Skip to content

Commit

Permalink
Fix bugs in iterator checkpointing, enable other last batch policies (#…
Browse files Browse the repository at this point in the history
…5298)

Fixed: iterator checkpointing tests, external source checkpointing, restoring internal state of the iterator. 
Added tests for different last batch policies. Enabled last_batch_policy=DROP.

---------

Signed-off-by: Szymon Karpiński <[email protected]>
  • Loading branch information
szkarpinski authored Feb 7, 2024
1 parent a0a3b7f commit a48c723
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 17 deletions.
7 changes: 3 additions & 4 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,8 +1061,6 @@ def outputs(self):
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
self._consumer_iter += 1
if self._batches_to_consume == 0:
self._consumer_iter = 0
self._consumer_epoch_idx += 1
raise StopIteration
self._batches_to_consume -= 1
return self._outputs()
Expand Down Expand Up @@ -1110,8 +1108,6 @@ def share_outputs(self):
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
self._consumer_iter += 1
if self._batches_to_consume == 0:
self._consumer_iter = 0
self._consumer_epoch_idx += 1
raise StopIteration
self._batches_to_consume -= 1
return self._pipe.ShareOutputs()
Expand Down Expand Up @@ -1326,6 +1322,9 @@ def reset(self):
self._first_iter = True
self._last_iter = False
self._epoch_idx += 1
if self._consumer_iter > 0:
self._consumer_epoch_idx += 1
self._consumer_iter = 0
if self._input_callbacks:
for group in self._input_callbacks:
group.reset_indices()
Expand Down
53 changes: 47 additions & 6 deletions dali/python/nvidia/dali/plugin/base_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,60 @@ def __init__(
)

if self._enable_checkpointing:
# Note: currently, checkpointing is not supported with last_batch_padded=False.
# It is verified in FileReader, where this assumption is needed.
# Adding this assertion here lead to problem when reader was not used in the pipeline.

if self._last_batch_policy == LastBatchPolicy.DROP:
if self._last_batch_policy == LastBatchPolicy.FILL and self._last_batch_padded is False:
raise NotImplementedError(
"Currently, checkpointing is not supported with last_batch_policy=DROP"
"Currently, checkpointing is not supported for iterators with "
+ "last_batch_policy=FILL and last_batch_padded=False"
)

# Precompute the initial checkpoints, to prevent any problems
# related to the `prepare_first_batch` flag.
self._initial_checkpoints = [p.checkpoint() for p in self._pipes]

if any(p.is_restored_from_checkpoint for p in self._pipes):
iters = [p._consumer_iter for p in self._pipes]
if not all(p.is_restored_from_checkpoint for p in self._pipes):
logging.warning(
"Some, but not all of the pipelines used were restored from checkpoint. "
+ "This iterator might produce unexpected results."
)
elif not all(i == iters[0] for i in iters):
logging.warning(
"The provided pipelines have different number of completed iterations. "
+ "This iterator might produce unexpected results."
)

self._restore_state(min(iters))

def _restore_state(self, pipeline_iterations):
"""
Restores state of the iterator to a state after `pipeline_iterations` iterations of the
pipeline.
"""

if self._last_batch_policy == LastBatchPolicy.FILL and self._last_batch_padded is False:
raise NotImplementedError(
"Currently, checkpointing is not supported for iterators with "
+ "last_batch_policy=FILL and last_batch_padded=False"
)

# In modes other than FILL + last_batch_padded=False, each epoch starts with the first
# sample of a shard and the number of pipeline iterations per epoch is constant.

size = self._shard_sizes_per_gpu.min() if self._reader_name else self._size
iters_per_epoch = (size + self.batch_size - 1) // self.batch_size
complete_epochs = (max(0, pipeline_iterations - 1)) // iters_per_epoch

self._counter = self.batch_size * (pipeline_iterations - complete_epochs * iters_per_epoch)

if not self._reader_name:
# If not in reader_name mode, the counter keeps the total count for all pipelines
self._counter *= self._num_gpus

if self._reader_name and not self._is_stick_to_shard:
self._shard_sizes_per_gpu = np.roll(self._shard_sizes_per_gpu, complete_epochs)
self._shards_id = (self._shards_id + complete_epochs) % self._num_gpus

