diff --git a/podium/datasets/arrow.py b/podium/datasets/arrow.py index 5a33d3a7..caeea893 100644 --- a/podium/datasets/arrow.py +++ b/podium/datasets/arrow.py @@ -1,4 +1,3 @@ -import csv import itertools import os import pickle @@ -6,12 +5,13 @@ import tempfile import warnings from collections import defaultdict -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union from podium.field import Field, unpack_fields from .dataset import DatasetBase -from .example_factory import Example, ExampleFactory +from .example_factory import Example +from .tabular_dataset import _load_tabular_file try: @@ -218,14 +218,15 @@ def from_examples( @staticmethod def from_tabular_file( path: str, - format: str, fields: Union[Dict[str, Field], List[Field]], + format: str = "csv", cache_path: Optional[str] = None, data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None, - chunk_size=10_000, + chunk_size: int = 1024, + line2example: Optional[Callable] = None, skip_header: bool = False, delimiter=None, - csv_reader_params: Dict = None, + csv_reader_params: Optional[Dict] = None, ) -> "DiskBackedDataset": """ Loads a tabular file format (csv, tsv, json) as a DiskBackedDataset. @@ -279,15 +280,18 @@ def from_tabular_file( Maximum number of examples to be loaded before dumping to the on-disk cache file. Use lower number if memory usage is an issue while loading. + line2example : callable + The function mapping from a file line to Fields. + In case your dataset is not in one of the standardized formats, + you can provide a function which performs a custom split for + each input line. + skip_header : bool Whether to skip the first line of the input file. If format is CSV/TSV and 'fields' is a dict, then skip_header must be False and the data file must have a header. Default is False. - delimiter: str - Delimiter used to separate columns in a row. - If set to None, the default delimiter for the given format will - be used. + csv_reader_params : Dict Parameters to pass to the csv reader. Only relevant when format is csv or tsv. @@ -299,62 +303,18 @@ def from_tabular_file( DiskBackedDataset DiskBackedDataset instance containing the examples from the tabular file. """ - format = format.lower() - csv_reader_params = {} if csv_reader_params is None else csv_reader_params - - with open(os.path.expanduser(path), encoding="utf8") as f: - if format in {"csv", "tsv"}: - delimiter = "," if format == "csv" else "\t" - reader = csv.reader(f, delimiter=delimiter, **csv_reader_params) - elif format == "json": - reader = f - else: - raise ValueError(f"Invalid format: {format}") - - if skip_header: - if format == "json": - raise ValueError( - f"When using a {format} file, skip_header must be False." - ) - elif format in {"csv", "tsv"} and isinstance(fields, dict): - raise ValueError( - f"When using a dict to specify fields with a {format} " - "file, skip_header must be False and the file must " - "have a header." - ) - # skipping the header - next(reader) - - # if format is CSV/TSV and fields is a dict, transform it to a list - if format in {"csv", "tsv"} and isinstance(fields, dict): - # we need a header to know the column names - header = next(reader) - - # columns not present in the fields dict are ignored (None) - fields = [fields.get(column, None) for column in header] - - # fields argument is the same for all examples - # fromlist is used for CSV/TSV because csv_reader yields data rows as - # lists, not strings - example_factory = ExampleFactory(fields) - make_example_function = { - "json": example_factory.from_json, - "csv": example_factory.from_list, - "tsv": example_factory.from_list, - } - - make_example = make_example_function[format] - - # map each line from the reader to an example - example_iterator = map(make_example, reader) - return DiskBackedDataset.from_examples( - fields, - example_iterator, - cache_path=cache_path, - data_types=data_types, - chunk_size=chunk_size, - ) + example_generator = _load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params + ) + + return DiskBackedDataset.from_examples( + fields, + example_generator, + cache_path=cache_path, + data_types=data_types, + chunk_size=chunk_size, + ) @staticmethod def _schema_to_data_types( diff --git a/podium/datasets/impl/catacx_comments_dataset.py b/podium/datasets/impl/catacx_comments_dataset.py index 0e43d6b0..1f1003e1 100644 --- a/podium/datasets/impl/catacx_comments_dataset.py +++ b/podium/datasets/impl/catacx_comments_dataset.py @@ -114,7 +114,7 @@ def _create_examples(dir_path, fields): A list of examples containing comments from the Catacx dataset. """ example_factory = ExampleFactory(fields) - with open(dir_path, encoding="utf8") as f: + with open(os.path.expanduser(dir_path), encoding="utf8") as f: ds = json.load(f) examples = [] diff --git a/podium/datasets/impl/snli_dataset.py b/podium/datasets/impl/snli_dataset.py index ac91807a..08b195d3 100644 --- a/podium/datasets/impl/snli_dataset.py +++ b/podium/datasets/impl/snli_dataset.py @@ -99,7 +99,7 @@ def _create_examples(file_path, fields): example_factory = ExampleFactory(fields) examples = [] - with open(file=file_path, encoding="utf8") as in_file: + with open(file=os.path.expanduser(file_path), encoding="utf8") as in_file: for line in in_file: examples.append(example_factory.from_json(line)) return examples diff --git a/podium/datasets/tabular_dataset.py b/podium/datasets/tabular_dataset.py index fbd0768c..1cad5dcc 100644 --- a/podium/datasets/tabular_dataset.py +++ b/podium/datasets/tabular_dataset.py @@ -1,7 +1,9 @@ import csv import os +from typing import Callable, Dict, List, Optional, Union from podium.datasets.dataset import Dataset +from podium.field import Field from .example_factory import ExampleFactory @@ -14,12 +16,12 @@ class TabularDataset(Dataset): def __init__( self, - path, - fields, - format="csv", - line2example=None, - skip_header=False, - csv_reader_params={}, + path: str, + fields: Union[Dict[str, Field], List[Field]], + format: str = "csv", + line2example: Optional[Callable] = None, + skip_header: bool = False, + csv_reader_params: Optional[Dict] = None, **kwargs, ): """ @@ -83,49 +85,83 @@ def __init__( If format is "JSON" and skip_header is True. """ - format = format.lower() - - with open(os.path.expanduser(path), encoding="utf8") as f: - - # Skip header prior to custom line2example in case - # the header is in a different format so we don't - # cause an error. - if skip_header: - if format == "json": - raise ValueError( - f"When using a {format} file, skip_header \ - must be False." - ) - elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict): - raise ValueError( - f"When using a dict to specify fields with a {format} " - "file, skip_header must be False and the file must " - "have a header." - ) - - # skip the header - next(f) - - if line2example is not None: - reader = (line2example(line) for line in f) - format = "custom" - elif format in {"csv", "tsv"}: - delimiter = "," if format == "csv" else "\t" - reader = csv.reader(f, delimiter=delimiter, **csv_reader_params) - elif format == "json": - reader = f - else: - raise ValueError(f"Invalid format: {format}") - - # create a list of examples - examples = create_examples(reader, format, fields) - - # create a Dataset with lists of examples and fields + examples = _load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params + ) + # Make the examples concrete here by casting to list + examples = list(examples) + super(TabularDataset, self).__init__(examples, fields, **kwargs) self.finalize_fields() -def create_examples(reader, format, fields): +def _load_tabular_file( + path, fields, format, line2example, skip_header, csv_reader_params +): + """ + Loads examples as a generator from a dataset in tabular format. + + Abstracted from TabularDataset due to duplicate usage in ArrowDataset. + Parameters same as in TabularDataset constructor. + """ + + with open(os.path.expanduser(path), encoding="utf8", newline="") as f: + # create a list of examples + reader, format = _initialize_tabular_reader( + f, format, fields, line2example, skip_header, csv_reader_params + ) + examples = _create_examples(reader, format, fields) + yield from examples + + +def _initialize_tabular_reader( + file, format, fields, line2example, skip_header, csv_reader_params +): + """ + Initializes the input stream from a file. + + In case of using a custom format handled by line2example, lazily applies the + row transfromation on the data. + """ + + format = format.lower() + + # Skip header prior to custom line2example in case + # the header is in a different format so we don't + # cause an error. + if skip_header: + if format == "json": + raise ValueError( + f"When using a {format} file, skip_header \ + must be False." + ) + elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict): + raise ValueError( + f"When using a dict to specify fields with a {format} " + "file, skip_header must be False and the file must " + "have a header." + ) + + # skip the header + next(file) + + if line2example is not None: + reader = (line2example(line) for line in file) + format = "custom" + elif format in {"csv", "tsv"}: + delimiter = "," if format == "csv" else "\t" + if csv_reader_params is None: + csv_reader_params = {} + reader = csv.reader(file, delimiter=delimiter, **csv_reader_params) + elif format == "json": + reader = file + else: + raise ValueError(f"Invalid format: {format}") + + return reader, format + + +def _create_examples(reader, format, fields): """ Creates a list of examples from the given line reader and fields (see TabularDataset.__init__ docs for more info on the fields). @@ -133,8 +169,10 @@ def create_examples(reader, format, fields): Parameters ---------- reader - A reader object that reads one line at a time. Yields either strings - (when format is JSON) or lists of values (when format is CSV/TSV). + A reader object that reads one line at a time. Yields strings + (when format is JSON) or lists of values (when format is CSV/TSV) + or a sequence of pre-transformed lines if a custom format is used + via the `line2example` argument. format : str Format of the data file that is being read. Can be either CSV, TSV or JSON. @@ -178,4 +216,4 @@ def create_examples(reader, format, fields): # map each line from the reader to an example examples = map(make_example, reader) - return list(examples) + return examples diff --git a/podium/vectorizers/vectorizer.py b/podium/vectorizers/vectorizer.py index 85a872cf..870008d6 100644 --- a/podium/vectorizers/vectorizer.py +++ b/podium/vectorizers/vectorizer.py @@ -91,9 +91,11 @@ def __init__( max_vectors : int, optional maximum number of vectors to load in memory """ - self._path = path + self._path = os.path.expanduser(path) if path is not None else path self._default_vector_function = default_vector_function - self._cache_path = cache_path + self._cache_path = ( + os.path.expanduser(cache_path) if cache_path is not None else cache_path + ) self._max_vectors = max_vectors @abstractmethod diff --git a/tests/datasets/test_arrow.py b/tests/datasets/test_arrow.py index 505e5656..5ac93054 100644 --- a/tests/datasets/test_arrow.py +++ b/tests/datasets/test_arrow.py @@ -223,7 +223,7 @@ def test_from_tabular(data, fields, tmpdir): writer = csv.writer(f) writer.writerows(data) - csv_dataset = DiskBackedDataset.from_tabular_file(test_file, "csv", fields) + csv_dataset = DiskBackedDataset.from_tabular_file(test_file, fields, format="csv") for ex, d in zip(csv_dataset, data): assert int(ex.number[0]) == d[0] assert ex.tokens[0] == d[1]