From beb5f27b543947f8657c2c84e4b638443e267729 Mon Sep 17 00:00:00 2001 From: ratsgo Date: Mon, 1 Feb 2021 19:41:58 +0900 Subject: [PATCH] =?UTF-8?q?[Python]=20#15=20=EB=AC=B8=EC=9E=A5=20=EC=83=9D?= =?UTF-8?q?=EC=84=B1=20=ED=8A=9C=ED=86=A0=EB=A6=AC=EC=96=BC=20=EA=B0=9C?= =?UTF-8?q?=EB=B0=9C=20=EC=A7=84=ED=96=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ratsnlp/nlpbook/generation/__init__.py | 2 +- ratsnlp/nlpbook/generation/arguments.py | 8 +-- ratsnlp/nlpbook/generation/corpus.py | 82 ++++++++++++------------- ratsnlp/nlpbook/generation/task.py | 2 +- setup.py | 3 +- 5 files changed, 46 insertions(+), 51 deletions(-) diff --git a/ratsnlp/nlpbook/generation/__init__.py b/ratsnlp/nlpbook/generation/__init__.py index a3b27e3..aa33868 100644 --- a/ratsnlp/nlpbook/generation/__init__.py +++ b/ratsnlp/nlpbook/generation/__init__.py @@ -1,4 +1,4 @@ -from .corpus import KoreanChatCorpus +from .corpus import GenerationDataset, NsmcCorpus from .task import GenerationTask from .tokenizer import KoGPT2Tokenizer from .arguments import GenerationTrainArguments, GenerationDeployArguments diff --git a/ratsnlp/nlpbook/generation/arguments.py b/ratsnlp/nlpbook/generation/arguments.py index 1fcb03f..73d34dd 100644 --- a/ratsnlp/nlpbook/generation/arguments.py +++ b/ratsnlp/nlpbook/generation/arguments.py @@ -26,7 +26,7 @@ class GenerationTrainArguments: metadata={"help": "The output model dir."} ) max_seq_length: int = field( - default=128, + default=32, metadata={ "help": "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." @@ -61,7 +61,7 @@ class GenerationTrainArguments: metadata={"help": "Test Mode enables `fast_dev_run`"} ) learning_rate: float = field( - default=5e-6, + default=5e-5, metadata={"help": "learning rate"} ) optimizer: str = field( @@ -77,7 +77,7 @@ class GenerationTrainArguments: metadata={"help": "max epochs"} ) batch_size: int = field( - default=0, + default=96, metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"} ) cpu_workers: int = field( @@ -126,7 +126,7 @@ class GenerationDeployArguments: metadata={"help": "The output model checkpoint path."} ) max_seq_length: int = field( - default=128, + default=64, metadata={ "help": "The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded." diff --git a/ratsnlp/nlpbook/generation/corpus.py b/ratsnlp/nlpbook/generation/corpus.py index 0f27977..45dd5f1 100644 --- a/ratsnlp/nlpbook/generation/corpus.py +++ b/ratsnlp/nlpbook/generation/corpus.py @@ -8,7 +8,6 @@ from typing import List, Optional from torch.utils.data.dataset import Dataset from transformers import PreTrainedTokenizerFast -from ratsnlp.nlpbook.generation.tokenizer import MASK_TOKEN from ratsnlp.nlpbook.generation.arguments import GenerationTrainArguments @@ -16,9 +15,8 @@ @dataclass -class KoreanChatExample: - question: str - answer: str +class GenerationExample: + text: str @dataclass @@ -26,69 +24,65 @@ class GenerationFeatures: input_ids: List[int] attention_mask: Optional[List[int]] = None token_type_ids: Optional[List[int]] = None - label: Optional[List[int]] = None + labels: Optional[List[int]] = None -class KoreanChatCorpus: +class NsmcCorpus: - @classmethod - def _read_corpus(cls, input_file): + def __init__(self): + pass + + def _read_corpus(cls, input_file, quotechar='"'): with open(input_file, "r", encoding="utf-8") as f: - return list(csv.reader(f, delimiter=","))[1:] + return list(csv.reader(f, delimiter="\t", quotechar=quotechar)) def _create_examples(self, lines): examples = [] - for line in lines: - question, answer, _ = line - examples.append(KoreanChatExample(question=question, answer=answer)) + for (i, line) in enumerate(lines): + if i == 0: + continue + _, review_sentence, sentiment = line + sentiment = "긍정" if sentiment == "1" else "부정" + text = sentiment + " " + review_sentence + examples.append(GenerationExample(text=text)) return examples + def get_examples(self, data_root_path, mode): + data_fpath = os.path.join(data_root_path, f"ratings_{mode}.txt") + logger.info(f"loading {mode} data... LOOKING AT {data_fpath}") + return self._create_examples(self._read_corpus(data_fpath)) + -def _convert_chatbot_examples_to_generation_features( - examples: List[KoreanChatExample], +def _convert_examples_to_generation_features( + examples: List[GenerationExample], tokenizer: PreTrainedTokenizerFast, args: GenerationTrainArguments, ): - mask_token_id = tokenizer.convert_tokens_to_ids(MASK_TOKEN) logger.info( "tokenize sentences, it could take a lot of time..." ) start = time.time() - features = [] - for example in examples: - question_token_ids = [tokenizer.bos_token_id] + tokenizer.encode(example.question) + [tokenizer.eos_token_id] - answer_token_ids = tokenizer.encode(example.answer) + [tokenizer.eos_token_id] - answer_length = args.max_seq_length - len(question_token_ids) - if answer_length > 0: - if len(question_token_ids) + len(answer_token_ids) > args.max_seq_length: - answer_token_ids = answer_token_ids[:answer_length] - token_type_ids = [0] * len(question_token_ids) + [1] * len(answer_token_ids) - attention_mask = [1] * (len(question_token_ids + len(answer_token_ids))) - label = [mask_token_id] * len(question_token_ids) + answer_token_ids - if len(question_token_ids) + len(answer_token_ids) < args.max_seq_length: - padding_length = args.max_seq_length - len(question_token_ids) - len(answer_token_ids) - answer_token_ids += [tokenizer.pad_token_id] * padding_length - token_type_ids += [0] * padding_length - attention_mask += [0] * padding_length - label += [tokenizer.pad_token_id] * padding_length - feature = GenerationFeatures( - input_ids=question_token_ids + answer_token_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - label=label, - ) - features.append(feature) + batch_encoding = tokenizer( + [example.text for example in examples], + max_length=args.max_seq_length, + padding="max_length", + truncation=True, + ) logger.info( "tokenize sentences [took %.3f s]", time.time() - start ) + features = [] + for i in range(len(examples)): + inputs = {k: batch_encoding[k][i] for k in batch_encoding} + feature = GenerationFeatures(**inputs, labels=batch_encoding["input_ids"][i]) + features.append(feature) + for i, example in enumerate(examples[:5]): logger.info("*** Example ***") - logger.info("question: %s" % (example.question)) - logger.info("answer: %s" % (example.answer)) - logger.info("tokens: %s" % (" ".join(tokenizer.decode(features[i].input_ids)))) - logger.info("label: %s" % (" ".join(tokenizer.decode(features[i].label)))) + logger.info("sentence: %s" % (example.text)) + logger.info("tokens: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids)))) logger.info("features: %s" % features[i]) return features @@ -102,7 +96,7 @@ def __init__( tokenizer: PreTrainedTokenizerFast, corpus, mode: Optional[str] = "train", - convert_examples_to_features_fn=_convert_chatbot_examples_to_generation_features, + convert_examples_to_features_fn=_convert_examples_to_generation_features, ): if corpus is not None: self.corpus = corpus diff --git a/ratsnlp/nlpbook/generation/task.py b/ratsnlp/nlpbook/generation/task.py index 3305940..fa5a1dc 100644 --- a/ratsnlp/nlpbook/generation/task.py +++ b/ratsnlp/nlpbook/generation/task.py @@ -38,7 +38,7 @@ def forward(self, **kwargs): return self.model(**kwargs) def step(self, inputs, mode="train"): - loss, logits = self.model(**inputs) + loss, logits, _ = self.model(**inputs) preds = logits.argmax(dim=-1) labels = inputs["labels"] acc = accuracy(preds, labels) diff --git a/setup.py b/setup.py index c7ea210..146e749 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="ratsnlp", - version="0.0.955", + version="0.0.957", license='MIT', author="ratsgo", author_email="ratsgo@naver.com", @@ -16,6 +16,7 @@ 'ratsnlp.nlpbook.ner': ['*.html'], 'ratsnlp.nlpbook.qa': ['*.html'], 'ratsnlp.nlpbook.paircls': ['*.html'], + 'ratsnlp.nlpbook.generation': ['*.html'], }, install_requires=[ "torch>=1.4.0",