diff --git a/dominoes/networks/pointer_networks.py b/dominoes/networks/pointer_networks.py index cbd086c..b6e9660 100644 --- a/dominoes/networks/pointer_networks.py +++ b/dominoes/networks/pointer_networks.py @@ -84,38 +84,73 @@ def get_pointer_network( """ -def get_pointer_kwargs(args): - """ - method for getting the pointer method and its kwargs from a dictionary of arguments - """ - # get embedding dim (this is always required) - embedding_dim = args["embedding_dim"] - - # these are the possible kwargs that can be passed to the pointer network - # key is the argument name in the ArgParser, value is the pointer network keyword - possible_kwargs = dict( - embedding_bias="embedding_bias", - num_encoding_layers="num_encoding_layers", - encoder_method="encoder_method", - decoder_method="decoder_method", - pointer_method="pointer_method", - train_temperature="temperature", - ) - # these use the default=False, store as True and require a sign flip - signflip_kwargs = dict( - no_thompson="thompson", - ) +class PointerArguments: + def __init__(self, args): + self.args = args + self._get_pointernet_kwargs() + self._get_encoder_kwargs() + self._get_decoder_kwargs() + self._get_pointer_kwargs() + + def __call__(self): + """return stored arguments""" + return self.embedding_dim, self.pointernet_kwargs, self.encoder_kwargs, self.decoder_kwargs, self.pointer_kwargs + + def _get_kwargs(self, required_args, required_kwargs, possible_kwargs, signflip_kwargs, name="pointer network"): + """method for getting the required and optional kwargs from stored argument dictionary""" + # if any required args are missing, raise an error + _check_kwargs(name, self.args, required_args) + + # get required args (in order of list!) + args = [self.args[arg] for arg in required_args] + + # get kwargs + kwargs = {} + + # if any required kwargs are missing, raise an error + _check_kwargs(name, self.args, required_kwargs) + for key, value in required_kwargs.items(): + kwargs[value] = self.args[key] + + # if any kwargs are included in args, add them to the dictionary + for key, value in possible_kwargs.items(): + if key in self.args: + kwargs[value] = self.args[key] + + # these use the default=False, store as True and require a sign flip + for key, value in signflip_kwargs.items(): + if key in self.args: + kwargs[value] = not self.args[key] + + return args, kwargs + + def _get_pointernet_kwargs(self): + """method for getting the pointer network kwargs from a dictionary of arguments""" + required_args = ["embedding_dim"] + + # these are the possible kwargs that can be passed to the pointer network + # key is the argument name in the ArgParser, value is the pointer network keyword + required_kwargs = dict( + embedding_bias="embedding_bias", + num_encoding_layers="num_encoding_layers", + encoder_method="encoder_method", + decoder_method="decoder_method", + pointer_method="pointer_method", + ) + possible_kwargs = dict( + train_temperature="temperature", + ) + # these use the default=False, store as True and require a sign flip + signflip_kwargs = dict( + no_thompson="thompson", + ) - # if any kwargs are included in args, add them to the pointer_kwargs dictionary - pointer_kwargs = {} - for key, value in possible_kwargs.items(): - if key in args: - pointer_kwargs[value] = args[key] - for key, value in signflip_kwargs.items(): - if key in args: - pointer_kwargs[value] = not args[key] + # get arguments for pointer network + (embedding_dim,), pointernet_kwargs = self._get_kwargs(required_args, required_kwargs, possible_kwargs, signflip_kwargs) - return embedding_dim, pointer_kwargs + # store arguments in self + self.embedding_dim = embedding_dim + self.pointernet_kwargs = pointernet_kwargs class PointerNetworkBaseClass(nn.Module, ABC):