From 64b00d3f60b9f53f7ab989fdda1fc5913a2895e7 Mon Sep 17 00:00:00 2001 From: "dikw.nlp@gmail.com" Date: Sun, 7 Apr 2024 17:12:35 +0800 Subject: [PATCH] add chinese language support --- data_selection/hashed_ngram_dsir.py | 42 ++++++++--------------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/data_selection/hashed_ngram_dsir.py b/data_selection/hashed_ngram_dsir.py index 9a17e44..304c4cf 100644 --- a/data_selection/hashed_ngram_dsir.py +++ b/data_selection/hashed_ngram_dsir.py @@ -3,6 +3,7 @@ from tqdm import tqdm from nltk.tokenize import WordPunctTokenizer from nltk.tokenize import word_tokenize +import jieba # 导入 jieba 库 from nltk import ngrams as get_ngrams import numpy as np @@ -28,15 +29,6 @@ def get_ngram_counts(line: str, num_buckets: int = 10000, counts: Optional[np.ndarray] = None, tokenizer: Callable = wpt.tokenize) -> np.ndarray: - '''Return ngram count features given a string. - - Args: - line: string to get ngram counts from - n: n in ngrams - num_buckets: number of buckets to hash ngrams into - counts: pre-initialized counts array - tokenizer: tokenization function to use. Defaults to word_tokenize from nltk - ''' words = tokenizer(line.lower()) if counts is None: @@ -52,8 +44,6 @@ def get_ngram_counts(line: str, class HashedNgramDSIR(DSIR): - """DSIR with hashed n-gram features.""" - def __init__(self, raw_datasets: List[str], target_datasets: List[str], @@ -69,25 +59,8 @@ def __init__(self, min_example_length: int = 100, target_laplace_smoothing: float = 0.0, separate_targets: bool = False, - target_proportions: Optional[List[float]] = None) -> None: - '''Initialize the HashedNgramDSIR object. - - Args: - raw_datasets: List of data paths - target_datasets: List of data paths - cache_dir: place to store cached log_importance_weights - load_dataset_fn: Function to load a dataset from a path. Defaults to default_load_dataset_fn. - parse_example_fn: Function that takes in an example dict and returns a string. - Defaults to returning the 'text' field of the example. - num_proc: number of processes to use for parallelization. Defaults to number of cores. - ngrams: N in N-grams. 2 means both unigram and bigrams. - num_buckets: number of buckets to hash ngrams into. - tokenizer: word_tokenize or wordpunct - min_example_length: minimum number of tokens in an example to be considered. - target_laplace_smoothing: Smooth the target hashed ngram distribution. This parameter is a pseudo-count. This could be useful for small target datasets. - separate_targets: whether to select data separately for each target and then join them - target_proportions: weighting across multiple targets if separate_targets=True. Set to None to weight by the size of each target dataset - ''' + target_proportions: Optional[List[float]] = None, + language: str = 'en') -> None: super().__init__( raw_datasets=raw_datasets, target_datasets=target_datasets, @@ -99,12 +72,19 @@ def __init__(self, num_proc=num_proc, separate_targets=separate_targets, target_proportions=target_proportions) - if tokenizer == 'word_tokenize': + self.language = language + + # add chinese support (use jieba cuts word) + if self.language == 'zh': + self.tokenizer = lambda text: list(jieba.cut(text)) + + elif tokenizer == 'word_tokenize': self.tokenizer = word_tokenize elif tokenizer == 'wordpunct': self.tokenizer = wpt.tokenize else: raise ValueError('tokenizer not recognized') + self.ngrams = ngrams self.num_buckets = num_buckets self.min_example_length = min_example_length