Skip to content

Commit

Permalink
no more signflip arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 14, 2024
1 parent 8a21063 commit 3470cfe
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 177 deletions.
13 changes: 5 additions & 8 deletions dominoes/datasets/dominoe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,22 @@ def process_arguments(self, args):
highest_dominoe="highest_dominoe",
)
possible_kwargs = dict(
randomize_direction="randomize_direction",
train_fraction="train_fraction",
batch_size="batch_size",
return_target="return_target",
ignore_index="ignore_index",
threads="threads",
)
signflip_kwargs = dict(
no_randomize_direction="randomize_direction",
)
required_args, required_kwargs, possible_kwargs, signflip_kwargs = self.task_specific_arguments(
required_args, required_kwargs, possible_kwargs = self.task_specific_arguments(
required_args,
required_kwargs,
possible_kwargs,
signflip_kwargs,
)
init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, signflip_kwargs, self.__class__.__name__)[1]
init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, self.__class__.__name__)[1]
return init_prms

def task_specific_arguments(self, required_args, required_kwargs, possible_kwargs, signflip_kwargs):
def task_specific_arguments(self, required_args, required_kwargs, possible_kwargs):
"""add (or remove) parameters for each task, respectively"""
if self.task == "sequencer":
possible_kwargs["value_method"] = "value_method"
Expand All @@ -121,7 +118,7 @@ def task_specific_arguments(self, required_args, required_kwargs, possible_kwarg
required_args = ["highest_dominoe"]
else:
raise ValueError(f"task ({self.task}) not recognized!")
return required_args, required_kwargs, possible_kwargs, signflip_kwargs
return required_args, required_kwargs, possible_kwargs

@torch.no_grad()
def set_train_fraction(self, train_fraction):
Expand Down
3 changes: 1 addition & 2 deletions dominoes/datasets/tsp_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def process_arguments(self, args):
ignore_index="ignore_index",
threads="threads",
)
signflip_kwargs = {}
init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, signflip_kwargs, self.__class__.__name__)[1]
init_prms = process_arguments(args, required_args, required_kwargs, possible_kwargs, self.__class__.__name__)[1]
return init_prms

def get_input_dim(self, coord_dims=None):
Expand Down
167 changes: 33 additions & 134 deletions dominoes/experiments/arglib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from ..utils import argbool


