Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chinese language support #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions data_selection/hashed_ngram_dsir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm
from nltk.tokenize import WordPunctTokenizer
from nltk.tokenize import word_tokenize
import jieba # 导入 jieba 库
from nltk import ngrams as get_ngrams
import numpy as np

Expand All @@ -28,15 +29,6 @@ def get_ngram_counts(line: str,
num_buckets: int = 10000,
counts: Optional[np.ndarray] = None,
tokenizer: Callable = wpt.tokenize) -> np.ndarray:
'''Return ngram count features given a string.

Args:
line: string to get ngram counts from
n: n in ngrams
num_buckets: number of buckets to hash ngrams into
counts: pre-initialized counts array
tokenizer: tokenization function to use. Defaults to word_tokenize from nltk
'''
words = tokenizer(line.lower())

if counts is None:
Expand All @@ -52,8 +44,6 @@ def get_ngram_counts(line: str,


class HashedNgramDSIR(DSIR):
"""DSIR with hashed n-gram features."""

def __init__(self,
raw_datasets: List[str],
target_datasets: List[str],
Expand All @@ -69,25 +59,8 @@ def __init__(self,
min_example_length: int = 100,
target_laplace_smoothing: float = 0.0,
separate_targets: bool = False,
target_proportions: Optional[List[float]] = None) -> None:
'''Initialize the HashedNgramDSIR object.

Args:
raw_datasets: List of data paths
target_datasets: List of data paths
cache_dir: place to store cached log_importance_weights
load_dataset_fn: Function to load a dataset from a path. Defaults to default_load_dataset_fn.
parse_example_fn: Function that takes in an example dict and returns a string.
Defaults to returning the 'text' field of the example.
num_proc: number of processes to use for parallelization. Defaults to number of cores.
ngrams: N in N-grams. 2 means both unigram and bigrams.
num_buckets: number of buckets to hash ngrams into.
tokenizer: word_tokenize or wordpunct
min_example_length: minimum number of tokens in an example to be considered.
target_laplace_smoothing: Smooth the target hashed ngram distribution. This parameter is a pseudo-count. This could be useful for small target datasets.
separate_targets: whether to select data separately for each target and then join them
target_proportions: weighting across multiple targets if separate_targets=True. Set to None to weight by the size of each target dataset
'''
target_proportions: Optional[List[float]] = None,
language: str = 'en') -> None:
super().__init__(
raw_datasets=raw_datasets,
target_datasets=target_datasets,
Expand All @@ -99,12 +72,19 @@ def __init__(self,
num_proc=num_proc,
separate_targets=separate_targets,
target_proportions=target_proportions)
if tokenizer == 'word_tokenize':
self.language = language

# add chinese support (use jieba cuts word)
if self.language == 'zh':
self.tokenizer = lambda text: list(jieba.cut(text))

elif tokenizer == 'word_tokenize':
self.tokenizer = word_tokenize
elif tokenizer == 'wordpunct':
self.tokenizer = wpt.tokenize
else:
raise ValueError('tokenizer not recognized')

self.ngrams = ngrams
self.num_buckets = num_buckets
self.min_example_length = min_example_length
Expand Down