diff --git a/dominoes/datasets/dominoe_dataset.py b/dominoes/datasets/dominoe_dataset.py index 8d89262..f46e1ad 100644 --- a/dominoes/datasets/dominoe_dataset.py +++ b/dominoes/datasets/dominoe_dataset.py @@ -92,25 +92,22 @@ def process_arguments(self, args): highest_dominoe="highest_dominoe", ) possible_kwargs = dict( + randomize_direction="randomize_direction", train_fraction="train_fraction", batch_size="batch_size", return_target="return_target", ignore_index="ignore_index", threads="threads", ) - signflip_kwargs = dict( - no_randomize_direction="randomize_direction", - ) - required_args, required_kwargs, possible_kwargs, signflip_kwargs = self.task_specific_arguments( + required_args, required_kwargs, possible_kwargs = self.task_specific_arguments( required_args, required_kwargs, possible_kwargs, - signflip_kwargs, ) - init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, signflip_kwargs, self.__class__.__name__)[1] + init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, self.__class__.__name__)[1] return init_prms - def task_specific_arguments(self, required_args, required_kwargs, possible_kwargs, signflip_kwargs): + def task_specific_arguments(self, required_args, required_kwargs, possible_kwargs): """add (or remove) parameters for each task, respectively""" if self.task == "sequencer": possible_kwargs["value_method"] = "value_method" @@ -121,7 +118,7 @@ def task_specific_arguments(self, required_args, required_kwargs, possible_kwarg required_args = ["highest_dominoe"] else: raise ValueError(f"task ({self.task}) not recognized!") - return required_args, required_kwargs, possible_kwargs, signflip_kwargs + return required_args, required_kwargs, possible_kwargs @torch.no_grad() def set_train_fraction(self, train_fraction): diff --git a/dominoes/datasets/tsp_dataset.py b/dominoes/datasets/tsp_dataset.py index 56f34b4..faf87ff 100644 --- a/dominoes/datasets/tsp_dataset.py +++ b/dominoes/datasets/tsp_dataset.py @@ -61,8 +61,7 @@ def process_arguments(self, args): ignore_index="ignore_index", threads="threads", ) - signflip_kwargs = {} - init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, signflip_kwargs, self.__class__.__name__)[1] + init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, self.__class__.__name__)[1] return init_prms def get_input_dim(self, coord_dims=None): diff --git a/dominoes/experiments/arglib.py b/dominoes/experiments/arglib.py index 71ea882..57f7a0e 100644 --- a/dominoes/experiments/arglib.py +++ b/dominoes/experiments/arglib.py @@ -1,3 +1,6 @@ +from ..utils import argbool + + def add_standard_training_parameters(parser): """ arguments for defining the network type, dataset, optimizer, and other metaparameters @@ -9,19 +12,9 @@ def add_standard_training_parameters(parser): parser.add_argument("--train_epochs", type=int, default=2000, help="how many epochs to train the networks on") parser.add_argument("--test_epochs", type=int, default=100, help="how many epochs to train the networks on") parser.add_argument("--replicates", type=int, default=2, help="how many replicates of each network to train") - parser.add_argument("--silent", default=False, action="store_true", help="if used, won't print training progress") - parser.add_argument( - "--save_loss", - default=False, - action="store_true", - help="if used, will save loss during training (always True for learning_mode=supervised)", - ) - parser.add_argument( - "--save_reward", - default=False, - action="store_true", - help="if used, will save reward during training (always True for learning_mode=reinforce)", - ) + parser.add_argument("--silent", type=argbool, default=False, help="whether or not to print training progress (default=False)") + parser.add_argument("--save_loss", type=argbool, default=False, help="whether to save loss during training (default=False)") + parser.add_argument("--save_reward", type=argbool, default=False, help="whether to save reward during training (default=False)") return parser @@ -33,29 +26,20 @@ def add_network_training_metaparameters(parser): parser.add_argument("--wd", type=float, default=0) # default weight decay parser.add_argument("--gamma", type=float, default=1.0) # default gamma for reward processing parser.add_argument("--train_temperature", type=float, default=5.0, help="temperature for training") - parser.add_argument( - "--no_thompson", - default=False, - action="store_true", - help="if used, do greedy instead of Thompson sampling during training (default=False)", - ) - parser.add_argument( - "--no_baseline", - default=False, - action="store_true", - help="if used, will not use a baseline correction during training (default=False)", - ) + parser.add_argument("--thompson", type=argbool, default=True, help="whether to use Thompson sampling during training (default=True)") + parser.add_argument("--baseline", type=argbool, default=True, help="whether to use a baseline correction during training (default=True)") parser.add_argument("--bl_temperature", type=float, default=1.0, help="temperature for baseline networks during training") - parser.add_argument("--bl_thompson", default=False, action="store_true", help="if used, will use Thompson sampling for baseline networks") - parser.add_argument("--bl_significance", type=float, default=0.05, help="significance level for updating baseline networks") - parser.add_argument("--bl_batch_size", type=int, default=1024, help="batch size for baseline networks") + parser.add_argument("--bl_thompson", type=argbool, default=False, help="whether to use Thompson sampling for baseline networks (default=False)") + parser.add_argument("--bl_significance", type=float, default=0.05, help="significance level for updating baseline networks (default=0.05)") + parser.add_argument("--bl_batch_size", type=int, default=1024, help="batch size for baseline networks (default=1024)") + parser.add_argument("--bl_duty_cycle", type=int, default=10, help="how many epochs to wait before checking baseline improvement (default=10)") return parser def add_pointernet_parameters(parser): """arguments for the PointerNet""" parser.add_argument("--embedding_dim", type=int, default=128, help="the dimensions of the embedding (default=128)") - parser.add_argument("--no_embedding_bias", default=False, action="store_true", help="whether to remove embedding_bias (default=False)") + parser.add_argument("--embedding_bias", type=argbool, default=True, help="whether to use embedding_bias (default=True)") parser.add_argument("--num_encoding_layers", type=int, default=1, help="the number of encoding layers in the PointerNet (default=1)") parser.add_argument("--encoder_method", type=str, default="transformer", help="PointerNet encoding layer method (default='transformer')") parser.add_argument("--decoder_method", type=str, default="transformer", help="PointerNet decoding layer method (default='transformer')") @@ -66,90 +50,35 @@ def add_pointernet_parameters(parser): def add_pointernet_encoder_parameters(parser): """arguments for the encoder layers in a PointerNet""" parser.add_argument("--encoder_num_heads", type=int, default=1, help="the number of heads in ptrnet encoding layers (default=1)") - parser.add_argument("--encoder_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the encoder (default=False)") + parser.add_argument("--encoder_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the encoder (default=True)") parser.add_argument("--encoder_expansion", type=int, default=4, help="the expansion of the FF layers in the encoder (default=4)") - parser.add_argument( - "--encoder_no_kqv_bias", - default=False, - action="store_true", - help="if used, won't use bias in the attention layers (default=False)", - ) - parser.add_argument( - "--encoder_no_mlp_bias", - default=False, - action="store_true", - help="if used, won't use bias in the MLP part of transformer encoders (default=False)", - ) - parser.add_argument( - "--encoder_no_residual", - default=False, - action="store_true", - help="if used, won't use residual connections in the encoder (default=False)", - ) + parser.add_argument("--encoder_kqv_bias", type=argbool, default=False, help="whether to use bias in the attention kqv layers (default=False)") + parser.add_argument("--encoder_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer encoders (default=True)") + parser.add_argument("--encoder_residual", type=argbool, default=True, help="use residual connections in the attentional encoders (default=True)") return parser def add_pointernet_decoder_parameters(parser): """arguments for the decoder layers in a PointerNet""" parser.add_argument("--decoder_num_heads", type=int, default=1, help="the number of heads in ptrnet decoding layers (default=1)") - parser.add_argument("--decoder_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the decoder (default=False)") + parser.add_argument("--decoder_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the decoder (default=True)") parser.add_argument("--decoder_expansion", type=int, default=4, help="the expansion of the FF layers in the decoder (default=4)") - parser.add_argument( - "--decoder_no_gru_bias", - default=False, - action="store_true", - help="if used, won't use bias in the gru decoder method (default=False)", - ) - parser.add_argument( - "--decoder_no_kqv_bias", - default=False, - action="store_true", - help="if used, won't use bias in the attention layer (default=False)", - ) - parser.add_argument( - "--decoder_no_mlp_bias", - default=False, - action="store_true", - help="if used, won't use bias in the MLP part of transformer decoders (default=False)", - ) - parser.add_argument( - "--decoder_no_residual", - default=False, - action="store_true", - help="if used, won't use residual connections in the decoder (default=False)", - ) + parser.add_argument("--decoder_gru_bias", type=argbool, default=True, help="whether to use bias in the gru decoder method (default=True)") + parser.add_argument("--decoder_kqv_bias", type=argbool, default=True, help="whether to use bias in the attention layer (default=True)") + parser.add_argument("--decoder_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer decoders (default=True)") + parser.add_argument("--decoder_residual", type=argbool, default=True, help="use residual connections in the attentional decoders (default=True)") return parser def add_pointernet_pointer_parameters(parser): """arguments for the pointer layer in a PointerNet""" parser.add_argument("--pointer_num_heads", type=int, default=1, help="the number of heads in ptrnet decoding layers (default=1)") - parser.add_argument("--pointer_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the decoder (default=False)") + parser.add_argument("--pointer_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the decoder (default=True)") parser.add_argument("--pointer_expansion", type=int, default=4, help="the expansion of the FF layers in the decoder (default=4)") - parser.add_argument( - "--pointer_bias", - default=False, - action="store_true", - help="if used, will use bias in pointer projection layers (default=False)", - ) - parser.add_argument( - "--pointer_no_kqv_bias", - default=False, - action="store_true", - help="if used, won't use bias in the attention layer of pointers (default=False)", - ) - parser.add_argument( - "--pointer_no_mlp_bias", - default=False, - action="store_true", - help="if used, won't use bias in the MLP part of transformer pointers (default=False)", - ) - parser.add_argument( - "--pointer_no_residual", - default=False, - action="store_true", - help="if used, won't use residual connections in the attentional pointer (default=False)", - ) + parser.add_argument("--pointer_bias", type=argbool, default=False, help="whether to use bias in pointer projection layers (default=False)") + parser.add_argument("--pointer_kqv_bias", type=argbool, default=True, help="use bias in the attention layer of pointers (default=True)") + parser.add_argument("--pointer_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer pointers (default=True)") + parser.add_argument("--pointer_residual", type=argbool, default=True, help="use residual connections in the attentional pointer (default=True)") return parser @@ -160,30 +89,10 @@ def add_checkpointing(parser): TODO: probably add some arguments for controlling the details of the checkpointing : e.g. how often to checkpoint, etc. """ - parser.add_argument( - "--use_prev", - default=False, - action="store_true", - help="if used, will pick up training off previous checkpoint", - ) - parser.add_argument( - "--save_ckpts", - default=False, - action="store_true", - help="if used, will save checkpoints of models", - ) - parser.add_argument( - "--ckpt_frequency", - default=1, - type=int, - help="frequency (by epoch) to save checkpoints of models", - ) - parser.add_argument( - "--use_wandb", - default=False, - action="store_true", - help="if used, will log experiment to WandB", - ) + parser.add_argument("--use_prev", default=False, action="store_true", help="if used, will pick up training off previous checkpoint") + parser.add_argument("--save_ckpts", default=False, action="store_true", help="if used, will save checkpoints of models") + parser.add_argument("--ckpt_frequency", default=1, type=int, help="frequency (by epoch) to save checkpoints of models") + parser.add_argument("--use_wandb", default=False, action="store_true", help="if used, will log experiment to WandB") return parser @@ -211,12 +120,7 @@ def add_dominoe_parameters(parser): parser.add_argument("--highest_dominoe", type=int, default=9, help="the highest dominoe value (default=9)") parser.add_argument("--train_fraction", type=float, default=0.8, help="the fraction of dominoes to train with (default=0.8)") parser.add_argument("--hand_size", type=int, default=8, help="the number of dominoes in the hand (default=8)") - parser.add_argument( - "--no_randomize_direction", - default=False, - action="store_true", - help="if used, won't randomize the direction of the dominoes (default=False)", - ) + parser.add_argument("--randomize_direction", type=argbool, default=True, help="randomize the direction of the dominoes (default=True)") return parser @@ -229,10 +133,5 @@ def add_dominoe_sequencer_parameters(parser): def add_dominoe_sorting_parameters(parser): """arguments for the dominoe sorting task""" - parser.add_argument( - "--allow_mistakes", - default=False, - action="store_true", - help="if used, will allow mistakes in the sorting task (default=False)", - ) + parser.add_argument("--allow_mistakes", type=argbool, default=False, help="allow mistakes in the sorting task (default=False)") return parser diff --git a/dominoes/experiments/ptr_arch_comp.py b/dominoes/experiments/ptr_arch_comp.py index 572ac15..9e91a11 100644 --- a/dominoes/experiments/ptr_arch_comp.py +++ b/dominoes/experiments/ptr_arch_comp.py @@ -13,8 +13,7 @@ from .. import datasets from .. import train from ..networks import get_pointer_network, get_pointer_methods, get_pointer_arguments -from ..utils import loadSavedExperiment, compute_stats_by_type -from .. import utils +from ..utils import compute_stats_by_type from .experiment import Experiment from . import arglib diff --git a/dominoes/networks/pointer_networks.py b/dominoes/networks/pointer_networks.py index c56165e..091dfba 100644 --- a/dominoes/networks/pointer_networks.py +++ b/dominoes/networks/pointer_networks.py @@ -106,6 +106,7 @@ def _get_pointernet_kwargs(self): # these are the possible kwargs that can be passed to the pointer network # key is the argument name in the ArgParser, value is the pointer network keyword required_kwargs = dict( + embedding_bias="embedding_bias", num_encoding_layers="num_encoding_layers", encoder_method="encoder_method", decoder_method="decoder_method", @@ -113,20 +114,14 @@ def _get_pointernet_kwargs(self): ) possible_kwargs = dict( train_temperature="temperature", + thompson="thompson", ) - # these use the default=False, store as True and require a sign flip - signflip_kwargs = dict( - no_embedding_bias="embedding_bias", - no_thompson="thompson", - ) - # get arguments for pointer network (embedding_dim,), pointernet_kwargs = process_arguments( self.args, required_args, required_kwargs, possible_kwargs, - signflip_kwargs, name="PointerNetwork-MainNetwork", ) @@ -145,21 +140,17 @@ def _get_encoder_kwargs(self): ) possible_kwargs = dict( encoder_expansion="expansion", + encoder_kqnorm="kqnorm", + encoder_kqv_bias="kqv_bias", + encoder_mlp_bias="mlp_bias", + encoder_residual="residual", ) - signflip_kwargs = dict( - encoder_no_kqnorm="kqnorm", - encoder_no_kqv_bias="kqv_bias", - encoder_no_mlp_bias="mlp_bias", - encoder_no_residual="residual", - ) - # get arguments for encoder _, encoder_kwargs = process_arguments( self.args, required_args, required_kwargs, possible_kwargs, - signflip_kwargs, name="PointerNetwork-Encoder", ) @@ -176,22 +167,18 @@ def _get_decoder_kwargs(self): possible_kwargs = dict( decoder_num_heads="num_heads", decoder_expansion="expansion", + decoder_gru_bias="gru_bias", + decoder_kqnorm="kqnorm", + decoder_kqv_bias="kqv_bias", + decoder_mlp_bias="mlp_bias", + decoder_residual="residual", ) - signflip_kwargs = dict( - decoder_no_gru_bias="gru_bias", - decoder_no_kqnorm="kqnorm", - decoder_no_kqv_bias="kqv_bias", - decoder_no_mlp_bias="mlp_bias", - decoder_no_residual="residual", - ) - # get arguments for decoder _, decoder_kwargs = process_arguments( self.args, required_args, required_kwargs, possible_kwargs, - signflip_kwargs, name="PointerNetwork-Decoder", ) @@ -209,12 +196,10 @@ def _get_pointer_kwargs(self): pointer_bias="bias", pointer_num_heads="num_heads", pointer_expansion="expansion", - ) - signflip_kwargs = dict( - decoder_no_kqnorm="kqnorm", - decoder_no_kqv_bias="kqv_bias", - decoder_no_mlp_bias="mlp_bias", - decoder_no_residual="residual", + pointer_kqnorm="kqnorm", + pointer_kqv_bias="kqv_bias", + pointer_mlp_bias="mlp_bias", + pointer_residual="residual", ) # get arguments for pointer @@ -223,7 +208,6 @@ def _get_pointer_kwargs(self): required_args, required_kwargs, possible_kwargs, - signflip_kwargs, name="PointerNetwork-Pointer", )