diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 30c5be53..3b7b010d 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -69,6 +69,9 @@ def __init__( self.volume_shape = volume_shape self.n_classes = n_classes + if self.n_classes < 1: + raise ValueError("n_classes must be > 0.") + @classmethod def from_tfrecords( cls, @@ -80,6 +83,7 @@ def from_tfrecords( n_classes=1, tf_dataset_options=None, num_parallel_calls=1, + label_mapping=None, ): """Function to retrieve a saved tf record as a nobrainer Dataset @@ -123,6 +127,7 @@ def from_tfrecords( if not n_volumes: n_volumes = block_length * len(files) + print(f"n_volumes: {n_volumes}") dataset = dataset.interleave( map_func=lambda x: tf.data.TFRecordDataset( @@ -138,7 +143,9 @@ def from_tfrecords( if block_shape: ds_obj.block(block_shape) if not scalar_labels: - ds_obj.map_labels() + ds_obj.map_labels( + label_mapping=label_mapping, num_parallel_calls=num_parallel_calls + ) # TODO automatically determine batch size ds_obj.batch(1) @@ -158,6 +165,7 @@ def from_files( n_classes=1, block_shape=None, tf_dataset_options=None, + label_mapping=None, ): """Create Nobrainer datasets from data filepaths: List(str), list of paths to individual input data files. @@ -221,6 +229,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) ds_eval = None if n_eval > 0: @@ -234,6 +243,7 @@ def from_files( block_shape=block_shape, tf_dataset_options=tf_dataset_options, num_parallel_calls=num_parallel_calls, + label_mapping=label_mapping, ) return ds_train, ds_eval @@ -315,19 +325,31 @@ def _f(x, y): self.dataset = self.dataset.unbatch() return self - def map_labels(self, label_mapping=None): - if self.n_classes < 1: - raise ValueError("n_classes must be > 0.") - + def map_labels(self, label_mapping=None, num_parallel_calls=1): if label_mapping is not None: - self.map(lambda x, y: (x, replace(y, label_mapping=label_mapping))) + self.map(lambda x, y: (x, replace(y, mapping=label_mapping))) if self.n_classes == 1: - self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) + self.map( + lambda x, y: (x, tf.expand_dims(binarize(y), -1)), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes == 2: - self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(binarize(y), dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) elif self.n_classes > 2: - self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes))) + self.map( + lambda x, y: ( + x, + tf.one_hot(tf.cast(y, dtype=tf.int32), self.n_classes), + ), + num_parallel_calls=num_parallel_calls, + ) return self