Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAN integration #12

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion Writing-editing network/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ class LargeConfig6(LargeDataset):
use_topics = True
experiment_name = "lg-with-topics-lr-0.0001-WE-300"

class RandomConfig(SmallDataset):
emsize = 512
context_dim = 128
lr = 0.0001
pretrained = None
use_topics = False
experiment_name = "random"

configuration = {
"st1": SmallTopicsConfig1(),
"st2": SmallTopicsConfig2(),
Expand All @@ -168,7 +176,8 @@ class LargeConfig6(LargeDataset):
"l3": LargeConfig3(),
"l4": LargeConfig4(),
"l5": LargeConfig5(),
"l6": LargeConfig6()}
"l6": LargeConfig6(),
"random": RandomConfig()}

def get_conf(name):
return configuration[name]
165 changes: 153 additions & 12 deletions Writing-editing network/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import time, argparse, math, os, sys, pickle, copy
import time, argparse, math, os, sys, pickle, copy, random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
Expand All @@ -10,6 +10,7 @@
from seq2seq.EncoderRNN import EncoderRNN
from seq2seq.DecoderRNNFB import DecoderRNNFB
from seq2seq.ContextEncoder import ContextEncoder
from seq2seq.discriminator import reinforce, Encoder, DecoderRNN, Discriminator, Critic
from predictor import Predictor
from tensorboardX import SummaryWriter
import configurations
Expand Down Expand Up @@ -43,6 +44,7 @@

cwd = os.getcwd()
vectorizer = Vectorizer(min_frequency=config.min_freq)
train_shuffle_samples = []

validation_data_path = cwd + config.relative_dev_path
validation_abstracts = headline2abstractdataset(validation_data_path, vectorizer, args.cuda, max_len=1000)
Expand Down Expand Up @@ -73,6 +75,18 @@
n_layers=config.nlayers, rnn_cell=config.cell, bidirectional=config.bidirectional,
input_dropout_p=config.dropout, dropout_p=config.dropout)
model = FbSeq2seq(encoder_title, encoder, context_encoder, decoder)

""" Define the Discriminator model here """

discrim_encoder = Encoder(config.emsize, config.emsize, vocab_size, config.batch_size, use_cuda=args.cuda)
discrim_decoder = DecoderRNN(config.emsize, config.emsize, vocab_size, 1, config.batch_size)
discrim_model = Discriminator(discrim_encoder, discrim_decoder, use_cuda=args.cuda)
discrim_criterion = nn.BCELoss()
critic_model = Critic(config.emsize, config.emsize, vocab_size, config.batch_size, use_cuda=args.cuda)

""" Ends here """


total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in model.parameters())
print('Model total parameters:', total_params, flush=True)

Expand All @@ -83,8 +97,11 @@
criterion = nn.CrossEntropyLoss(ignore_index=0)
if args.cuda:
model = model.cuda()
discrim_model = discrim_model.cuda()
critic_model = critic_model.cuda()
discrim_criterion = discrim_criterion.cuda()
criterion = criterion.cuda()
optimizer = optim.Adam(model.parameters(), lr=config.lr)
optimizer = optim.Adam(list(model.parameters()) + list(discrim_model.parameters()), lr=config.lr)

# Mask variable
def _mask(prev_generated_seq):
Expand All @@ -104,30 +121,141 @@ def _mask(prev_generated_seq):
mask = mask.cuda()
return prev_generated_seq.data.masked_fill_(mask, 0)

def train_batch(input_variable, input_lengths, target_variable, topics, model,
teacher_forcing_ratio):
def load_training_samples_for_shuffling(dataset):
train_loader = DataLoader(dataset, config.batch_size)
for d in train_loader:
train_shuffle_samples.append(d)

def freeze_generator():
for param in model.parameters():
param.requires_grad = False

def unfreeze_generator():
for param in model.parameters():
param.requires_grad = True

def freeze_discriminator():
for param in discrim_model.parameters():
param.requires_grad = False

def unfreeze_discriminator():
for param in discrim_model.parameters():
param.requires_grad = True


def train_discriminator(input_variable, target_variable, is_eval=False):
sequence_length = input_variable.shape[1]
'''add other return values'''
dis_out, dis_sig = discrim_model(input_variable, sequence_length, config.batch_size)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong dimensions are coming here. Kindly check. The dimensions of dis_out and target_variable should match. The target variable has a dimensionality of (20, 1)

print("Discriminator's output dim is {}, target dim is {}".format(dis_sig.shape, target_variable.shape))
loss = discrim_criterion(dis_sig, target_variable)
""" Check if we need this if condition here, since we are freezing the weights anyhow """
if not is_eval:
discrim_model.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
'''Need to add code to train the critic when we train the discriminator'''
return loss

def train_generator(input_variable, input_lengths, target_variable, topics, model,
teacher_forcing_ratio, is_eval=False):
loss_list = []
# Forward propagation
prev_generated_seq = None
target_variable_reshaped = target_variable[:, 1:].contiguous().view(-1)

sentences = []
drafts = [[] for _ in range(config.num_exams)]
probabilities = [[] for _ in range(config.num_exams)]
for i, t in zip(input_variable, target_variable):
sentences.append((" ".join([vectorizer.idx2word[tok.item()] for tok in i if tok.item() != 0 and tok.item() != 1 and tok.item() != 2]),
" ".join([vectorizer.idx2word[tok.item()] for tok in t if tok.item() != 0 and tok.item() != 1 and tok.item() != 2])))

for i in range(config.num_exams):
topics = topics if config.use_topics else None
decoder_outputs, _, other = \
model(input_variable, prev_generated_seq, input_lengths,
target_variable, teacher_forcing_ratio, topics)

