Skip to content

Commit

Permalink
Merge pull request #319 from neuronets/fix/pgan
Browse files Browse the repository at this point in the history
Fix PGAN notebook
  • Loading branch information
satra authored Apr 4, 2024
2 parents ab03eda + 71c3f63 commit b616a06
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import BaseEstimator
from .. import losses
from ..dataset import get_dataset
from ..dataset import Dataset


class ProgressiveGeneration(BaseEstimator):
Expand Down Expand Up @@ -147,15 +147,17 @@ def _compile():
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

dataset = get_dataset(
dataset = Dataset.from_tfrecords(
file_pattern=info.get("file_pattern"),
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
volume_shape=(resolution, resolution, resolution),
n_classes=1,
scalar_label=True,
normalizer=info.get("normalizer") or normalizer,
scalar_labels=True,
)
n_epochs = info.get("epochs") or epochs
dataset.batch(batch_size).normalize(
info.get("normalizer") or normalizer
).repeat(n_epochs)

with self.strategy.scope():
# grow the networks by one (2^x) resolution
Expand All @@ -164,9 +166,7 @@ def _compile():
self.model_.discriminator.add_resolution()
_compile()

steps_per_epoch = (info.get("epochs") or epochs) // info.get(
"batch_size"
)
steps_per_epoch = n_epochs // info.get("batch_size")

# save_best_only is set to False as it is an adversarial loss
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
Expand All @@ -182,7 +182,7 @@ def _compile():

print("Transition phase")
self.model_.fit(
dataset,
dataset.dataset,
phase="transition",
resolution=resolution,
steps_per_epoch=steps_per_epoch, # necessary for repeat dataset
Expand All @@ -191,7 +191,7 @@ def _compile():

print("Resolution phase")
self.model_.fit(
dataset,
dataset.dataset,
phase="resolution",
resolution=resolution,
steps_per_epoch=steps_per_epoch,
Expand Down

0 comments on commit b616a06

Please sign in to comment.