diff --git a/dominoes/networks/pointer_decoder.py b/dominoes/networks/pointer_decoder.py index 59987d1..385799e 100644 --- a/dominoes/networks/pointer_decoder.py +++ b/dominoes/networks/pointer_decoder.py @@ -156,11 +156,11 @@ def _pointer_loop( - updates the mask if permutation is enabled """ # update context representation - decoder_context = self.decode(processed_encoding, decoder_input, decoder_context, mask) + decoder_context = self.decode(processed_encoding.stored_encoding, decoder_input, decoder_context, mask) # use pointer attention to evaluate scores of each possible input given the context decoder_state = self.get_decoder_state(decoder_input, decoder_context) - log_score = self.pointer(processed_encoding, decoder_state, mask=mask, temperature=temperature) + log_score = self.pointer(processed_encoding.stored_encoding, decoder_state, mask=mask, temperature=temperature) # choose token for this sample if thompson: @@ -172,7 +172,7 @@ def _pointer_loop( # next decoder_input is whatever token had the highest probability index_tensor = choice.unsqueeze(-1).expand(batch_size, 1, self.embedding_dim) - decoder_input = torch.gather(processed_encoding, dim=1, index=index_tensor).squeeze(1) + decoder_input = torch.gather(processed_encoding.stored_encoding, dim=1, index=index_tensor).squeeze(1) if self.permutation: # mask previously chosen tokens (don't include this in the computation graph) diff --git a/dominoes/networks/pointer_layers.py b/dominoes/networks/pointer_layers.py index 061a61a..095ff02 100644 --- a/dominoes/networks/pointer_layers.py +++ b/dominoes/networks/pointer_layers.py @@ -89,7 +89,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): transform_decoded = self.W2(decoder_state) - logits = self.vt(torch.tanh(encoded.stored_encoding + transform_decoded.unsqueeze(1))).squeeze(2) + logits = self.vt(torch.tanh(encoded + transform_decoded.unsqueeze(1))).squeeze(2) return logits @@ -108,7 +108,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): transform_decoded = self.dln(self.W2(decoder_state)) - logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2) + logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2) return logits @@ -129,7 +129,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): transform_decoded = self.dln(self.W2(decoder_state)) - logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2) + logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2) return logits @@ -146,7 +146,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): transform_decoded = self.dln(decoder_state) - logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2) + logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2) return logits @@ -172,7 +172,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): # attention on encoded representations with decoder_state - attended = self.attention(encoded.stored_encoding, [decoder_state], mask=mask) + attended = self.attention(encoded, [decoder_state], mask=mask) logits = self.vt(torch.tanh(attended)).squeeze(2) return logits @@ -198,7 +198,7 @@ def process_encoded(self, encoded): def _get_logits(self, encoded, decoder_state, mask): # transform encoded representations with decoder_state - transformed = self.transformer(encoded.stored_encoding, [decoder_state], mask=mask) + transformed = self.transformer(encoded, [decoder_state], mask=mask) logits = self.vt(torch.tanh(transformed)).squeeze(2) return logits