def add_standard_training_parameters(parser):
"""
arguments for defining the network type, dataset, optimizer, and other metaparameters
Expand All @@ -9,19 +12,9 @@ def add_standard_training_parameters(parser):
parser.add_argument("--train_epochs", type=int, default=2000, help="how many epochs to train the networks on")
parser.add_argument("--test_epochs", type=int, default=100, help="how many epochs to train the networks on")
parser.add_argument("--replicates", type=int, default=2, help="how many replicates of each network to train")
parser.add_argument("--silent", default=False, action="store_true", help="if used, won't print training progress")
parser.add_argument(
"--save_loss",
default=False,
action="store_true",
help="if used, will save loss during training (always True for learning_mode=supervised)",
)
parser.add_argument(
"--save_reward",
default=False,
action="store_true",
help="if used, will save reward during training (always True for learning_mode=reinforce)",
)
parser.add_argument("--silent", type=argbool, default=False, help="whether or not to print training progress (default=False)")
parser.add_argument("--save_loss", type=argbool, default=False, help="whether to save loss during training (default=False)")
parser.add_argument("--save_reward", type=argbool, default=False, help="whether to save reward during training (default=False)")
return parser


Expand All @@ -33,29 +26,20 @@ def add_network_training_metaparameters(parser):
parser.add_argument("--wd", type=float, default=0) # default weight decay
parser.add_argument("--gamma", type=float, default=1.0) # default gamma for reward processing
parser.add_argument("--train_temperature", type=float, default=5.0, help="temperature for training")
parser.add_argument(
"--no_thompson",
default=False,
action="store_true",
help="if used, do greedy instead of Thompson sampling during training (default=False)",
)
parser.add_argument(
"--no_baseline",
default=False,
action="store_true",
help="if used, will not use a baseline correction during training (default=False)",
)
parser.add_argument("--thompson", type=argbool, default=True, help="whether to use Thompson sampling during training (default=True)")
parser.add_argument("--baseline", type=argbool, default=True, help="whether to use a baseline correction during training (default=True)")
parser.add_argument("--bl_temperature", type=float, default=1.0, help="temperature for baseline networks during training")
parser.add_argument("--bl_thompson", default=False, action="store_true", help="if used, will use Thompson sampling for baseline networks")
parser.add_argument("--bl_significance", type=float, default=0.05, help="significance level for updating baseline networks")
parser.add_argument("--bl_batch_size", type=int, default=1024, help="batch size for baseline networks")
parser.add_argument("--bl_thompson", type=argbool, default=False, help="whether to use Thompson sampling for baseline networks (default=False)")
parser.add_argument("--bl_significance", type=float, default=0.05, help="significance level for updating baseline networks (default=0.05)")
parser.add_argument("--bl_batch_size", type=int, default=1024, help="batch size for baseline networks (default=1024)")
parser.add_argument("--bl_duty_cycle", type=int, default=10, help="how many epochs to wait before checking baseline improvement (default=10)")
return parser


def add_pointernet_parameters(parser):
"""arguments for the PointerNet"""
parser.add_argument("--embedding_dim", type=int, default=128, help="the dimensions of the embedding (default=128)")
parser.add_argument("--no_embedding_bias", default=False, action="store_true", help="whether to remove embedding_bias (default=False)")
parser.add_argument("--embedding_bias", type=argbool, default=True, help="whether to use embedding_bias (default=True)")
parser.add_argument("--num_encoding_layers", type=int, default=1, help="the number of encoding layers in the PointerNet (default=1)")
parser.add_argument("--encoder_method", type=str, default="transformer", help="PointerNet encoding layer method (default='transformer')")
parser.add_argument("--decoder_method", type=str, default="transformer", help="PointerNet decoding layer method (default='transformer')")
Expand All @@ -66,90 +50,35 @@ def add_pointernet_parameters(parser):
def add_pointernet_encoder_parameters(parser):
"""arguments for the encoder layers in a PointerNet"""
parser.add_argument("--encoder_num_heads", type=int, default=1, help="the number of heads in ptrnet encoding layers (default=1)")
parser.add_argument("--encoder_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the encoder (default=False)")
parser.add_argument("--encoder_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the encoder (default=True)")
parser.add_argument("--encoder_expansion", type=int, default=4, help="the expansion of the FF layers in the encoder (default=4)")
parser.add_argument(
"--encoder_no_kqv_bias",
default=False,
action="store_true",
help="if used, won't use bias in the attention layers (default=False)",
)
parser.add_argument(
"--encoder_no_mlp_bias",
default=False,
action="store_true",
help="if used, won't use bias in the MLP part of transformer encoders (default=False)",
)
parser.add_argument(
"--encoder_no_residual",
default=False,
action="store_true",
help="if used, won't use residual connections in the encoder (default=False)",
)
parser.add_argument("--encoder_kqv_bias", type=argbool, default=False, help="whether to use bias in the attention kqv layers (default=False)")
parser.add_argument("--encoder_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer encoders (default=True)")
parser.add_argument("--encoder_residual", type=argbool, default=True, help="use residual connections in the attentional encoders (default=True)")
return parser


def add_pointernet_decoder_parameters(parser):
"""arguments for the decoder layers in a PointerNet"""
parser.add_argument("--decoder_num_heads", type=int, default=1, help="the number of heads in ptrnet decoding layers (default=1)")
parser.add_argument("--decoder_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the decoder (default=False)")
parser.add_argument("--decoder_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the decoder (default=True)")
parser.add_argument("--decoder_expansion", type=int, default=4, help="the expansion of the FF layers in the decoder (default=4)")
parser.add_argument(
"--decoder_no_gru_bias",
default=False,
action="store_true",
help="if used, won't use bias in the gru decoder method (default=False)",
)
parser.add_argument(
"--decoder_no_kqv_bias",
default=False,
action="store_true",
help="if used, won't use bias in the attention layer (default=False)",
)
parser.add_argument(
"--decoder_no_mlp_bias",
default=False,
action="store_true",
help="if used, won't use bias in the MLP part of transformer decoders (default=False)",
)
parser.add_argument(
"--decoder_no_residual",
default=False,
action="store_true",
help="if used, won't use residual connections in the decoder (default=False)",
)
parser.add_argument("--decoder_gru_bias", type=argbool, default=True, help="whether to use bias in the gru decoder method (default=True)")
parser.add_argument("--decoder_kqv_bias", type=argbool, default=True, help="whether to use bias in the attention layer (default=True)")
parser.add_argument("--decoder_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer decoders (default=True)")
parser.add_argument("--decoder_residual", type=argbool, default=True, help="use residual connections in the attentional decoders (default=True)")
return parser


def add_pointernet_pointer_parameters(parser):
"""arguments for the pointer layer in a PointerNet"""
parser.add_argument("--pointer_num_heads", type=int, default=1, help="the number of heads in ptrnet decoding layers (default=1)")
parser.add_argument("--pointer_no_kqnorm", default=False, action="store_true", help="if used, won't use kqnorm in the decoder (default=False)")
parser.add_argument("--pointer_kqnorm", type=argbool, default=True, help="whether to use kqnorm in the decoder (default=True)")
parser.add_argument("--pointer_expansion", type=int, default=4, help="the expansion of the FF layers in the decoder (default=4)")
parser.add_argument(
"--pointer_bias",
default=False,
action="store_true",
help="if used, will use bias in pointer projection layers (default=False)",
)
parser.add_argument(
"--pointer_no_kqv_bias",
default=False,
action="store_true",
help="if used, won't use bias in the attention layer of pointers (default=False)",
)
parser.add_argument(
"--pointer_no_mlp_bias",
default=False,
action="store_true",
help="if used, won't use bias in the MLP part of transformer pointers (default=False)",
)
parser.add_argument(
"--pointer_no_residual",
default=False,
action="store_true",
help="if used, won't use residual connections in the attentional pointer (default=False)",
)
parser.add_argument("--pointer_bias", type=argbool, default=False, help="whether to use bias in pointer projection layers (default=False)")
parser.add_argument("--pointer_kqv_bias", type=argbool, default=True, help="use bias in the attention layer of pointers (default=True)")
parser.add_argument("--pointer_mlp_bias", type=argbool, default=True, help="use bias in the MLP part of transformer pointers (default=True)")
parser.add_argument("--pointer_residual", type=argbool, default=True, help="use residual connections in the attentional pointer (default=True)")
return parser


Expand All @@ -160,30 +89,10 @@ def add_checkpointing(parser):
TODO: probably add some arguments for controlling the details of the checkpointing
: e.g. how often to checkpoint, etc.
"""
parser.add_argument(
"--use_prev",
default=False,
action="store_true",
help="if used, will pick up training off previous checkpoint",
)
parser.add_argument(
"--save_ckpts",
default=False,
action="store_true",
help="if used, will save checkpoints of models",
)
parser.add_argument(
"--ckpt_frequency",
default=1,
type=int,
help="frequency (by epoch) to save checkpoints of models",
)
parser.add_argument(
"--use_wandb",
default=False,
action="store_true",
help="if used, will log experiment to WandB",
)
parser.add_argument("--use_prev", default=False, action="store_true", help="if used, will pick up training off previous checkpoint")
parser.add_argument("--save_ckpts", default=False, action="store_true", help="if used, will save checkpoints of models")
parser.add_argument("--ckpt_frequency", default=1, type=int, help="frequency (by epoch) to save checkpoints of models")
parser.add_argument("--use_wandb", default=False, action="store_true", help="if used, will log experiment to WandB")

