Skip to content

Commit

Permalink
Added random config
Browse files Browse the repository at this point in the history
  • Loading branch information
edorado93 committed Jul 9, 2018
1 parent 0dead58 commit 9236f7f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 10 additions & 1 deletion Writing-editing network/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ class LargeConfig6(LargeDataset):
use_topics = True
experiment_name = "lg-with-topics-lr-0.0001-WE-300"

class RandomConfig(SmallDataset):
emsize = 512
context_dim = 128
lr = 0.0001
pretrained = None
use_topics = False
experiment_name = "random"

configuration = {
"st1": SmallTopicsConfig1(),
"st2": SmallTopicsConfig2(),
Expand All @@ -168,7 +176,8 @@ class LargeConfig6(LargeDataset):
"l3": LargeConfig3(),
"l4": LargeConfig4(),
"l5": LargeConfig5(),
"l6": LargeConfig6()}
"l6": LargeConfig6(),
"random": RandomConfig()}

def get_conf(name):
return configuration[name]
2 changes: 1 addition & 1 deletion Writing-editing network/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def train_generator(input_variable, input_lengths, target_variable, topics, mode
gen_log = torch.stack(probabilities[i])
discrim_input = torch.stack(drafts[i])
sequence_length = discrim_input.shape[1]
print("Log probabilities size is {}, discriminator's input size is {}, sequence length is {}, batch size is {}", gen_log.shape, discrim_input.shape, sequence_length, config.batch_size)
print("Log probabilities size is {}, discriminator's input size is {}, sequence length is {}, batch size is {}".format(gen_log.shape, discrim_input.shape, sequence_length, config.batch_size))
est_values = critic_model(discrim_input)
dis_out = discrim_model(discrim_input)
#gen_log is the log probabilities of generator output
Expand Down

0 comments on commit 9236f7f

Please sign in to comment.