Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract tabular dataset loading #274

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 25 additions & 65 deletions podium/datasets/arrow.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import csv
import itertools
import os
import pickle
import shutil
import tempfile
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from podium.field import Field, unpack_fields

from .dataset import DatasetBase
from .example_factory import Example, ExampleFactory
from .example_factory import Example
from .tabular_dataset import _load_tabular_file


try:
Expand Down Expand Up @@ -218,14 +218,15 @@ def from_examples(
@staticmethod
def from_tabular_file(
path: str,
format: str,
fields: Union[Dict[str, Field], List[Field]],
format: str = "csv",
cache_path: Optional[str] = None,
data_types: Dict[str, Tuple[pa.DataType, pa.DataType]] = None,
chunk_size=10_000,
chunk_size: int = 1024,
line2example: Optional[Callable] = None,
skip_header: bool = False,
delimiter=None,
csv_reader_params: Dict = None,
csv_reader_params: Optional[Dict] = None,
) -> "DiskBackedDataset":
"""
Loads a tabular file format (csv, tsv, json) as a DiskBackedDataset.
Expand Down Expand Up @@ -279,15 +280,18 @@ def from_tabular_file(
Maximum number of examples to be loaded before dumping to the on-disk cache
file. Use lower number if memory usage is an issue while loading.

line2example : callable
The function mapping from a file line to Fields.
In case your dataset is not in one of the standardized formats,
you can provide a function which performs a custom split for
each input line.

skip_header : bool
Whether to skip the first line of the input file.
If format is CSV/TSV and 'fields' is a dict, then skip_header
must be False and the data file must have a header.
Default is False.
delimiter: str
Delimiter used to separate columns in a row.
If set to None, the default delimiter for the given format will
be used.

csv_reader_params : Dict
Parameters to pass to the csv reader. Only relevant when
format is csv or tsv.
Expand All @@ -299,62 +303,18 @@ def from_tabular_file(
DiskBackedDataset
DiskBackedDataset instance containing the examples from the tabular file.
"""
format = format.lower()
csv_reader_params = {} if csv_reader_params is None else csv_reader_params

with open(os.path.expanduser(path), encoding="utf8") as f:
if format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
reader = csv.reader(f, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = f
else:
raise ValueError(f"Invalid format: {format}")

if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header must be False."
)
elif format in {"csv", "tsv"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skipping the header
next(reader)

# if format is CSV/TSV and fields is a dict, transform it to a list
if format in {"csv", "tsv"} and isinstance(fields, dict):
# we need a header to know the column names
header = next(reader)

# columns not present in the fields dict are ignored (None)
fields = [fields.get(column, None) for column in header]

# fields argument is the same for all examples
# fromlist is used for CSV/TSV because csv_reader yields data rows as
# lists, not strings
example_factory = ExampleFactory(fields)
make_example_function = {
"json": example_factory.from_json,
"csv": example_factory.from_list,
"tsv": example_factory.from_list,
}

make_example = make_example_function[format]

# map each line from the reader to an example
example_iterator = map(make_example, reader)
return DiskBackedDataset.from_examples(
fields,
example_iterator,
cache_path=cache_path,
data_types=data_types,
chunk_size=chunk_size,
)
example_generator = _load_tabular_file(
path, fields, format, line2example, skip_header, csv_reader_params
)

return DiskBackedDataset.from_examples(
fields,
example_generator,
cache_path=cache_path,
data_types=data_types,
chunk_size=chunk_size,
)

@staticmethod
def _schema_to_data_types(
Expand Down
2 changes: 1 addition & 1 deletion podium/datasets/impl/catacx_comments_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _create_examples(dir_path, fields):
A list of examples containing comments from the Catacx dataset.
"""
example_factory = ExampleFactory(fields)
with open(dir_path, encoding="utf8") as f:
with open(os.path.expanduser(dir_path), encoding="utf8") as f:
ds = json.load(f)

examples = []
Expand Down
2 changes: 1 addition & 1 deletion podium/datasets/impl/snli_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _create_examples(file_path, fields):
example_factory = ExampleFactory(fields)
examples = []

with open(file=file_path, encoding="utf8") as in_file:
with open(file=os.path.expanduser(file_path), encoding="utf8") as in_file:
for line in in_file:
examples.append(example_factory.from_json(line))
return examples
Expand Down
134 changes: 86 additions & 48 deletions podium/datasets/tabular_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import csv
import os
from typing import Callable, Dict, List, Optional, Union

from podium.datasets.dataset import Dataset
from podium.field import Field

from .example_factory import ExampleFactory

Expand All @@ -14,12 +16,12 @@ class TabularDataset(Dataset):

def __init__(
self,
path,
fields,
format="csv",
line2example=None,
skip_header=False,
csv_reader_params={},
path: str,
fields: Union[Dict[str, Field], List[Field]],
format: str = "csv",
line2example: Optional[Callable] = None,
skip_header: bool = False,
csv_reader_params: Optional[Dict] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -83,58 +85,94 @@ def __init__(
If format is "JSON" and skip_header is True.
"""

format = format.lower()

with open(os.path.expanduser(path), encoding="utf8") as f:

# Skip header prior to custom line2example in case
# the header is in a different format so we don't
# cause an error.
if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header \
must be False."
)
elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skip the header
next(f)

if line2example is not None:
reader = (line2example(line) for line in f)
format = "custom"
elif format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
reader = csv.reader(f, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = f
else:
raise ValueError(f"Invalid format: {format}")

# create a list of examples
examples = create_examples(reader, format, fields)

# create a Dataset with lists of examples and fields
examples = _load_tabular_file(
path, fields, format, line2example, skip_header, csv_reader_params
)
# Make the examples concrete here by casting to list
examples = list(examples)

super(TabularDataset, self).__init__(examples, fields, **kwargs)
self.finalize_fields()


def create_examples(reader, format, fields):
def _load_tabular_file(
path, fields, format, line2example, skip_header, csv_reader_params
):
"""
Loads examples as a generator from a dataset in tabular format.

Abstracted from TabularDataset due to duplicate usage in ArrowDataset.
Parameters same as in TabularDataset constructor.
"""

with open(os.path.expanduser(path), encoding="utf8", newline="") as f:
# create a list of examples
reader, format = _initialize_tabular_reader(
f, format, fields, line2example, skip_header, csv_reader_params
)
examples = _create_examples(reader, format, fields)
yield from examples


def _initialize_tabular_reader(
file, format, fields, line2example, skip_header, csv_reader_params
):
"""
Initializes the input stream from a file.

In case of using a custom format handled by line2example, lazily applies the
row transfromation on the data.
"""

format = format.lower()

# Skip header prior to custom line2example in case
# the header is in a different format so we don't
# cause an error.
if skip_header:
if format == "json":
raise ValueError(
f"When using a {format} file, skip_header \
must be False."
)
elif format in {"csv", "tsv", "custom"} and isinstance(fields, dict):
raise ValueError(
f"When using a dict to specify fields with a {format} "
"file, skip_header must be False and the file must "
"have a header."
)

# skip the header
next(file)

if line2example is not None:
reader = (line2example(line) for line in file)
format = "custom"
mttk marked this conversation as resolved.
Show resolved Hide resolved
elif format in {"csv", "tsv"}:
delimiter = "," if format == "csv" else "\t"
if csv_reader_params is None:
csv_reader_params = {}
reader = csv.reader(file, delimiter=delimiter, **csv_reader_params)
elif format == "json":
reader = file
else:
raise ValueError(f"Invalid format: {format}")

return reader, format


def _create_examples(reader, format, fields):
"""
Creates a list of examples from the given line reader and fields (see
TabularDataset.__init__ docs for more info on the fields).

Parameters
----------
reader
A reader object that reads one line at a time. Yields either strings
(when format is JSON) or lists of values (when format is CSV/TSV).
A reader object that reads one line at a time. Yields strings
(when format is JSON) or lists of values (when format is CSV/TSV)
or a sequence of pre-transformed lines if a custom format is used
via the `line2example` argument.
format : str
Format of the data file that is being read. Can be either CSV,
TSV or JSON.
Expand Down Expand Up @@ -178,4 +216,4 @@ def create_examples(reader, format, fields):
# map each line from the reader to an example
examples = map(make_example, reader)

return list(examples)
return examples
6 changes: 4 additions & 2 deletions podium/vectorizers/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@ def __init__(
max_vectors : int, optional
maximum number of vectors to load in memory
"""
self._path = path
self._path = os.path.expanduser(path) if path is not None else path
self._default_vector_function = default_vector_function
self._cache_path = cache_path
self._cache_path = (
os.path.expanduser(cache_path) if cache_path is not None else cache_path
)
self._max_vectors = max_vectors

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_from_tabular(data, fields, tmpdir):
writer = csv.writer(f)
writer.writerows(data)

csv_dataset = DiskBackedDataset.from_tabular_file(test_file, "csv", fields)
csv_dataset = DiskBackedDataset.from_tabular_file(test_file, fields, format="csv")
for ex, d in zip(csv_dataset, data):
assert int(ex.number[0]) == d[0]
assert ex.tokens[0] == d[1]
Expand Down