return parser

Expand Down Expand Up @@ -211,12 +120,7 @@ def add_dominoe_parameters(parser):
parser.add_argument("--highest_dominoe", type=int, default=9, help="the highest dominoe value (default=9)")
parser.add_argument("--train_fraction", type=float, default=0.8, help="the fraction of dominoes to train with (default=0.8)")
parser.add_argument("--hand_size", type=int, default=8, help="the number of dominoes in the hand (default=8)")
parser.add_argument(
"--no_randomize_direction",
default=False,
action="store_true",
help="if used, won't randomize the direction of the dominoes (default=False)",
)
parser.add_argument("--randomize_direction", type=argbool, default=True, help="randomize the direction of the dominoes (default=True)")
return parser


Expand All @@ -229,10 +133,5 @@ def add_dominoe_sequencer_parameters(parser):

def add_dominoe_sorting_parameters(parser):
"""arguments for the dominoe sorting task"""
parser.add_argument(
"--allow_mistakes",
default=False,
action="store_true",
help="if used, will allow mistakes in the sorting task (default=False)",
)
parser.add_argument("--allow_mistakes", type=argbool, default=False, help="allow mistakes in the sorting task (default=False)")
return parser
3 changes: 1 addition & 2 deletions dominoes/experiments/ptr_arch_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from .. import datasets
from .. import train
from ..networks import get_pointer_network, get_pointer_methods, get_pointer_arguments
from ..utils import loadSavedExperiment, compute_stats_by_type
from .. import utils
from ..utils import compute_stats_by_type

from .experiment import Experiment
from . import arglib
Expand Down
Loading

0 comments on commit 3470cfe

Please sign in to comment.