Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Roll backbone (#5229)
Browse files Browse the repository at this point in the history
Adding support for inputs to the backbone with more than 3 dimensions
  • Loading branch information
jacob-morrison authored May 28, 2021
1 parent babc450 commit 3d5799d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
an actual `torch.nn.Module`. Other parameters to this method have changed as well.
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.

### Added

Expand Down
60 changes: 50 additions & 10 deletions allennlp/modules/backbones/vilbert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,50 @@ def forward(
box_mask: torch.Tensor,
text: TextFieldTensors,
) -> Dict[str, torch.Tensor]:
batch_size, _, feature_size = box_features.size()

if "token_ids" in text["tokens"]:
token_ids = text["tokens"]["token_ids"]
else:
token_ids = text["tokens"]["tokens"]

if token_ids.shape[:-1] != box_features.shape[:-2]:
raise ValueError(
"Tokens and boxes must have the same batch size and extra "
"dimensions (if applicable). Token size {0} did not match "
"box feature size {1}.".format(token_ids.shape[:-1], box_features.shape[:-2])
)

# Shape: (batch_size, num_tokens)
token_type_ids = text["tokens"].get("type_ids")
# Shape: (batch_size, num_tokens)
attention_mask = text["tokens"].get("mask")

# Shape: (batch_size, num_tokens, embedding_dim)
box_feature_dimensions = box_features.shape
feature_size = box_feature_dimensions[-1]
rolled_dimensions = box_feature_dimensions[:-2]
rolled_dimensions_product = 1
for dim in rolled_dimensions:
rolled_dimensions_product *= dim

token_ids = token_ids.view(rolled_dimensions_product, token_ids.shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(
rolled_dimensions_product, token_type_ids.shape[-1]
)
if attention_mask is not None:
attention_mask = attention_mask.view(
rolled_dimensions_product, attention_mask.shape[-1]
)
box_features = box_features.view(
rolled_dimensions_product, box_feature_dimensions[-2], feature_size
)
box_coordinates = box_coordinates.view(
rolled_dimensions_product,
box_coordinates.shape[-2],
box_coordinates.shape[-1],
)
box_mask = box_mask.view(rolled_dimensions_product, box_mask.shape[-1])

# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
embedding_output = self.text_embeddings(token_ids, token_type_ids)
num_tokens = embedding_output.size(1)

Expand All @@ -137,16 +168,16 @@ def forward(

extended_image_attention_mask = box_mask

# Shape: (batch_size, feature_size, num_tokens)
# Shape: (rolled_dimensions_product, feature_size, num_tokens)
# TODO (epwalsh): Why all zeros?? This doesn't seem right.
extended_co_attention_mask = torch.zeros(
batch_size,
extended_image_attention_mask.shape[0],
feature_size,
num_tokens,
dtype=extended_image_attention_mask.dtype,
)

# Shape: (batch_size, num_boxes, image_embedding_dim)
# Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim)
v_embedding_output = self.image_embeddings(box_features, box_coordinates)

encoded_layers_t, encoded_layers_v = self.encoder(
Expand All @@ -157,16 +188,25 @@ def forward(
extended_co_attention_mask,
)

# Shape: (batch_size, num_tokens, embedding_dim)
# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
sequence_output_t = encoded_layers_t[:, :, :, -1]
# Shape: (batch_size, num_boxes, image_embedding_dim)
# Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim)
sequence_output_v = encoded_layers_v[:, :, :, -1]

# Shape: (batch_size, pooled_output_dim)
# Shape: (rolled_dimensions_product, pooled_output_dim)
pooled_output_t = self.t_pooler(sequence_output_t)
# Shape: (batch_size, pooled_output_dim)
# Shape: (rolled_dimensions_product, pooled_output_dim)
pooled_output_v = self.v_pooler(sequence_output_v)

sequence_output_t = sequence_output_t.view(
rolled_dimensions + (sequence_output_t.shape[-2], sequence_output_t.shape[-1])
)
sequence_output_v = sequence_output_v.view(
rolled_dimensions + (sequence_output_v.shape[-2], sequence_output_v.shape[-1])
)
pooled_output_t = pooled_output_t.view(rolled_dimensions + (pooled_output_t.shape[-1],))
pooled_output_v = pooled_output_v.view(rolled_dimensions + (pooled_output_v.shape[-1],))

if self.fusion_method == "sum":
pooled_output = self.dropout(pooled_output_t + pooled_output_v)
elif self.fusion_method == "mul":
Expand Down

0 comments on commit 3d5799d

Please sign in to comment.