From 6827e15d93a3120a9e19f38d2477d8ff1eb584d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 24 May 2022 18:48:44 +0200 Subject: [PATCH] Add tests for operator cast. Revert to plain batched cast kernel until the optimized one is fixed. (#3927) * Add tests for operator cast. Revert to plain batched cast kernel (the BinSearch is broken). Signed-off-by: Michal Zientkiewicz --- dali/operators/generic/cast.cu | 19 ++++ dali/test/python/test_operator_cast.py | 139 ++++++++++++++++++------- 2 files changed, 118 insertions(+), 40 deletions(-) diff --git a/dali/operators/generic/cast.cu b/dali/operators/generic/cast.cu index 879b060951..1943f9d000 100644 --- a/dali/operators/generic/cast.cu +++ b/dali/operators/generic/cast.cu @@ -89,6 +89,24 @@ void CastGPU::RunImpl(DeviceWorkspace &ws) { } auto blocks = block_setup_.Blocks(); + + kernels::BlockDesc<1> *blocks_dev; + kernels::CastSampleDesc *samples_dev; + std::tie(blocks_dev, samples_dev) = scratchpad.ToContiguousGPU(ws.stream(), + blocks, samples_); + + DALIDataType itype = input.type(); + dim3 grid_dim = block_setup_.GridDim(); + dim3 block_dim = block_setup_.BlockDim(); + TYPE_SWITCH(output_type_, type2id, OType, CAST_ALLOWED_TYPES, ( + TYPE_SWITCH(itype, type2id, IType, CAST_ALLOWED_TYPES, ( + kernels::BatchedCastKernel + <<>>(samples_dev, blocks_dev); + ), DALI_FAIL(make_string("Invalid input type: ", itype));); // NOLINT(whitespace/parens) + ), DALI_FAIL(make_string("Invalid output type: ", output_type_));); // NOLINT(whitespace/parens) + + /* + TODO(michalz): Fix the kernel! // Calculate id of the earliest block that should process given sample for (int block_id = 0, sample_id = -1; block_id < blocks.size(); block_id++) { if (blocks[block_id].sample_idx != sample_id) { @@ -112,6 +130,7 @@ void CastGPU::RunImpl(DeviceWorkspace &ws) { num_samples, block_volume_scale); ), DALI_FAIL(make_string("Invalid input type: ", itype));); // NOLINT(whitespace/parens) ), DALI_FAIL(make_string("Invalid output type: ", output_type_));); // NOLINT(whitespace/parens) + */ } DALI_REGISTER_OPERATOR(Cast, CastGPU, GPU); diff --git a/dali/test/python/test_operator_cast.py b/dali/test/python/test_operator_cast.py index bdc28928bd..5a3f0a08fb 100644 --- a/dali/test/python/test_operator_cast.py +++ b/dali/test/python/test_operator_cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,46 +12,105 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nvidia.dali.pipeline import Pipeline -import nvidia.dali.ops as ops +import nose_utils +from nvidia.dali import pipeline_def +import nvidia.dali as dali +import nvidia.dali.fn as fn import nvidia.dali.types as types import numpy as np +from nose.tools import nottest -from test_utils import compare_pipelines -from test_utils import RandomlyShapedDataIterator - -class CastPipeline(Pipeline): - def __init__(self, device, batch_size, iterator, cast_dtypes, num_threads=1, device_id=0): - super(CastPipeline, self).__init__(batch_size, num_threads, device_id) - self.layout = "HWC" - self.device = device - self.iterator = iterator - self.inputs = ops.ExternalSource() - self.cast = [ops.Cast(device=device, dtype=dtype) for dtype in cast_dtypes] - - def define_graph(self): - self.data = self.inputs() - out = self.data.gpu() if self.device == 'gpu' else self.data - for k in range(len(self.cast)): - out = self.cast[k](out) - return out - - def iter_setup(self): - data = self.iterator.next() - self.feed_input(self.data, data, layout=self.layout) - -def check_cast_operator_float16(device, batch_size, in_type, out_type): - input_shape=(300, 400, 3) - eii1 = RandomlyShapedDataIterator(batch_size, max_shape=input_shape, dtype=in_type) - eii2 = RandomlyShapedDataIterator(batch_size, max_shape=input_shape, dtype=in_type) - compare_pipelines( - CastPipeline(device, batch_size, iter(eii1), [types.FLOAT16, out_type]), - CastPipeline(device, batch_size, iter(eii2), [out_type]), - batch_size=batch_size, N_iterations=5) - -def test_cast_operator_float16(): +from test_utils import check_batch, np_type_to_dali + + +def ref_cast(x, dtype): + if np.issubdtype(dtype, np.integer): + lo = np.iinfo(dtype).min + hi = np.iinfo(dtype).max + if np.issubdtype(x.dtype, np.floating): + x = np.round(x) + return x.clip(lo, hi).astype(dtype) + else: + return x.astype(dtype) + +def random_shape(rng, ndim: int, max_size: int): + if ndim == 0: + return [] + max_size = int(max_size ** (1/ndim)) + return list(rng.integers(0, max_size, [ndim])) + +def generate(rng, ndim: int, batch_size: int, in_dtype: np.dtype, out_dtype: np.dtype): + lo, hi = -1000, 1000 + if np.issubdtype(out_dtype, np.integer): + lo = np.iinfo(out_dtype).min + hi = np.iinfo(out_dtype).max + if hi < np.iinfo(np.int64).max: + r = hi - lo + hi += r // 2 + lo -= r // 2 + if np.issubdtype(in_dtype, np.integer): + lo = max(np.iinfo(in_dtype).min, lo) + hi = min(np.iinfo(in_dtype).max, hi) + else: + lo = max(-np.finfo(in_dtype).max, lo) + hi = min(np.finfo(in_dtype).max, hi) + + max_size = 100000 // batch_size + out = [rng.uniform(lo, hi, size=random_shape(rng, ndim, max_size)).astype(in_dtype) for _ in range(batch_size)] + if np.issubdtype(in_dtype, np.floating) and np.issubdtype(out_dtype, np.integer): + for x in out: + # avoid exactly halfway numbers - rounding is different for CPU and GPU + halfway = x[x - np.floor(x) == 0.5] + x[x - np.floor(x) == 0.5] = np.nextafter(halfway, np.Infinity) + return out + +rng = np.random.default_rng(1234) + +@nottest +def _test_operator_cast(ndim, batch_size, in_dtype, out_dtype, device): + src = lambda: generate(rng, ndim, batch_size, in_dtype, out_dtype) + @pipeline_def(batch_size=batch_size, num_threads=4, device_id=types.CPU_ONLY_DEVICE_ID if device == 'cpu' else 0) + def cast_pipe(): + inp = fn.external_source(src) + inp_dev = inp.gpu() if device == 'gpu' else inp + return inp, fn.cast(inp_dev, dtype=np_type_to_dali(out_dtype)) + + pipe = cast_pipe() + pipe.build() + for _ in range(10): + inp, out = pipe.run() + if device=='gpu': + out = out.as_cpu() + ref = [ref_cast(np.array(x), out_dtype) for x in inp] + + # work around a bug in numpy: when the argument is a scalar fp32 or fp16, nextafter + # promotes it to fp64, resulting in insufficient epsilon - we want an epsilon of the + # type specified in out_dtype + eps = 0 if np.issubdtype(out_dtype, np.integer) else (np.nextafter(out_dtype([1]), 2) - 1.0)[0] + + for i in range(batch_size): + if not np.allclose(out[i], ref[i], eps): + print("At sample", i) + I = np.array(inp[i]) + O = np.array(out[i]) + R = ref[i] + print(I) + print(R) + print(O) + mask = np.logical_not(np.isclose(O, R, eps)) + print("Differences at", mask) + print(I[mask]) + print(R[mask]) + print(O[mask]) + print(np.count_nonzero(mask), "wrong values out of", mask.size) + assert np.array_equal(out[i], ref[i]) + + +def test_operator_cast(): + types = [np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64, np.float16, np.float32] for device in ['cpu', 'gpu']: - for batch_size in [3]: - for in_type in [np.uint8, np.int64]: - for out_type in [types.FLOAT, types.INT8]: - yield check_cast_operator_float16, device, batch_size, in_type, out_type + for in_type in types: + for out_type in types: + ndim = rng.integers(0, 4) + batch_size = rng.integers(1, 11) + yield _test_operator_cast, ndim, batch_size, in_type, out_type, device