-
Notifications
You must be signed in to change notification settings - Fork 1
/
net_clip.py
194 lines (168 loc) · 6.87 KB
/
net_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python3
# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL)
#
# SPDX-License-Identifier: MIT
#
import warnings
import torch
import torch.nn as nn
warnings.filterwarnings("ignore")
import pdb
import pickle
import clip
import numpy as np
import torch.nn.functional as F
from PIL import Image
import globvars as gv
# Vision and Language pretrained models.
class SMART_VL_CLIP_Net(nn.Module):
def __init__(self, args, VL_backbone):
super(SMART_VL_CLIP_Net, self).__init__()
vocab_path = args.vocab_path
with open(vocab_path, "rb") as f:
self.vocab = pickle.load(f)
self.num_opts = 5
self.out_dim = args.feat_size
self.h_sz = 256
self.feat_size = 512
self.dummy_question = None
self.model_name = args.model_name
self.use_clip_text = args.use_clip_text
self.loss_type = args.loss_type
self.monolithic = args.monolithic
self.use_single_image_head = args.use_single_image_head
self.train_backbone = args.train_backbone
self.sorted_puzzle_ids = np.sort(np.array([int(ii) for ii in args.puzzle_ids]))
if args.loss_type == "classifier" or args.loss_type == "puzzle_tails":
self.max_val = gv.MAX_VAL + 1
elif args.loss_type == "regression":
self.max_val = 1
self.preprocess = args.preprocess
self.VL_backbone = VL_backbone
self.create_puzzle_head(args)
self.q_MLP = nn.Sequential(
nn.Linear(self.feat_size, self.h_sz),
nn.ReLU(),
nn.Linear(self.h_sz, self.out_dim),
nn.ReLU(),
)
self.qv_fusion = nn.Sequential(
nn.Linear(self.out_dim * 2, self.out_dim),
nn.ReLU(),
nn.Linear(self.out_dim, self.out_dim),
nn.ReLU(),
)
if self.monolithic:
self.qvo_fusion = nn.Sequential(nn.Linear(self.out_dim, self.max_val))
else:
self.create_puzzle_tail(args)
def create_puzzle_head(self, args):
if args.use_single_image_head:
self.im_encoder = nn.Sequential(
nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim)
)
else:
self.puzzle_ids = args.puzzle_ids
im_encoder = [nn.Sequential(nn.Linear(self.out_dim, 1))]
for i in range(1, gv.num_puzzles + 1):
im_encoder.append(
nn.Sequential(
nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim)
)
)
self.im_encoder = nn.ModuleList(im_encoder)
def create_puzzle_tail(self, args):
self.puzzle_ids = args.puzzle_ids
ans_decoder = [
nn.Sequential(nn.Linear(self.out_dim, 1))
] # start with a dummy as we are 1-indexed wrt puzzle ids.
if args.puzzles == "all":
puzzles = range(1, gv.num_puzzles + 1)
else:
puzzles = self.puzzle_ids
for pid in puzzles: # self.puzzle_ids:
num_classes = gv.NUM_CLASSES_PER_PUZZLE[str(pid)] if args.loss_type == "classifier" else 1
if int(pid) not in gv.SEQ_PUZZLES:
ans_decoder.append(
nn.Sequential(
nn.Linear(self.out_dim, self.out_dim),
nn.ReLU(),
nn.Linear(self.out_dim, self.out_dim),
nn.ReLU(),
nn.Linear(self.out_dim, num_classes),
)
)
else:
ans_decoder.append(nn.LSTM(self.out_dim, num_classes, num_layers=1, batch_first=True))
self.ans_decoder = nn.ModuleList(ans_decoder)
def process(self, im, q_text):
q_text = self.decode_text(q_text)
text = clip.tokenize(q_text, truncate=True).to("cuda")
return im, text
def encode_image(self, im_feat, pids=None):
if self.use_single_image_head:
y = self.im_encoder(im_feat)
else:
y = torch.zeros(len(im_feat), self.out_dim).cuda()
for t in range(len(self.puzzle_ids)):
idx = pids == int(self.puzzle_ids[t])
idx = idx.cuda()
if idx.sum() > 0:
y[idx] = F.relu(self.im_encoder[int(self.puzzle_ids[t])](im_feat[idx]))
return y
def encode_text(self, q_feat):
x = F.relu(self.q_MLP(q_feat))
return x
def decode_image(self, im_list):
"""convert torch tensor images back to Image bcos VL FLAVA model works with images."""
im_list = (im_list.permute(0, 2, 3, 1) * 255).cpu().numpy().astype("uint8")
im_list = [Image.fromarray(im_list[ii]) for ii in range(len(im_list))] # convert im
return im_list
def decode_text(self, text):
get_range = lambda x: range(1, x) if x < 70 else range(x - 70 + 4, x)
tt = text.cpu()
text = [
" ".join([self.vocab.idx2word[int(j)] for j in tt[i][get_range(torch.nonzero(tt[i])[-1])]])
for i in range(len(tt))
]
return text
def seq_decoder(self, decoder, feat):
"""run the LSTM decoder sequentially for k steps"""
out = [None] * gv.MAX_DECODE_STEPS
hx = None
for k in range(gv.MAX_DECODE_STEPS):
try:
out[k], hx = decoder(feat, hx)
except:
pdb.set_trace()
return out
def decode_individual_puzzles(self, feat, pids):
upids = torch.unique(pids)
out_feats = {}
for t in range(len(upids)):
idx = pids == upids[t]
key = str(upids[t].item())
key_idx = np.where(int(key) == np.array(self.sorted_puzzle_ids))[0][0] + 1 # +1 because we use 1-indexed.
if upids[t] not in gv.SEQ_PUZZLES:
out_feats[int(key)] = self.ans_decoder[key_idx](feat[idx])
else:
out_feats[int(key)] = self.seq_decoder(self.ans_decoder[key_idx], feat[idx])
return out_feats
def forward(self, im, q=None, puzzle_ids=None):
im, text = self.process(im, q)
if self.train_backbone:
im_feat = self.VL_backbone.encode_image(im)
q_feat = self.VL_backbone.encode_text(text)
else:
with torch.no_grad():
im_feat = self.VL_backbone.encode_image(im)
q_feat = self.VL_backbone.encode_text(text)
im_feat = self.encode_image(im_feat.float(), puzzle_ids)
q_feat = self.encode_text(q_feat.float())
qv_feat = self.qv_fusion(torch.cat([im_feat, q_feat], dim=1))
if self.monolithic:
qv_feat = qv_feat.unsqueeze(1)
qvo_feat = self.qvo_fusion(qv_feat).squeeze()
else:
qvo_feat = self.decode_individual_puzzles(qv_feat, puzzle_ids)
return qvo_feat