Skip to content

Commit

Permalink
consolidate input processing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 9, 2024
1 parent 681bfde commit 0d826cb
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 88 deletions.
48 changes: 48 additions & 0 deletions dominoes/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch
from ..utils import named_transpose
from .pointer_networks import get_pointer_network
from .pointer_layers import get_pointer_methods

Expand All @@ -7,3 +9,49 @@ def _check_kwargs(method_name, kwargs, required_kwargs):
for key in required_kwargs:
if key not in kwargs:
raise ValueError(f"required kwarg {key} not found in kwargs ({method_name} requires {required_kwargs})")


def _process_input(input, mask, expected_dim, name="input"):
"""check sizes and create mask if not provided"""
assert input.ndim == 3, f"{name} should have size: (batch_size, num_tokens, input_dimensionality)"
batch_size, num_tokens, input_dim = input.size()
assert input_dim == expected_dim, f"dimensionality of {name} ({input_dim}) doesn't match network ({expected_dim})"

if mask is not None:
assert mask.ndim == 2, f"{name} mask must have shape (batch_size, num_tokens)"
assert mask.size(0) == batch_size and mask.size(1) == num_tokens, f"{name} mask must have same batch size and max tokens as x"
assert not torch.any(torch.all(mask == 0, dim=1)), f"{name} mask includes rows where all elements are masked, this is not permitted"
else:
mask = torch.ones((batch_size, num_tokens), dtype=input.dtype).to(input.device)

return batch_size, mask


def _process_multimodal_input(self, multimode, mm_mask, num_multimodal, mm_dim):
"""check sizes and create mask for all multimodal inputs if not provided"""
# first check if multimodal context is a sequence (tuple or list)
assert type(multimode) == tuple or type(multimode) == list, "context should be a tuple or a list"
if len(multimode) != num_multimodal:
raise ValueError(f"this network requires {num_multimodal} context tensors but {len(multimode)} were provided")

# handle mm_mask
if mm_mask is None:
# make a None list for the mask if not provided
mm_mask = [None for _ in range(num_multimodal)]
else:
assert len(mm_mask) == num_multimodal, f"if mm_mask provided, must have {num_multimodal} elements"

# handle mm_dim
if type(mm_dim) == int:
mm_dim = [mm_dim] * num_multimodal
assert len(mm_dim) == num_multimodal, f"mm_dim must be an integer or a list of integers of length {num_multimodal}"

# get the batch and mask for each multimode input
mm_batch_size, mm_mask = named_transpose(
[_process_input(mmc, mmm, mmd, name=f"multimodal input #{imm}") for imm, (mmc, mmm, mmd) in enumerate(zip(multimode, mm_mask, mm_dim))]
)

# make sure batch_size is consistent
assert all([mmb == mm_batch_size[0] for mmb in mm_batch_size]), "batch size of each multimodal input should be the same"

return mm_batch_size[0], mm_mask
55 changes: 10 additions & 45 deletions dominoes/networks/attention_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import torch
from torch import nn

from ..utils import masked_softmax, named_transpose
from ..utils import masked_softmax

from . import _process_input, _process_multimodal_input

