Skip to content

Commit

Permalink
handle storedEncoding in pointer_decoder only
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 10, 2024
1 parent 0f6d2da commit 4790dc9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions dominoes/networks/pointer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def _pointer_loop(
- updates the mask if permutation is enabled
"""
# update context representation
decoder_context = self.decode(processed_encoding, decoder_input, decoder_context, mask)
decoder_context = self.decode(processed_encoding.stored_encoding, decoder_input, decoder_context, mask)

# use pointer attention to evaluate scores of each possible input given the context
decoder_state = self.get_decoder_state(decoder_input, decoder_context)
log_score = self.pointer(processed_encoding, decoder_state, mask=mask, temperature=temperature)
log_score = self.pointer(processed_encoding.stored_encoding, decoder_state, mask=mask, temperature=temperature)

# choose token for this sample
if thompson:
Expand All @@ -172,7 +172,7 @@ def _pointer_loop(

# next decoder_input is whatever token had the highest probability
index_tensor = choice.unsqueeze(-1).expand(batch_size, 1, self.embedding_dim)
decoder_input = torch.gather(processed_encoding, dim=1, index=index_tensor).squeeze(1)
decoder_input = torch.gather(processed_encoding.stored_encoding, dim=1, index=index_tensor).squeeze(1)

if self.permutation:
# mask previously chosen tokens (don't include this in the computation graph)
Expand Down
12 changes: 6 additions & 6 deletions dominoes/networks/pointer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
transform_decoded = self.W2(decoder_state)
logits = self.vt(torch.tanh(encoded.stored_encoding + transform_decoded.unsqueeze(1))).squeeze(2)
logits = self.vt(torch.tanh(encoded + transform_decoded.unsqueeze(1))).squeeze(2)
return logits


Expand All @@ -108,7 +108,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
transform_decoded = self.dln(self.W2(decoder_state))
logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2)
logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2)
return logits


Expand All @@ -129,7 +129,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
transform_decoded = self.dln(self.W2(decoder_state))
logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2)
logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2)
return logits


Expand All @@ -146,7 +146,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
transform_decoded = self.dln(decoder_state)
logits = torch.bmm(encoded.stored_encoding, transform_decoded.unsqueeze(2)).squeeze(2)
logits = torch.bmm(encoded, transform_decoded.unsqueeze(2)).squeeze(2)
return logits


Expand All @@ -172,7 +172,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
# attention on encoded representations with decoder_state
attended = self.attention(encoded.stored_encoding, [decoder_state], mask=mask)
attended = self.attention(encoded, [decoder_state], mask=mask)
logits = self.vt(torch.tanh(attended)).squeeze(2)
return logits

Expand All @@ -198,7 +198,7 @@ def process_encoded(self, encoded):

def _get_logits(self, encoded, decoder_state, mask):
# transform encoded representations with decoder_state
transformed = self.transformer(encoded.stored_encoding, [decoder_state], mask=mask)
transformed = self.transformer(encoded, [decoder_state], mask=mask)
logits = self.vt(torch.tanh(transformed)).squeeze(2)
return logits

Expand Down

0 comments on commit 4790dc9

Please sign in to comment.