Skip to content

Commit

Permalink
Resolved #315, #316, #317
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Apr 2, 2024
1 parent 690b68c commit b85e6a9
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions nobrainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(
self.volume_shape = volume_shape
self.n_classes = n_classes

if self.n_classes < 1:
raise ValueError("n_classes must be > 0.")

@classmethod
def from_tfrecords(
cls,
Expand All @@ -80,6 +83,7 @@ def from_tfrecords(
n_classes=1,
tf_dataset_options=None,
num_parallel_calls=1,
label_mapping=None,
):
"""Function to retrieve a saved tf record as a nobrainer Dataset
Expand Down Expand Up @@ -123,6 +127,7 @@ def from_tfrecords(

if not n_volumes:
n_volumes = block_length * len(files)
print(f"n_volumes: {n_volumes}")

dataset = dataset.interleave(
map_func=lambda x: tf.data.TFRecordDataset(
Expand All @@ -138,7 +143,9 @@ def from_tfrecords(
if block_shape:
ds_obj.block(block_shape)
if not scalar_labels:
ds_obj.map_labels()
ds_obj.map_labels(
label_mapping=label_mapping, num_parallel_calls=num_parallel_calls
)
# TODO automatically determine batch size
ds_obj.batch(1)

Expand All @@ -158,6 +165,7 @@ def from_files(
n_classes=1,
block_shape=None,
tf_dataset_options=None,
label_mapping=None,
):
"""Create Nobrainer datasets from data
filepaths: List(str), list of paths to individual input data files.
Expand Down Expand Up @@ -221,6 +229,7 @@ def from_files(
block_shape=block_shape,
tf_dataset_options=tf_dataset_options,
num_parallel_calls=num_parallel_calls,
label_mapping=label_mapping,
)
ds_eval = None
if n_eval > 0:
Expand All @@ -234,6 +243,7 @@ def from_files(
block_shape=block_shape,
tf_dataset_options=tf_dataset_options,
num_parallel_calls=num_parallel_calls,
label_mapping=label_mapping,
)
return ds_train, ds_eval

Expand Down Expand Up @@ -315,19 +325,31 @@ def _f(x, y):
self.dataset = self.dataset.unbatch()
return self

def map_labels(self, label_mapping=None):
if self.n_classes < 1:
raise ValueError("n_classes must be > 0.")

def map_labels(self, label_mapping=None, num_parallel_calls=1):
if label_mapping is not None:
self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping)))
self.map(lambda x, y: (x, replace(y, mapping=label_mapping)))

if self.n_classes == 1:
self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1)))
self.map(
lambda x, y: (x, tf.expand_dims(binarize(y), -1)),
num_parallel_calls=num_parallel_calls,
)
elif self.n_classes == 2:
self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes)))
self.map(
lambda x, y: (
x,
tf.one_hot(tf.cast(binarize(y), dtype=tf.int32), self.n_classes),
),
num_parallel_calls=num_parallel_calls,
)
elif self.n_classes > 2:
self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes)))
self.map(
lambda x, y: (
x,
tf.one_hot(tf.cast(y, dtype=tf.int32), self.n_classes),
),
num_parallel_calls=num_parallel_calls,
)

return self

Expand Down

0 comments on commit b85e6a9

Please sign in to comment.