-
Notifications
You must be signed in to change notification settings - Fork 18
/
NCE.py
25 lines (22 loc) · 1.21 KB
/
NCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
def NCE(source_label: np.ndarray, target_label: np.ndarray):
"""
:param source_label: shape [N], elements in [0, C_s), often got from taken argmax from pre-trained predictions
:param target_label: shape [N], elements in [0, C_t)
:return:
"""
C_t = int(np.max(target_label) + 1) # the number of target classes
C_s = int(np.max(source_label) + 1) # the number of source classes
N = len(source_label)
joint = np.zeros((C_t, C_s), dtype=float) # placeholder for the joint distribution, shape [C_t, C_s]
for s, t in zip(source_label, target_label):
s = int(s)
t = int(t)
joint[t, s] += 1.0 / N
p_z = joint.sum(axis=0, keepdims=True) # shape [1, C_s]
p_target_given_source = (joint / p_z).T # P(y | z), shape [C_s, C_t]
mask = p_z.reshape(-1) != 0 # valid Z, shape [C_s]
p_target_given_source = p_target_given_source[mask] + 1e-20 # remove NaN where p(z) = 0, add 1e-20 to avoid log (0)
entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True) # shape [C_s, 1]
conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask]) # scalar
return - conditional_entropy