Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Provide vocab as param to constraints (#5321)
Browse files Browse the repository at this point in the history
* Provide vocab as param to constraints

* Update changelog
  • Loading branch information
JohnGiorgi authored Jul 19, 2021
1 parent 56e1f49 commit f8fad9f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after
loading a pretrained state dictionary. This is useful when tying weights, for example.
- Added an end-to-end test for the Transformer Toolkit.
- Added `vocab` argument to `BeamSearch`, which is passed to each contraint in `constraints` (if provided).

### Fixed

Expand Down
25 changes: 20 additions & 5 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from overrides import overrides
import torch

from allennlp.common import Registrable
from allennlp.common import Lazy, Registrable
from allennlp.common.checks import ConfigurationError
from allennlp.data import Vocabulary
from allennlp.nn.util import min_value_of_dtype


Expand Down Expand Up @@ -568,6 +569,9 @@ class Constraint(Registrable):
"""

def __init__(self, vocab: Optional[Vocabulary] = None) -> None:
self.vocab = vocab

def init_state(
self,
batch_size: int,
Expand Down Expand Up @@ -625,8 +629,8 @@ def _update_state(

@Constraint.register("repeated-ngram-blocking")
class RepeatedNGramBlockingConstraint(Constraint):
def __init__(self, ngram_size: int) -> None:
super().__init__()
def __init__(self, ngram_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.ngram_size = ngram_size

@overrides
Expand Down Expand Up @@ -729,6 +733,15 @@ class BeamSearch(Registrable):
constraints: `List[Constraint]`, optional (default = `None`)
An optional list of `Constraint`s which should be applied during beam search. If not
provided, no constraints will be enforced.
vocab: `Vocabulary`
If `constraints` is not `None`, then `Vocabulary` will be passed to each constraint
during its initialization. Having access to the vocabulary may be useful for certain
contraints, e.g., to mask out invalid predictions during structured prediction.
In a typical AllenNLP configuration file, this parameter does not get an entry under the
"model", it gets specified as a top-level parameter, then is passed in to the model
separately.
"""

default_implementation = "beam_search"
Expand All @@ -742,7 +755,8 @@ def __init__(
sampler: Sampler = None,
min_steps: Optional[int] = None,
final_sequence_scorer: FinalSequenceScorer = None,
constraints: Optional[List[Constraint]] = None,
constraints: Optional[List[Lazy[Constraint]]] = None,
vocab: Optional[Vocabulary] = None,
) -> None:
if not max_steps > 0:
raise ValueError("max_steps must be positive")
Expand All @@ -763,7 +777,8 @@ def __init__(
self.sampler = sampler or DeterministicSampler()
self.min_steps = min_steps or 0
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
self.constraints = constraints or []
# Lazily build the constrains with the vocab (if provided).
self.constraints = [constraint.construct(vocab=vocab) for constraint in constraints or []]

@staticmethod
def _reconstruct_sequences(predictions, backpointers):
Expand Down

0 comments on commit f8fad9f

Please sign in to comment.