Skip to content

Commit

Permalink
making a smart argument handler
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 9, 2024
1 parent 2531718 commit cadf03c
Showing 1 changed file with 65 additions and 30 deletions.
95 changes: 65 additions & 30 deletions dominoes/networks/pointer_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,38 +84,73 @@ def get_pointer_network(
"""


def get_pointer_kwargs(args):
"""
method for getting the pointer method and its kwargs from a dictionary of arguments
"""
# get embedding dim (this is always required)
embedding_dim = args["embedding_dim"]

# these are the possible kwargs that can be passed to the pointer network
# key is the argument name in the ArgParser, value is the pointer network keyword
possible_kwargs = dict(
embedding_bias="embedding_bias",
num_encoding_layers="num_encoding_layers",
encoder_method="encoder_method",
decoder_method="decoder_method",
pointer_method="pointer_method",
train_temperature="temperature",
)
# these use the default=False, store as True and require a sign flip
signflip_kwargs = dict(
no_thompson="thompson",
)
class PointerArguments:
def __init__(self, args):
self.args = args
self._get_pointernet_kwargs()
self._get_encoder_kwargs()
self._get_decoder_kwargs()
self._get_pointer_kwargs()

def __call__(self):
"""return stored arguments"""
return self.embedding_dim, self.pointernet_kwargs, self.encoder_kwargs, self.decoder_kwargs, self.pointer_kwargs

def _get_kwargs(self, required_args, required_kwargs, possible_kwargs, signflip_kwargs, name="pointer network"):
"""method for getting the required and optional kwargs from stored argument dictionary"""
# if any required args are missing, raise an error
_check_kwargs(name, self.args, required_args)

# get required args (in order of list!)
args = [self.args[arg] for arg in required_args]

# get kwargs
kwargs = {}

# if any required kwargs are missing, raise an error
_check_kwargs(name, self.args, required_kwargs)
for key, value in required_kwargs.items():
kwargs[value] = self.args[key]

# if any kwargs are included in args, add them to the dictionary
for key, value in possible_kwargs.items():
if key in self.args:
kwargs[value] = self.args[key]

# these use the default=False, store as True and require a sign flip
for key, value in signflip_kwargs.items():
if key in self.args:
kwargs[value] = not self.args[key]

return args, kwargs

def _get_pointernet_kwargs(self):
"""method for getting the pointer network kwargs from a dictionary of arguments"""
required_args = ["embedding_dim"]

# these are the possible kwargs that can be passed to the pointer network
# key is the argument name in the ArgParser, value is the pointer network keyword
required_kwargs = dict(
embedding_bias="embedding_bias",
num_encoding_layers="num_encoding_layers",
encoder_method="encoder_method",
decoder_method="decoder_method",
pointer_method="pointer_method",
)
possible_kwargs = dict(
train_temperature="temperature",
)
# these use the default=False, store as True and require a sign flip
signflip_kwargs = dict(
no_thompson="thompson",
)

# if any kwargs are included in args, add them to the pointer_kwargs dictionary
pointer_kwargs = {}
for key, value in possible_kwargs.items():
if key in args:
pointer_kwargs[value] = args[key]
for key, value in signflip_kwargs.items():
if key in args:
pointer_kwargs[value] = not args[key]
# get arguments for pointer network
(embedding_dim,), pointernet_kwargs = self._get_kwargs(required_args, required_kwargs, possible_kwargs, signflip_kwargs)

return embedding_dim, pointer_kwargs
# store arguments in self
self.embedding_dim = embedding_dim
self.pointernet_kwargs = pointernet_kwargs


class PointerNetworkBaseClass(nn.Module, ABC):
Expand Down

0 comments on commit cadf03c

Please sign in to comment.