diff --git a/nobrainer/processing/segmentation.py b/nobrainer/processing/segmentation.py index c6805443..4b57e469 100644 --- a/nobrainer/processing/segmentation.py +++ b/nobrainer/processing/segmentation.py @@ -38,7 +38,6 @@ def __init__( self.volume_shape_ = None self.scalar_labels_ = None - def fit( self, dataset_train, @@ -58,7 +57,7 @@ def fit( batch_size = dataset_train.batch_size self.block_shape_ = dataset_train.block_shape self.volume_shape_ = dataset_train.volume_shape - self.scalar_labels_ = dataset_train.scalar_labels + # self.scalar_labels_ = dataset_train.scalar_labels n_classes = dataset_train.n_classes opt_args = opt_args or {} if optimizer is None: @@ -100,6 +99,9 @@ def _compile(): if callbacks is None: callbacks = [] + dataset_train.repeat(epochs) + dataset_validate.repeat(epochs) + if self.checkpoint_tracker: callbacks.append(self.checkpoint_tracker) self.model_.fit(