From 8f02b531b021497c9dd4bda45c9670b7f238fb90 Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Thu, 14 Jan 2021 18:27:53 +0100 Subject: [PATCH 1/7] Abstract tabular dataset loading --- podium/datasets/arrow_tabular_dataset.py | 82 +++++------------- podium/datasets/tabular_dataset.py | 105 ++++++++++++++--------- 2 files changed, 86 insertions(+), 101 deletions(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index c20248d3..1eb8fa6f 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -12,7 +12,7 @@ from .dataset import Dataset, DatasetBase from .example_factory import Example, ExampleFactory - +from .tabular_dataset import load_tabular_file try: import pyarrow as pa @@ -205,9 +205,9 @@ def from_tabular_file( cache_path: str = None, data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None, chunk_size=10_000, + line2example=None, skip_header: bool = False, - delimiter=None, - csv_reader_params: Dict = None, + csv_reader_params: Dict = {}, ) -> "ArrowDataset": """ Loads a tabular file format (csv, tsv, json) as an ArrowDataset. @@ -256,15 +256,18 @@ def from_tabular_file( Maximum number of examples to be loaded before dumping to the on-disk cache file. Use lower number if memory usage is an issue while loading. + line2example : callable + The function mapping from a file line to Fields. + In case your dataset is not in one of the standardized formats, + you can provide a function which performs a custom split for + each input line. + skip_header : bool Whether to skip the first line of the input file. If format is CSV/TSV and 'fields' is a dict, then skip_header must be False and the data file must have a header. Default is False. - delimiter: str - Delimiter used to separate columns in a row. - If set to None, the default delimiter for the given format will - be used. + csv_reader_params : Dict Parameters to pass to the csv reader. Only relevant when format is csv or tsv. @@ -276,62 +279,17 @@ def from_tabular_file( ArrowDataset ArrowDataset instance containing the examples from the tabular file. """ - format = format.lower() - csv_reader_params = {} if csv_reader_params is None else csv_reader_params - - with open(os.path.expanduser(path), encoding="utf8") as f: - if format in {"csv", "tsv"}: - delimiter = "," if format == "csv" else "\t" - reader = csv.reader(f, delimiter=delimiter, **csv_reader_params) - elif format == "json": - reader = f - else: - raise ValueError(f"Invalid format: {format}") - - if skip_header: - if format == "json": - raise ValueError( - f"When using a {format} file, skip_header must be False." - ) - elif format in {"csv", "tsv"} and isinstance(fields, dict): - raise ValueError( - f"When using a dict to specify fields with a {format} " - "file, skip_header must be False and the file must " - "have a header." - ) - # skipping the header - next(reader) - - # if format is CSV/TSV and fields is a dict, transform it to a list - if format in {"csv", "tsv"} and isinstance(fields, dict): - # we need a header to know the column names - header = next(reader) - - # columns not present in the fields dict are ignored (None) - fields = [fields.get(column, None) for column in header] - - # fields argument is the same for all examples - # fromlist is used for CSV/TSV because csv_reader yields data rows as - # lists, not strings - example_factory = ExampleFactory(fields) - make_example_function = { - "json": example_factory.from_json, - "csv": example_factory.from_list, - "tsv": example_factory.from_list, - } - - make_example = make_example_function[format] - - # map each line from the reader to an example - example_iterator = map(make_example, reader) - return ArrowDataset.from_examples( - fields, - example_iterator, - cache_path=cache_path, - data_types=data_types, - chunk_size=chunk_size, - ) + example_generator = load_tabular_file(path, fields, format, line2example, + skip_header, csv_reader_params) + + return ArrowDataset.from_examples( + fields, + example_generator, + cache_path=cache_path, + data_types=data_types, + chunk_size=chunk_size, + ) @staticmethod def _schema_to_data_types( diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index fbd0768c..3696fbd5 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -83,47 +83,74 @@ def __init__( If format is "JSON" and skip_header is True. """ - format = format.lower() - - with open(os.path.expanduser(path), encoding="utf8") as f: - - # Skip header prior to custom line2example in case - # the header is in a different format so we don't - # cause an error. - if skip_header: - if format == "json": - raise ValueError( - f"When using a {format} file, skip_header \ - must be False." - ) - elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict): - raise ValueError( - f"When using a dict to specify fields with a {format} " - "file, skip_header must be False and the file must " - "have a header." - ) - - # skip the header - next(f) - - if line2example is not None: - reader = (line2example(line) for line in f) - format = "custom" - elif format in {"csv", "tsv"}: - delimiter = "," if format == "csv" else "\t" - reader = csv.reader(f, delimiter=delimiter, **csv_reader_params) - elif format == "json": - reader = f - else: - raise ValueError(f"Invalid format: {format}") - - # create a list of examples - examples = create_examples(reader, format, fields) - - # create a Dataset with lists of examples and fields + examples = load_tabular_file(path, fields, format, line2example, + skip_header, csv_reader_params) + examples = list(examples) + + # Make the examples concrete here by casting to list super(TabularDataset, self).__init__(examples, fields, **kwargs) self.finalize_fields() +def load_tabular_file( + path, + fields, + format, + line2example, + skip_header, + csv_reader_params + ): + + with open(os.path.expanduser(path), encoding="utf8") as f: + # create a list of examples + reader = initialize_tabular_reader(f, format, fields, line2example, + skip_header, csv_reader_params) + examples = create_examples(reader, format, fields) + yield from examples + +def initialize_tabular_reader( + file, + format, + fields, + line2example, + skip_header, + csv_reader_params + ): + + format = format.lower() + + + # Skip header prior to custom line2example in case + # the header is in a different format so we don't + # cause an error. + if skip_header: + if format == "json": + raise ValueError( + f"When using a {format} file, skip_header \ + must be False." + ) + elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict): + raise ValueError( + f"When using a dict to specify fields with a {format} " + "file, skip_header must be False and the file must " + "have a header." + ) + + # skip the header + next(file) + + if line2example is not None: + reader = (line2example(line) for line in file) + format = "custom" + elif format in {"csv", "tsv"}: + delimiter = "," if format == "csv" else "\t" + reader = csv.reader(file, delimiter=delimiter, **csv_reader_params) + elif format == "json": + reader = file + else: + raise ValueError(f"Invalid format: {format}") + + return reader + def create_examples(reader, format, fields): """ @@ -178,4 +205,4 @@ def create_examples(reader, format, fields): # map each line from the reader to an example examples = map(make_example, reader) - return list(examples) + return examples From 36b3077c441eab01556cd472f91dea6c17bfc669 Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Thu, 14 Jan 2021 18:29:03 +0100 Subject: [PATCH 2/7] style --- podium/datasets/arrow_tabular_dataset.py | 6 +++-- podium/datasets/tabular_dataset.py | 31 +++++++++--------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index 1eb8fa6f..af34ac0f 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -14,6 +14,7 @@ from .example_factory import Example, ExampleFactory from .tabular_dataset import load_tabular_file + try: import pyarrow as pa except ImportError: @@ -280,8 +281,9 @@ def from_tabular_file( ArrowDataset instance containing the examples from the tabular file. """ - example_generator = load_tabular_file(path, fields, format, line2example, - skip_header, csv_reader_params) + example_generator = load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params + ) return ArrowDataset.from_examples( fields, diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index 3696fbd5..68ffd6ba 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -83,42 +83,33 @@ def __init__( If format is "JSON" and skip_header is True. """ - examples = load_tabular_file(path, fields, format, line2example, - skip_header, csv_reader_params) + examples = load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params + ) examples = list(examples) # Make the examples concrete here by casting to list super(TabularDataset, self).__init__(examples, fields, **kwargs) self.finalize_fields() -def load_tabular_file( - path, - fields, - format, - line2example, - skip_header, - csv_reader_params - ): + +def load_tabular_file(path, fields, format, line2example, skip_header, csv_reader_params): with open(os.path.expanduser(path), encoding="utf8") as f: # create a list of examples - reader = initialize_tabular_reader(f, format, fields, line2example, - skip_header, csv_reader_params) + reader = initialize_tabular_reader( + f, format, fields, line2example, skip_header, csv_reader_params + ) examples = create_examples(reader, format, fields) yield from examples + def initialize_tabular_reader( - file, - format, - fields, - line2example, - skip_header, - csv_reader_params - ): + file, format, fields, line2example, skip_header, csv_reader_params +): format = format.lower() - # Skip header prior to custom line2example in case # the header is in a different format so we don't # cause an error. From a0272a2b5e6ad48b6204f8b25e4e956d2cdef877 Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Thu, 14 Jan 2021 18:31:35 +0100 Subject: [PATCH 3/7] style --- podium/datasets/arrow_tabular_dataset.py | 3 +-- podium/datasets/tabular_dataset.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index af34ac0f..0e84e667 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -1,4 +1,3 @@ -import csv import itertools import os import pickle @@ -11,7 +10,7 @@ from podium.field import Field, unpack_fields from .dataset import Dataset, DatasetBase -from .example_factory import Example, ExampleFactory +from .example_factory import Example from .tabular_dataset import load_tabular_file diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index 68ffd6ba..9e479f0d 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -86,9 +86,9 @@ def __init__( examples = load_tabular_file( path, fields, format, line2example, skip_header, csv_reader_params ) + # Make the examples concrete here by casting to list examples = list(examples) - # Make the examples concrete here by casting to list super(TabularDataset, self).__init__(examples, fields, **kwargs) self.finalize_fields() From 8300dc76df1c40bada2b13bf063f2dc00684562e Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Sat, 16 Jan 2021 14:17:43 +0100 Subject: [PATCH 4/7] Minor comments --- podium/datasets/arrow_tabular_dataset.py | 4 ++-- podium/datasets/tabular_dataset.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index 0e84e667..2402a096 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -5,7 +5,7 @@ import tempfile import warnings from collections import defaultdict -from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from podium.field import Field, unpack_fields @@ -207,7 +207,7 @@ def from_tabular_file( chunk_size=10_000, line2example=None, skip_header: bool = False, - csv_reader_params: Dict = {}, + csv_reader_params: Optional[Dict] = None, ) -> "ArrowDataset": """ Loads a tabular file format (csv, tsv, json) as an ArrowDataset. diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index 9e479f0d..296f50c1 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -19,7 +19,7 @@ def __init__( format="csv", line2example=None, skip_header=False, - csv_reader_params={}, + csv_reader_params=None, **kwargs, ): """ @@ -97,7 +97,7 @@ def load_tabular_file(path, fields, format, line2example, skip_header, csv_reade with open(os.path.expanduser(path), encoding="utf8") as f: # create a list of examples - reader = initialize_tabular_reader( + reader, format = initialize_tabular_reader( f, format, fields, line2example, skip_header, csv_reader_params ) examples = create_examples(reader, format, fields) @@ -134,13 +134,15 @@ def initialize_tabular_reader( format = "custom" elif format in {"csv", "tsv"}: delimiter = "," if format == "csv" else "\t" + if csv_reader_params is None: + csv_reader_params = {} reader = csv.reader(file, delimiter=delimiter, **csv_reader_params) elif format == "json": reader = file else: raise ValueError(f"Invalid format: {format}") - return reader + return reader, format def create_examples(reader, format, fields): From d0567f2e4265bdf93547d92c4ce81b92b5ffdee4 Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Sat, 16 Jan 2021 14:37:05 +0100 Subject: [PATCH 5/7] Comments --- podium/datasets/arrow_tabular_dataset.py | 14 +++--- podium/datasets/tabular_dataset.py | 46 +++++++++++++------ .../datasets/test_pyarrow_tabular_dataset.py | 2 +- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index 2402a096..49e2999d 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -5,13 +5,13 @@ import tempfile import warnings from collections import defaultdict -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union from podium.field import Field, unpack_fields from .dataset import Dataset, DatasetBase from .example_factory import Example -from .tabular_dataset import load_tabular_file +from .tabular_dataset import _load_tabular_file try: @@ -200,12 +200,12 @@ def from_examples( @staticmethod def from_tabular_file( path: str, - format: str, fields: Union[Dict[str, Field], List[Field]], - cache_path: str = None, + format: str = "csv", + cache_path: Optional[str] = None, data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None, - chunk_size=10_000, - line2example=None, + chunk_size: int = 1024, + line2example: Optional[Callable] = None, skip_header: bool = False, csv_reader_params: Optional[Dict] = None, ) -> "ArrowDataset": @@ -280,7 +280,7 @@ def from_tabular_file( ArrowDataset instance containing the examples from the tabular file. """ - example_generator = load_tabular_file( + example_generator = _load_tabular_file( path, fields, format, line2example, skip_header, csv_reader_params ) diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index 296f50c1..1551c3b9 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -1,7 +1,9 @@ import csv import os +from typing import Callable, Dict, List, Optional, Union from podium.datasets.dataset import Dataset +from podium.field import Field from .example_factory import ExampleFactory @@ -14,12 +16,12 @@ class TabularDataset(Dataset): def __init__( self, - path, - fields, - format="csv", - line2example=None, - skip_header=False, - csv_reader_params=None, + path: str, + fields: Union[Dict[str, Field], List[Field]], + format: str = "csv", + line2example: Optional[Callable] = None, + skip_header: bool = False, + csv_reader_params: Optional[Dict] = None, **kwargs, ): """ @@ -83,7 +85,7 @@ def __init__( If format is "JSON" and skip_header is True. """ - examples = load_tabular_file( + examples = _load_tabular_file( path, fields, format, line2example, skip_header, csv_reader_params ) # Make the examples concrete here by casting to list @@ -93,20 +95,34 @@ def __init__( self.finalize_fields() -def load_tabular_file(path, fields, format, line2example, skip_header, csv_reader_params): +def _load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params +): + """ + Loads examples as a generator from a dataset in tabular format. + + Abstracted from TabularDataset due to duplicate usage in ArrowDataset. + Parameters same as in TabularDataset constructor. + """ with open(os.path.expanduser(path), encoding="utf8") as f: # create a list of examples - reader, format = initialize_tabular_reader( + reader, format = _initialize_tabular_reader( f, format, fields, line2example, skip_header, csv_reader_params ) - examples = create_examples(reader, format, fields) + examples = _create_examples(reader, format, fields) yield from examples -def initialize_tabular_reader( +def _initialize_tabular_reader( file, format, fields, line2example, skip_header, csv_reader_params ): + """ + Initializes the input stream from a file. + + In case of using a custom format handled by line2example, lazily applies the + row transfromation on the data. + """ format = format.lower() @@ -145,7 +161,7 @@ def initialize_tabular_reader( return reader, format -def create_examples(reader, format, fields): +def _create_examples(reader, format, fields): """ Creates a list of examples from the given line reader and fields (see TabularDataset.__init__ docs for more info on the fields). @@ -153,8 +169,10 @@ def create_examples(reader, format, fields): Parameters ---------- reader - A reader object that reads one line at a time. Yields either strings - (when format is JSON) or lists of values (when format is CSV/TSV). + A reader object that reads one line at a time. Yields strings + (when format is JSON) or lists of values (when format is CSV/TSV) + or a sequence of pre-transformed lines if a custom format is used + via the `line2example` argument. format : str Format of the data file that is being read. Can be either CSV, TSV or JSON. diff --git a/tests/datasets/test_pyarrow_tabular_dataset.py b/tests/datasets/test_pyarrow_tabular_dataset.py index 50f7efa6..fe615631 100644 --- a/tests/datasets/test_pyarrow_tabular_dataset.py +++ b/tests/datasets/test_pyarrow_tabular_dataset.py @@ -222,7 +222,7 @@ def test_from_tabular(data, fields, tmpdir): writer = csv.writer(f) writer.writerows(data) - csv_dataset = ArrowDataset.from_tabular_file(test_file, "csv", fields) + csv_dataset = ArrowDataset.from_tabular_file(test_file, fields, format="csv") for ex, d in zip(csv_dataset, data): assert int(ex.number[0]) == d[0] assert ex.tokens[0] == d[1] From b2a2b5bdd381f148099b31f2a336198b636afb8f Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Sat, 16 Jan 2021 15:43:05 +0100 Subject: [PATCH 6/7] Standardize file opening, expanduser --- podium/datasets/impl/catacx_comments_dataset.py | 2 +- podium/datasets/impl/snli_dataset.py | 2 +- podium/datasets/tabular_dataset.py | 2 +- podium/vectorizers/vectorizer.py | 6 ++++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/podium/datasets/impl/catacx_comments_dataset.py b/podium/datasets/impl/catacx_comments_dataset.py index 0e43d6b0..1f1003e1 100644 --- a/podium/datasets/impl/catacx_comments_dataset.py +++ b/podium/datasets/impl/catacx_comments_dataset.py @@ -114,7 +114,7 @@ def _create_examples(dir_path, fields): A list of examples containing comments from the Catacx dataset. """ example_factory = ExampleFactory(fields) - with open(dir_path, encoding="utf8") as f: + with open(os.path.expanduser(dir_path), encoding="utf8") as f: ds = json.load(f) examples = [] diff --git a/podium/datasets/impl/snli_dataset.py b/podium/datasets/impl/snli_dataset.py index ac91807a..08b195d3 100644 --- a/podium/datasets/impl/snli_dataset.py +++ b/podium/datasets/impl/snli_dataset.py @@ -99,7 +99,7 @@ def _create_examples(file_path, fields): example_factory = ExampleFactory(fields) examples = [] - with open(file=file_path, encoding="utf8") as in_file: + with open(file=os.path.expanduser(file_path), encoding="utf8") as in_file: for line in in_file: examples.append(example_factory.from_json(line)) return examples diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index 1551c3b9..1cad5dcc 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -105,7 +105,7 @@ def _load_tabular_file( Parameters same as in TabularDataset constructor. """ - with open(os.path.expanduser(path), encoding="utf8") as f: + with open(os.path.expanduser(path), encoding="utf8", newline="") as f: # create a list of examples reader, format = _initialize_tabular_reader( f, format, fields, line2example, skip_header, csv_reader_params diff --git a/podium/vectorizers/vectorizer.py b/podium/vectorizers/vectorizer.py index 69100c27..83fafa90 100644 --- a/podium/vectorizers/vectorizer.py +++ b/podium/vectorizers/vectorizer.py @@ -91,9 +91,11 @@ def __init__( max_vectors : int, optional maximum number of vectors to load in memory """ - self._path = path + self._path = os.path.expanduser(path) if path is not None else path self._default_vector_function = default_vector_function - self._cache_path = cache_path + self._cache_path = ( + os.path.expanduser(cache_path) if cache_path is not None else cache_path + ) self._max_vectors = max_vectors @abstractmethod From 5ac458c65e3701b843d3cb8225c1dc9cf6756113 Mon Sep 17 00:00:00 2001 From: Martin Tutek Date: Mon, 18 Jan 2021 14:15:43 +0100 Subject: [PATCH 7/7] Flake --- podium/datasets/arrow_tabular_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/podium/datasets/arrow_tabular_dataset.py b/podium/datasets/arrow_tabular_dataset.py index 62845b6d..23a5298e 100644 --- a/podium/datasets/arrow_tabular_dataset.py +++ b/podium/datasets/arrow_tabular_dataset.py @@ -10,7 +10,7 @@ from podium.field import Field, unpack_fields from .dataset import DatasetBase -from .example_factory import Example, ExampleFactory +from .example_factory import Example from .tabular_dataset import _load_tabular_file