def _calculate_shard_sizes(self, shard_nums):
shards_beg = np.floor(shard_nums * self._size_no_pad / self._shards_num)
shards_end = np.floor((shard_nums + 1) * self._size_no_pad / self._shards_num)
Expand Down
201 changes: 194 additions & 7 deletions dali/test/python/checkpointing/test_dali_checkpointing_fw_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import os
import tempfile

from checkpointing.test_dali_checkpointing import (
warmup_epochs,
pipeline_args,
Expand All @@ -23,6 +27,8 @@
import nvidia.dali.fn as fn
from nvidia.dali.pipeline import pipeline_def
from nose2.tools import params, cartesian_params
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nose import SkipTest


class FwTestBase:
Expand All @@ -40,22 +46,39 @@ def compare_outs(self, out1, out2):
assert self.equal(d1[key], d2[key])

def compare_iters(self, iter, iter2):
for out1, out2 in zip(iter, iter2):
outs1 = list(x for x in iter)
outs2 = list(x for x in iter2)
assert len(outs1) == len(outs2)
for out1, out2 in zip(outs1, outs2):
self.compare_outs(out1, out2)

def check_pipeline_checkpointing(self, pipeline_factory, reader_name=None, size=-1):
pipe = pipeline_factory(**pipeline_args)
pipe.build()

iter = self.FwIterator(pipe, ["data"], auto_reset=True, reader_name=reader_name, size=size)
iter = self.FwIterator(
pipe,
["data"],
auto_reset=True,
reader_name=reader_name,
size=size,
last_batch_policy=LastBatchPolicy.FILL,
last_batch_padded=True,
)
for _ in range(warmup_epochs):
for _ in iter:
pass

restored = pipeline_factory(**pipeline_args, checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = self.FwIterator(
restored, ["data"], auto_reset=True, reader_name=reader_name, size=size
restored,
["data"],
auto_reset=True,
reader_name=reader_name,
size=size,
last_batch_policy=LastBatchPolicy.FILL,
last_batch_padded=True,
)

self.compare_iters(iter, iter2)
Expand Down Expand Up @@ -129,6 +152,127 @@ def pipeline():

self.compare_iters(iter, iter2)

@dataclass
class DatasetConfig:
dataset_size: int
batch_size: int
num_shards: int

@cartesian_params(
(
DatasetConfig(dataset_size=11 + 11 + 12, batch_size=4, num_shards=3),
DatasetConfig(dataset_size=4 + 5, batch_size=3, num_shards=2),
),
(2, 3, 7),
(
# (last_batch_policy, pad_last_batch)
(LastBatchPolicy.FILL, True),
(LastBatchPolicy.DROP, True),
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.PARTIAL, True),
(LastBatchPolicy.PARTIAL, False),
),
(True, False), # stick_to_shard
)
def test_last_batch_policy(
self, dataset_config: DatasetConfig, iterations, last_batch_config, stick_to_shard
):
policy, pad_last_batch = last_batch_config
if last_batch_config not in self.supported_last_batch_policies():
raise SkipTest(
f"Policy {policy} with last_batch_padded={pad_last_batch} "
+ f"is not supported by {self.FwIterator}"
)
with tempfile.TemporaryDirectory() as data_dir:
os.mkdir(os.path.join(data_dir, "0"))
for i in range(dataset_config.dataset_size):
with open(os.path.join(data_dir, f"0/{i:02}.jpg"), "wb") as f:
f.write(bytes([i]))

def make_pipeline(shard_id, checkpoint=None):
@pipeline_def(
batch_size=dataset_config.batch_size,
enable_checkpointing=True,
num_threads=4,
device_id=0,
)
def pipeline():
data, _ = fn.readers.file(
file_root=data_dir,
name="Reader",
pad_last_batch=pad_last_batch,
num_shards=dataset_config.num_shards,
shard_id=shard_id,
stick_to_shard=stick_to_shard,
)
return data

p = pipeline(checkpoint=checkpoint)
p.build()
return p

def make_pipelines(checkpoints=None):
if not checkpoints:
return [
make_pipeline(shard_id) for shard_id in range(dataset_config.num_shards)
]
else:
assert len(checkpoints) == dataset_config.num_shards
return [
make_pipeline(shard_id, checkpoint=cpt)
for (shard_id, cpt) in zip(range(dataset_config.num_shards), checkpoints)
]

def make_iterator(pipes):
return self.FwIterator(
pipes,
output_map=["data"],
auto_reset=True,
last_batch_policy=policy,
prepare_first_batch=False,
reader_name="Reader",
)

pipes = make_pipelines()
it = make_iterator(pipes)

completed_iterations = 0
while completed_iterations < iterations:
try:
next(it)
completed_iterations += 1
except StopIteration:
pass

def observe(it, steps):
"""
Returns a list with data returned on each step or None if there was an epoch end.
This allows to compare behavior of two iterators precisely.
"""
results = []
for _ in range(steps):
try:
results.append(next(it))
except StopIteration:
results.append(None)
return results

pipes_restored = make_pipelines(it.checkpoints())
it_restored = make_iterator(pipes_restored)

steps = dataset_config.dataset_size * 2 // dataset_config.batch_size

a = observe(it, steps)
b = observe(it_restored, steps)

assert len(a) == len(b)

for x, y in zip(a, b):
if x is None or y is None:
assert x is None and y is None
else:
self.compare_outs(x, y)

# Random operators section

@cartesian_params(("cpu", "gpu"), (None, (1,), (10,)))
Expand Down Expand Up @@ -165,14 +309,27 @@ def run(iterator, iterations):
pipeline = pipeline_factory()
pipeline.build()

iter = self.FwIterator(pipeline, ["data"], auto_reset=True, size=size)
iter = self.FwIterator(
pipeline,
["data"],
auto_reset=True,
size=size,
last_batch_policy=LastBatchPolicy.FILL,
last_batch_padded=True,
)

run(iter, iterations)

restored = pipeline_factory(checkpoint=iter.checkpoints()[0])
restored.build()
iter2 = self.FwIterator(restored, ["data"], auto_reset=True, size=size)

iter2 = self.FwIterator(
restored,
["data"],
auto_reset=True,
size=size,
last_batch_policy=LastBatchPolicy.FILL,
last_batch_padded=True,
)
self.compare_iters(iter, iter2)

@cartesian_params(
Expand All @@ -185,7 +342,9 @@ def test_external_source_checkpointing(self, dataset_info, iterations, mode, par
epoch_size, batch_size = dataset_info
source = make_dummy_source(epoch_size, batch_size, mode)
pf = make_external_source_test_pipeline_factory(source, mode, batch_size, parallel)
self.check_external_source_pipeline_checkpointing(pf, iterations)
self.check_external_source_pipeline_checkpointing(
pf, iterations, size=epoch_size * batch_size
)


# Framework tests
Expand All @@ -201,6 +360,16 @@ def __init__(self):
def equal(self, a, b):
return (a == b).all()

def supported_last_batch_policies(self):
return (
# (last_batch_policy, pad_last_batch)
(LastBatchPolicy.DROP, True),
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.FILL, True),
(LastBatchPolicy.PARTIAL, False),
(LastBatchPolicy.PARTIAL, True),
)


class TestPytorchRagged(FwTestBase):
def __init__(self):
Expand All @@ -212,6 +381,16 @@ def __init__(self):
def equal(self, a, b):
return (a == b).all()

def supported_last_batch_policies(self):
return (
# (last_batch_policy, pad_last_batch)
(LastBatchPolicy.DROP, True),
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.FILL, True),
(LastBatchPolicy.PARTIAL, False),
(LastBatchPolicy.PARTIAL, True),
)


class TestJax(FwTestBase):
def __init__(self):
Expand All @@ -223,3 +402,11 @@ def __init__(self):
def compare_outs(self, out1, out2):
for key in out1.keys():
assert (out1[key] == out2[key]).all()

def supported_last_batch_policies(self):
return (
# (last_batch_policy, pad_last_batch)
(LastBatchPolicy.DROP, True),
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.FILL, True),
)

0 comments on commit a48c723

Please sign in to comment.