decoder_outputs_reshaped = decoder_outputs.view(-1, vocab_size)
lossi = criterion(decoder_outputs_reshaped, target_variable_reshaped)
prev_generated_seq = torch.squeeze(torch.topk(decoder_outputs, 1, dim=2)[1]).view(-1, decoder_outputs.size(1))
prev_generated_seq = _mask(prev_generated_seq)

log_probabilities = torch.squeeze(torch.topk(decoder_outputs, 1, dim=2)[0]).view(-1, decoder_outputs.size(1))
for lp_tensor, p_tensor in zip(log_probabilities, prev_generated_seq):
drafts[i].append(p_tensor)
probabilities[i].append(lp_tensor)

# Only calculate the reinforce loss the generator is being trained i.e.
# this is not the eval mode.
if not is_eval:
""" Call Discriminator, Critic and get the ReINFORCE Loss Term"""
#input is the batch_size * sequence length of word to index of abstracts
gen_log = torch.stack(probabilities[i])
discrim_input = torch.stack(drafts[i])
sequence_length = discrim_input.shape[1]
print("Log probabilities size is {}, discriminator's input size is {}, sequence length is {}, batch size is {}".format(gen_log.shape, discrim_input.shape, sequence_length, config.batch_size))
est_values = critic_model(discrim_input)
dis_out, dis_sig = discrim_model(discrim_input, sequence_length, config.batch_size)
#gen_log is the log probabilities of generator output
reinforce_loss, final_gen_obj = reinforce(gen_log, dis_out, est_values, sequence_length, config, args.cuda)
else:
reinforce_loss = 0

lossi = criterion(decoder_outputs_reshaped, target_variable_reshaped) + reinforce_loss
loss_list.append(lossi.item())
if model.training:
if not is_eval:
model.zero_grad()
lossi.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optimizer.step()
prev_generated_seq = torch.squeeze(torch.topk(decoder_outputs, 1, dim=2)[1]).view(-1, decoder_outputs.size(1))
prev_generated_seq = _mask(prev_generated_seq)
return loss_list

return loss_list, sentences, drafts

def train_batch(input_variables, input_lengths, target_variables, topics, teacher_forcing_ratio, is_generator):
if is_generator:
unfreeze_generator()
freeze_discriminator()
loss_list, sentences, drafts = train_generator(input_variables, input_lengths, target_variables, topics, model, teacher_forcing_ratio)
return loss_list
else:
# unfreeze_discriminator()
freeze_generator()
loss_list, sentences, drafts = train_generator(input_variables, input_lengths, target_variables,
topics, model, teacher_forcing_ratio, is_eval=True)
# Randomly select a batch of data from the actual training set.
_, true_target, _, _ = random.choice(train_shuffle_samples)
discriminator_input_variables = []
discriminator_target_variables = []

# Batch size is the same, so we can zip
for d, t in zip(drafts[1], true_target):

eos_tensor = torch.tensor(1).cuda() if input_variables.is_cuda else torch.tensor(1)
d = torch.cat((d.view(1, -1), eos_tensor.view(1, -1)), dim=1)
t = t.view(1, -1)

# Half the number of samples should be from the actual dataset and remaining half from genera
# -ted samples.
if random.random() > 0.5:
discriminator_input_variables.append(t)
# CONFIRM ? The discriminator should output 1 if the abstract is from the true data
discriminator_target_variables.append(1.)
else:
discriminator_input_variables.append(d)
# CONFIRM ? The discriminator should output 0 if the abstract is from the generated data
discriminator_target_variables.append(0.)

discriminator_input_variables = torch.stack(discriminator_input_variables).squeeze()
discriminator_target_variables = torch.tensor(discriminator_target_variables).squeeze()
if input_variables.is_cuda:
discriminator_input_variables = discriminator_input_variables.cuda()
discriminator_target_variables =discriminator_target_variables.cuda()

""" Mix them with the true data and pass it to the discriminator """
output = train_discriminator(discriminator_input_variables, discriminator_target_variables)
return output


def evaluate(validation_dataset, model, teacher_forcing_ratio):
validation_loader = DataLoader(validation_dataset, config.batch_size)
Expand All @@ -137,8 +265,8 @@ def evaluate(validation_dataset, model, teacher_forcing_ratio):
input_variables = source
target_variables = target
# train model
loss_list = train_batch(input_variables, input_lengths,
target_variables, topics, model, teacher_forcing_ratio)
loss_list, _, _ = train_generator(input_variables, input_lengths, target_variables,
topics, model, teacher_forcing_ratio, is_eval=True)
num_examples = len(source)
for i in range(config.num_exams):
epoch_loss_list[i] += loss_list[i] * num_examples
Expand All @@ -151,6 +279,9 @@ def train_epoches(dataset, model, n_epochs, teacher_forcing_ratio):
prev_epoch_loss_list = [100] * config.num_exams
patience = 0
best_model = None
# Loads the entire training set into memory. So that we can fetch a random batch to feed to the
# discriminator while training.
load_training_samples_for_shuffling(dataset)
for epoch in range(1, n_epochs + 1):
model.train(True)
epoch_examples_total = 0
Expand All @@ -163,8 +294,18 @@ def train_epoches(dataset, model, n_epochs, teacher_forcing_ratio):
input_variables = source
target_variables = target
# train model

# Train the DISCRIMINATOR
train_batch(input_variables, input_lengths,
target_variables, topics, teacher_forcing_ratio, False)

print("Discriminator trained successfully")
exit(0)

# Train the GENERATOR
loss_list = train_batch(input_variables, input_lengths,
target_variables, topics, model, teacher_forcing_ratio)
target_variables, topics, teacher_forcing_ratio, True)

# Record average loss
num_examples = len(source)
epoch_examples_total += num_examples
Expand Down
Loading