Skip to content

Commit

Permalink
[Python] #15 문장 생성 튜토리얼 개발 진행
Browse files Browse the repository at this point in the history
  • Loading branch information
ratsgo committed Feb 1, 2021
1 parent 76857c4 commit beb5f27
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 51 deletions.
2 changes: 1 addition & 1 deletion ratsnlp/nlpbook/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .corpus import KoreanChatCorpus
from .corpus import GenerationDataset, NsmcCorpus
from .task import GenerationTask
from .tokenizer import KoGPT2Tokenizer
from .arguments import GenerationTrainArguments, GenerationDeployArguments
8 changes: 4 additions & 4 deletions ratsnlp/nlpbook/generation/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GenerationTrainArguments:
metadata={"help": "The output model dir."}
)
max_seq_length: int = field(
default=128,
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
Expand Down Expand Up @@ -61,7 +61,7 @@ class GenerationTrainArguments:
metadata={"help": "Test Mode enables `fast_dev_run`"}
)
learning_rate: float = field(
default=5e-6,
default=5e-5,
metadata={"help": "learning rate"}
)
optimizer: str = field(
Expand All @@ -77,7 +77,7 @@ class GenerationTrainArguments:
metadata={"help": "max epochs"}
)
batch_size: int = field(
default=0,
default=96,
metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"}
)
cpu_workers: int = field(
Expand Down Expand Up @@ -126,7 +126,7 @@ class GenerationDeployArguments:
metadata={"help": "The output model checkpoint path."}
)
max_seq_length: int = field(
default=128,
default=64,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
Expand Down
82 changes: 38 additions & 44 deletions ratsnlp/nlpbook/generation/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,81 @@
from typing import List, Optional
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizerFast
from ratsnlp.nlpbook.generation.tokenizer import MASK_TOKEN
from ratsnlp.nlpbook.generation.arguments import GenerationTrainArguments


logger = logging.getLogger(__name__)


@dataclass
class KoreanChatExample:
question: str
answer: str
class GenerationExample:
text: str


@dataclass
class GenerationFeatures:
input_ids: List[int]
attention_mask: Optional[List[int]] = None
token_type_ids: Optional[List[int]] = None
label: Optional[List[int]] = None
labels: Optional[List[int]] = None


class KoreanChatCorpus:
class NsmcCorpus:

@classmethod
def _read_corpus(cls, input_file):
def __init__(self):
pass

def _read_corpus(cls, input_file, quotechar='"'):
with open(input_file, "r", encoding="utf-8") as f:
return list(csv.reader(f, delimiter=","))[1:]
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))

def _create_examples(self, lines):
examples = []
for line in lines:
question, answer, _ = line
examples.append(KoreanChatExample(question=question, answer=answer))
for (i, line) in enumerate(lines):
if i == 0:
continue
_, review_sentence, sentiment = line
sentiment = "긍정" if sentiment == "1" else "부정"
text = sentiment + " " + review_sentence
examples.append(GenerationExample(text=text))
return examples

def get_examples(self, data_root_path, mode):
data_fpath = os.path.join(data_root_path, f"ratings_{mode}.txt")
logger.info(f"loading {mode} data... LOOKING AT {data_fpath}")
return self._create_examples(self._read_corpus(data_fpath))


def _convert_chatbot_examples_to_generation_features(
examples: List[KoreanChatExample],
def _convert_examples_to_generation_features(
examples: List[GenerationExample],
tokenizer: PreTrainedTokenizerFast,
args: GenerationTrainArguments,
):

mask_token_id = tokenizer.convert_tokens_to_ids(MASK_TOKEN)
logger.info(
"tokenize sentences, it could take a lot of time..."
)
start = time.time()
features = []
for example in examples:
question_token_ids = [tokenizer.bos_token_id] + tokenizer.encode(example.question) + [tokenizer.eos_token_id]
answer_token_ids = tokenizer.encode(example.answer) + [tokenizer.eos_token_id]
answer_length = args.max_seq_length - len(question_token_ids)
if answer_length > 0:
if len(question_token_ids) + len(answer_token_ids) > args.max_seq_length:
answer_token_ids = answer_token_ids[:answer_length]
token_type_ids = [0] * len(question_token_ids) + [1] * len(answer_token_ids)
attention_mask = [1] * (len(question_token_ids + len(answer_token_ids)))
label = [mask_token_id] * len(question_token_ids) + answer_token_ids
if len(question_token_ids) + len(answer_token_ids) < args.max_seq_length:
padding_length = args.max_seq_length - len(question_token_ids) - len(answer_token_ids)
answer_token_ids += [tokenizer.pad_token_id] * padding_length
token_type_ids += [0] * padding_length
attention_mask += [0] * padding_length
label += [tokenizer.pad_token_id] * padding_length
feature = GenerationFeatures(
input_ids=question_token_ids + answer_token_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
label=label,
)
features.append(feature)
batch_encoding = tokenizer(
[example.text for example in examples],
max_length=args.max_seq_length,
padding="max_length",
truncation=True,
)
logger.info(
"tokenize sentences [took %.3f s]", time.time() - start
)

features = []
for i in range(len(examples)):
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
feature = GenerationFeatures(**inputs, labels=batch_encoding["input_ids"][i])
features.append(feature)

for i, example in enumerate(examples[:5]):
logger.info("*** Example ***")
logger.info("question: %s" % (example.question))
logger.info("answer: %s" % (example.answer))
logger.info("tokens: %s" % (" ".join(tokenizer.decode(features[i].input_ids))))
logger.info("label: %s" % (" ".join(tokenizer.decode(features[i].label))))
logger.info("sentence: %s" % (example.text))
logger.info("tokens: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids))))
logger.info("features: %s" % features[i])

return features
Expand All @@ -102,7 +96,7 @@ def __init__(
tokenizer: PreTrainedTokenizerFast,
corpus,
mode: Optional[str] = "train",
convert_examples_to_features_fn=_convert_chatbot_examples_to_generation_features,
convert_examples_to_features_fn=_convert_examples_to_generation_features,
):
if corpus is not None:
self.corpus = corpus
Expand Down
2 changes: 1 addition & 1 deletion ratsnlp/nlpbook/generation/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, **kwargs):
return self.model(**kwargs)

def step(self, inputs, mode="train"):
loss, logits = self.model(**inputs)
loss, logits, _ = self.model(**inputs)
preds = logits.argmax(dim=-1)
labels = inputs["labels"]
acc = accuracy(preds, labels)
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="ratsnlp",
version="0.0.955",
version="0.0.957",
license='MIT',
author="ratsgo",
author_email="[email protected]",
Expand All @@ -16,6 +16,7 @@
'ratsnlp.nlpbook.ner': ['*.html'],
'ratsnlp.nlpbook.qa': ['*.html'],
'ratsnlp.nlpbook.paircls': ['*.html'],
'ratsnlp.nlpbook.generation': ['*.html'],
},
install_requires=[
"torch>=1.4.0",
Expand Down

0 comments on commit beb5f27

Please sign in to comment.