"""
Almost everything I've learned about machine learning and pytorch has been due
Expand Down Expand Up @@ -140,42 +141,6 @@ def _build_mixing_matrix(self):
"""method for building mixing matrix for unifying num_heads"""
self.unifynum_heads = nn.Linear(self.embedding_dim, self.embedding_dim)

def _process_input(self, x, mask=None):
"""check sizes and create mask if not provided"""
# attention layer forward pass
assert x.ndim == 3, "x should have size: (batch_size, num_tokens, embedding_dimensionality)"
batch_size, tokens, embedding_dim = x.size() # get size of input

mask = mask if mask is not None else torch.ones((batch_size, tokens), dtype=x.dtype).to(x.device)
assert x.size(0) == mask.size(0) and x.size(1) == mask.size(1), "mask must have same batch_size and num_tokens as x"

# this is the only requirement on the input (other than the number of dimensions)
msg = f"Input embedding dim ({embedding_dim}) should match layer embedding dim ({self.embedding_dim})"
assert embedding_dim == self.embedding_dim, msg

return batch_size, mask

def _process_multimodal_input(self, multimode, mm_mask=None):
"""check sizes and create mask for all multimodal inputs if not provided"""
# first check if multimodal context is a sequence (tuple or list)
assert type(multimode) == tuple or type(multimode) == list, "context should be a tuple or a list"
if len(multimode) != 0:
raise ValueError(f"this network requires {self.num_multimodal} context tensors but {len(multimode)} were provided")

if mm_mask is None:
# make a None list for the mask if not provided
mm_mask = [None for _ in range(self.num_multimodal)]
else:
assert len(mm_mask) == self.num_multimodal, f"if mm_mask provided, must have {self.num_multimodal} elements"

# get the batch and mask for each multimode input
mm_batch, mm_mask = named_transpose([self._process_input(mmc, mmm) for mmc, mmm in zip(multimode, mm_mask)])

# make sure batch_size is consistent
assert all([mmb == mm_batch[0] for mmb in mm_batch]), "batch size of each mm context tensor should be the same"

return mm_batch[0], mm_mask

def _send_to_kqv(self, x, context=None, multimode=None):
"""
centralized method for sending input to queries, keys, and values
Expand Down Expand Up @@ -337,7 +302,7 @@ def __init__(self, embedding_dim, num_heads=8, kqnorm=True, bias=False, residual
def forward(self, x, mask=None):
"""core forward method with residual connection for attention mechanism"""
# create mask if not provided, check input sizes
batch_size, mask = self._process_input(x, mask)
batch_size, mask = _process_input(x, mask, self.embedding_dim)

# convert input tokens to their keys, queries, and values
keys, queries, values = self._send_to_kqv(x)
Expand Down Expand Up @@ -365,8 +330,8 @@ def __init__(self, embedding_dim, num_heads=8, kqnorm=True, bias=False, residual
def forward(self, x, context, mask=None, context_mask=None):
"""core forward method with residual connection for attention mechanism"""
# create mask if not provided, check input sizes
batch_size, mask = self._process_input(x, mask)
context_batch_size, context_mask = self._process_input(context, context_mask)
batch_size, mask = _process_input(x, mask, self.embedding_dim)
context_batch_size, context_mask = _process_input(context, context_mask, self.embedding_dim, name="context")
assert batch_size == context_batch_size, "batch size of x and context should match"

# convert input tokens to their keys, queries, and values
Expand All @@ -393,8 +358,8 @@ class MultimodalAttention(AttentionBaseClass):
def forward(self, x, multimode, mask=None, mm_mask=None):
"""core forward method with residual connection for attention mechanism"""
# create mask if not provided, check input sizes
batch_size, mask = self._process_input(x, mask)
mm_batch_size, mm_mask = self._process_multimodal_input(multimode, mm_mask)
batch_size, mask = _process_input(x, mask, self.embedding_dim)
mm_batch_size, mm_mask = _process_multimodal_input(multimode, mm_mask, self.num_multimodal, self.embedding_dim)

assert batch_size == mm_batch_size, "batch size of x and multimode inputs should match"

Expand Down Expand Up @@ -423,9 +388,9 @@ class MultimodalContextualAttention(AttentionBaseClass):
def forward(self, x, context, multimode, mask=None, context_mask=None, mm_mask=None):
"""core forward method with residual connection for attention mechanism"""
# create mask if not provided, check input sizes
batch_size, mask = self._process_input(x, mask)
context_batch_size, context_mask = self._process_input(context, context_mask)
mm_batch_size, mm_mask = self._process_multimodal_input(multimode, mm_mask)
batch_size, mask = _process_input(x, mask, self.embedding_dim)
context_batch_size, context_mask = _process_input(context, context_mask, self.embedding_dim, name="context")
mm_batch_size, mm_mask = _process_multimodal_input(multimode, mm_mask, self.num_multimodal, self.embedding_dim)

assert batch_size == context_batch_size, "batch size of x and context should match"
assert batch_size == mm_batch_size, "batch size of x and multimode inputs should match"
Expand Down
64 changes: 21 additions & 43 deletions dominoes/networks/pointer_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .attention_modules import get_attention_layer, _attention_type
from .transformer_modules import get_transformer_layer
from .pointer_decoder import PointerDecoder
from . import _check_kwargs
from . import _check_kwargs, _process_input, _process_multimodal_input


def _get_pointernet_constructor(contextual, multimodal):
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
self.contextual = contextual
self.multimodal = multimodal
self.num_multimodal = num_multimodal * multimodal
self.mm_input_dim = self._set_multimodal_input_dim(mm_input_dim, input_dim, self.num_multimodal)
self.mm_input_dim = self._check_multimodal_input_dim(mm_input_dim, self.num_multimodal) if multimodal else None
self.num_encoding_layers = num_encoding_layers
self.require_init = require_init
self.encoder_method = encoder_method
Expand All @@ -168,9 +168,10 @@ def __init__(
self.embedding = nn.Linear(in_features=input_dim, out_features=self.embedding_dim, bias=self.embedding_bias)

# create embedding for multimodal inputs to embedding dimension
self.mm_embedding = nn.ModuleList(
[nn.Linear(in_features=mid, out_features=self.embedding_dim, bias=self.embedding_bias) for mid in self.mm_input_dim]
)
if self.multimodal:
self.mm_embedding = nn.ModuleList(
[nn.Linear(in_features=mid, out_features=self.embedding_dim, bias=self.embedding_bias) for mid in self.mm_input_dim]
)

# build encoder layers
self._build_encoder(num_encoding_layers, encoder_method, encoder_kwargs)
Expand All @@ -197,14 +198,12 @@ def set_thompson(self, thompson):
"""method for setting the thompson sampling flag of the pointer network"""
self.thompson = thompson

def _set_multimodal_input_dim(self, mm_input_dim, input_dim, num_multimodal):
def _check_multimodal_input_dim(self, mm_input_dim, num_multimodal):
"""helper for setting input dim of multimodal inputs"""
if mm_input_dim is not None:
assert type(mm_input_dim) == tuple or type(mm_input_dim) == list, "mm_input_dim must be a tuple or list"
assert all([type(mid) == int for mid in mm_input_dim]), "all elements of mm_input_dim must be integers"
return mm_input_dim
else:
return [input_dim] * num_multimodal
assert type(mm_input_dim) == tuple or type(mm_input_dim) == list, "mm_input_dim must be a tuple or list"
assert len(mm_input_dim) == num_multimodal, f"mm_input_dim must have {num_multimodal} elements"
assert all([type(mid) == int for mid in mm_input_dim]), "all elements of mm_input_dim must be integers"
return mm_input_dim

def _build_encoder(self, num_encoding_layers, encoder_method, encoder_kwargs):
"""flexible method for creating encoding layers for pointer network"""
Expand Down Expand Up @@ -268,27 +267,6 @@ def _get_decoder_state(self, temperature, thompson):
thompson = thompson or self.thompson
return temperature, thompson

def _process_input(self, input, mask, name="input"):
"""method for processing inputs (for main input or context inputs, and used by multimode)"""
batch_size, num_tokens, inp_dim = input.size()
assert inp_dim == self.input_dim, f"dimensionality of {name} doesn't match network"

if mask is not None:
assert mask.ndim == 2, f"{name} mask must have shape (batch, num_tokens)"
assert mask.size(0) == batch_size and mask.size(1) == num_tokens, f"{name} mask must have same batch size and max tokens as x"
assert not torch.any(torch.all(mask == 0, dim=1)), f"{name} mask includes rows where all elements are masked, this is not permitted"
else:
mask = torch.ones((batch_size, num_tokens), dtype=input.dtype).to(input.device)

return mask, batch_size

def _process_multimodal_inputs(self, multimode, mm_mask):
"""method for processing all multimode inputs"""
msg = f"multimode inputs must be a tuple or list of tensors with length {self.num_multimodal}"
assert (type(multimode) == tuple or type(multimode) == list) and len(multimode) == self.num_multimodal, msg
mm_mask, mm_batch_size = named_transpose([self._process_input(mmx, mmm, name="multimode") for mmx, mmm in zip(multimode, mm_mask)])
return mm_mask, mm_batch_size

def _get_max_output(self, x, init, max_output=None):
"""method for getting the maximum number of outputs for the pointer network"""
if self.require_init and init is None:
Expand Down Expand Up @@ -404,7 +382,7 @@ def forward(self, x, mask=None, init=None, temperature=None, thompson=None, max_
temperature, thompson = self._get_decoder_state(temperature, thompson)

# process main input
mask, batch_size = self._process_input(x, mask)
batch_size, mask = _process_input(x, mask, self.input_dim)

# get max output
max_output = self._get_max_output(x, init, max_output=max_output)
Expand Down Expand Up @@ -458,8 +436,8 @@ def forward(self, x, context, mask=None, context_mask=None, init=None, temperatu
temperature, thompson = self._get_decoder_state(temperature, thompson)

# process main input
mask, batch_size = self._process_input(x, mask)
context_mask, context_batch_size = self._process_input(context, context_mask, name="context")
batch_size, mask = _process_input(x, mask, self.input_dim)
context_batch_size, context_mask = _process_input(context, context_mask, self.input_dim, name="context")

# check for consistency in batch sizes
assert batch_size == context_batch_size, "batch sizes of x and context must match"
Expand Down Expand Up @@ -516,11 +494,11 @@ def forward(self, x, multimode, mask=None, mm_mask=None, init=None, temperature=
temperature, thompson = self._get_decoder_state(temperature, thompson)

# process main input
mask, batch_size = self._process_input(x, mask)
mm_mask, mm_batch_size = self._process_multimodal_inputs(multimode, mm_mask)
batch_size, mask = _process_input(x, mask, self.input_dim)
mm_batch_size, mm_mask = _process_multimodal_input(multimode, mm_mask, self.num_multimodal, self.mm_input_dim)

# check for consistency in batch sizes
assert all([batch_size == mmbs for mmbs in mm_batch_size]), "batch sizes of x and each multimodal input must match"
assert batch_size == mm_batch_size, "batch sizes of x multimodal inputs must match"

# get max output
max_output = self._get_max_output(x, init, max_output=max_output)
Expand Down Expand Up @@ -581,13 +559,13 @@ def forward(self, x, context, multimode, mask=None, context_mask=None, mm_mask=N
temperature, thompson = self._get_decoder_state(temperature, thompson)

# process main input
mask, batch_size = self._process_input(x, mask)
context_mask, context_batch_size = self._process_input(context, context_mask, name="context")
mm_mask, mm_batch_size = self._process_multimodal_inputs(multimode, mm_mask)
batch_size, mask = _process_input(x, mask, self.input_dim)
context_batch_size, context_mask = _process_input(context, context_mask, self.input_dim, name="context")
mm_batch_size, mm_mask = _process_multimodal_input(multimode, mm_mask, self.num_multimodal, self.mm_input_dim)

# check for consistency in batch sizes
assert batch_size == context_batch_size, "batch sizes of x and context must match"
assert all([batch_size == mmbs for mmbs in mm_batch_size]), "batch sizes of x and each multimodal input must match"
assert batch_size == mm_batch_size, "batch sizes of x multimodal inputs must match"

# get max output
max_output = self._get_max_output(x, init, max_output=max_output)
Expand Down

0 comments on commit 0d826cb

Please sign in to comment.