diff --git a/dali/operators/python_function/dltensor_function.cc b/dali/operators/python_function/dltensor_function.cc index f042a2cd0c..1167f1fb05 100644 --- a/dali/operators/python_function/dltensor_function.cc +++ b/dali/operators/python_function/dltensor_function.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -143,12 +143,13 @@ TensorListShape<> GetDLTensorListShape(const std::vector& dl_tenso template <> void CopyOutputData(TensorList &output, std::vector &dl_tensors, - int batch_size, Workspace &workspace) { + Workspace &workspace) { + int batch_size = dl_tensors.size(); auto &thread_pool = workspace.GetThreadPool(); auto out_shape = output.shape(); for (int i = 0; i < batch_size; ++i) { thread_pool.AddWork([&, i](int) { - CopyDlTensor(output.raw_mutable_tensor(i), dl_tensors[i]); + CopyDlTensorCpu(output.raw_mutable_tensor(i), dl_tensors[i]); }, out_shape.tensor_size(i)); } thread_pool.RunAll(); @@ -156,10 +157,8 @@ void CopyOutputData(TensorList &output, std::vector &d template <> void CopyOutputData(TensorList& output, std::vector &dl_tensors, - int batch_size, Workspace &workspace) { - for (int i = 0; i < batch_size; ++i) { - CopyDlTensor(output.raw_mutable_tensor(i), dl_tensors[i], workspace.stream()); - } + Workspace &workspace) { + CopyDlTensorBatchGpu(output, dl_tensors, workspace.stream()); } } // namespace detail diff --git a/dali/operators/python_function/dltensor_function.h b/dali/operators/python_function/dltensor_function.h index 64246f1306..1ebbd71792 100644 --- a/dali/operators/python_function/dltensor_function.h +++ b/dali/operators/python_function/dltensor_function.h @@ -17,9 +17,9 @@ #include #include #include -#include -#include #include +#include +#include #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/util/copy_with_stride.h" @@ -76,21 +76,6 @@ std::vector CastToDLTensorList(py::list &list, Index exp_size, Ind TensorListShape<> GetDLTensorListShape(const std::vector &dl_tensors); -template -void CopyDlTensor(void *out_data, DLMTensorPtr &dlm_tensor_ptr, cudaStream_t stream = 0) { - auto &dl_tensor = dlm_tensor_ptr->dl_tensor; - auto item_size = dl_tensor.dtype.bits / 8; - if (dl_tensor.strides) { - std::vector strides(dl_tensor.ndim); - for (Index i = 0; i < dl_tensor.ndim; ++i) strides[i] = dl_tensor.strides[i] * item_size; - CopyWithStride(out_data, dl_tensor.data, strides.data(), - dl_tensor.shape, dl_tensor.ndim, item_size, stream); - } else { - CopyWithStride(out_data, dl_tensor.data, nullptr, - dl_tensor.shape, dl_tensor.ndim, item_size, stream); - } -} - template py::list PrepareDLTensorInputs(Workspace &ws); @@ -99,7 +84,7 @@ py::list PrepareDLTensorInputsPerSample(Workspace &ws); template void CopyOutputData(Output& output, std::vector &dl_tensors, - int batch_size, Workspace &workspace); + Workspace &workspace); template void PrepareOutputs(Workspace &ws, const py::object &output_o, int batch_size) { @@ -110,7 +95,7 @@ void PrepareOutputs(Workspace &ws, const py::object &output_o, int batch_size) { if (dl_tensors.empty()) continue; auto &tlist = ws.Output(idx); tlist.Resize(GetDLTensorListShape(dl_tensors), DLToDALIType(dl_tensors[0]->dl_tensor.dtype)); - CopyOutputData(tlist, dl_tensors, batch_size, ws); + CopyOutputData(tlist, dl_tensors, ws); } } diff --git a/dali/pipeline/data/dltensor.cc b/dali/pipeline/data/dltensor.cc index 46b56dbc9a..67179fee5f 100644 --- a/dali/pipeline/data/dltensor.cc +++ b/dali/pipeline/data/dltensor.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -74,7 +74,7 @@ inline std::string to_string(const DLDataType &dl_type) { DALIDataType DLToDALIType(const DLDataType &dl_type) { DALI_ENFORCE(dl_type.lanes == 1, - "DALI Tensors do no not support types with the number of lanes other than 1"); + "DALI Tensors do not support types with the number of lanes other than 1"); switch (dl_type.code) { case kDLUInt: { switch (dl_type.bits) { diff --git a/dali/pipeline/util/copy_with_stride.cc b/dali/pipeline/util/copy_with_stride.cc index f708db3183..cc5afb01b9 100644 --- a/dali/pipeline/util/copy_with_stride.cc +++ b/dali/pipeline/util/copy_with_stride.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -84,13 +84,8 @@ inline Index DeepestContiguous(const Index *in_strides, return 0; } -template <> -void CopyWithStride(void *output, const void *input, - const Index *in_strides, - const Index *shape, - int ndim, - size_t item_size, - cudaStream_t) { +void CopyWithStrideCpu(void *output, const void *input, const Index *in_strides, const Index *shape, + int ndim, size_t item_size) { assert(ndim >= 0); if (!in_strides) { std::memcpy(output, input, item_size * volume(shape, shape + ndim)); @@ -106,4 +101,19 @@ void CopyWithStride(void *output, const void *input, shape, ndim, 0, deepest_contiguous); } +void CopyDlTensorCpu(void *out_data, DLMTensorPtr &dlm_tensor_ptr) { + auto &dl_tensor = dlm_tensor_ptr->dl_tensor; + auto item_size = dl_tensor.dtype.bits / 8; + if (dl_tensor.strides) { + std::vector strides(dl_tensor.ndim); + for (Index i = 0; i < dl_tensor.ndim; ++i) + strides[i] = dl_tensor.strides[i] * item_size; + CopyWithStrideCpu(out_data, dl_tensor.data, strides.data(), dl_tensor.shape, dl_tensor.ndim, + item_size); + } else { + CopyWithStrideCpu(out_data, dl_tensor.data, nullptr, dl_tensor.shape, dl_tensor.ndim, + item_size); + } +} + } // namespace dali diff --git a/dali/pipeline/util/copy_with_stride.cu b/dali/pipeline/util/copy_with_stride.cu index 17bd9479ca..c047170be7 100644 --- a/dali/pipeline/util/copy_with_stride.cu +++ b/dali/pipeline/util/copy_with_stride.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,58 +12,404 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dali/pipeline/util/copy_with_stride.h" #include -#include #include +#include +#include +#include +#include +#include "dali/pipeline/util/copy_with_stride.h" namespace dali { +namespace strided_copy { + constexpr int MAX_DIMS = 15; -__global__ void CopyWithStrideKernel(uint8_t *output, const uint8_t *input, Index size, - DeviceArray out_strides, - DeviceArray in_strides, - int ndim) { - auto out_idx = blockIdx.x * blockDim.x + threadIdx.x; - if (out_idx >= size) - return; - Index in_idx = 0; - Index elem_offset = out_idx; - for (int dim = 0; dim < ndim; ++dim) { +struct StridedCopyDesc { + void *output; + const void *input; + // input and output strides, both kept in the same order + // (the out_strides are decreasing) + int64_t in_strides[MAX_DIMS]; + int64_t out_strides[MAX_DIMS]; + // the size of the tensor in number of elements (not bytes) + int64_t size; + + // filled separately by FillSampleAlignmentInfo + // based on the above inforation and data type info + struct { + // the number of aligned elements, i.e. + // the total number of elements minus the skip_left and skip_right + int64_t size; + // the offset (as number of elements, not bytes) that + // need to be skipped from the start of the output + // tensor for the vectorized writes to be aligned + int skip_left; + // the number of elements that need to be skipped from the end + // of the sample for the vectorized write to fit in the sample + int skip_right; + } aligned; +}; + +template +struct alignas(Alignment) Vectorized { + T payload[NumElements]; // NOLINT(runtime/arrays) + + DALI_DEVICE DALI_FORCEINLINE T &operator[](int idx) { + return payload[idx]; + } +}; + +template +struct ElementTypeOfSize {}; + +template <> +struct ElementTypeOfSize<1> { + using type = uint8_t; +}; + +template <> +struct ElementTypeOfSize<2> { + using type = uint16_t; +}; + +template <> +struct ElementTypeOfSize<4> { + using type = uint32_t; +}; + +template <> +struct ElementTypeOfSize<8> { + using type = uint64_t; +}; + +template +struct ElementType { + static constexpr int vec_len = VecLen; + using type = typename ElementTypeOfSize::type; + using vec_type = Vectorized; +}; + + +/** + * @brief A helper wrapper to abstract away if the number of strides + * the copy kernel needs to use for mapping the output index into + * input index is known as a compile time constant or a runtime input. + * + * @tparam MaxNDim_ -1 for runtime ndim or the non-negative number of extents. + */ +template +struct MismatchedNdim { + DALI_DEVICE DALI_FORCEINLINE constexpr int operator()() { + return MaxNDim_; + } +}; + +template <> +struct MismatchedNdim<-1> { + MismatchedNdim(int max_ndim) : max_ndim_{max_ndim} {} // NOLINT(runtime/explicit) + DALI_DEVICE DALI_FORCEINLINE int operator()() { + return max_ndim_; + } + + private: + int max_ndim_; +}; + +/** + * @brief Takes a flat output index, recomputes the output coordiantes based on the out_strides + * and returns flat input index based on input strides. + * + */ +template +DALI_DEVICE DALI_FORCEINLINE int64_t GetInputIdx(MismatchedNdimT mismatched_ndim, int64_t out_idx, + const int64_t *in_strides, + const int64_t *out_strides) { + int64_t in_idx = 0; + int64_t elem_offset = out_idx; +#pragma unroll + for (int dim = 0; dim < mismatched_ndim(); ++dim) { auto n = elem_offset / out_strides[dim]; in_idx += n * in_strides[dim]; elem_offset -= n * out_strides[dim]; } - output[out_idx] = input[in_idx + elem_offset]; + return in_idx + elem_offset; } -template <> -void CopyWithStride(void *output, const void *input, - const Index *in_strides, - const Index *shape, - int ndim, - size_t item_size, - cudaStream_t stream) { - if (!in_strides) { - CUDA_CALL( - cudaMemcpyAsync(output, input, volume(shape, shape + ndim) * item_size, - cudaMemcpyDeviceToDevice, stream)); +/** + * @brief Copies element by element (in contrast to vectorized `AlignedCopy`) the unaligned + * ends of the sample. + * + * Assumes that there is very few padded and cropped elements + * (less than the ElementTypeDesc's vector type elements, typically 4). + * In particular, only a single block will be active. + */ +template +DALI_DEVICE DALI_FORCEINLINE void UnalignedCopy(const StridedCopyDesc &sample, + MismatchedNdimT mismatched_ndim) { + using T = typename ElementTypeDesc::type; + constexpr int vec_len = ElementTypeDesc::vec_len; + int skip_left = sample.aligned.skip_left; + int skip_right = sample.aligned.skip_right; + assert(2 * vec_len <= blockDim.x); + assert(skip_left < vec_len && skip_right < vec_len); + if (blockIdx.x == 0) { + const T *__restrict__ input = static_cast(sample.input); + T *__restrict__ output = static_cast(sample.output); + int padded_idx = threadIdx.x; + int cropped_idx = threadIdx.x - vec_len; + if (padded_idx < skip_left) { + auto in_idx = GetInputIdx(mismatched_ndim, padded_idx, sample.in_strides, sample.out_strides); + output[threadIdx.x] = input[in_idx]; + } else if (0 <= cropped_idx && cropped_idx < skip_right) { + int64_t idx = sample.size - skip_right + cropped_idx; + auto in_idx = GetInputIdx(mismatched_ndim, idx, sample.in_strides, sample.out_strides); + output[idx] = input[in_idx]; + } + } +} + +/** + * @brief Performs output-aligned copy. + * + * The input is read element-by-element but the output is stored in a vectorized type and + * written into global memory with wider, vectorized writes. + */ +template +DALI_DEVICE DALI_FORCEINLINE void AlignedCopy(const StridedCopyDesc &sample, + MismatchedNdimT mismatched_ndim) { + using T = typename ElementTypeDesc::type; + using VecT = typename ElementTypeDesc::vec_type; + constexpr int vec_len = ElementTypeDesc::vec_len; + const T *__restrict__ input = static_cast(sample.input); + VecT *__restrict__ output = + reinterpret_cast(static_cast(sample.output) + sample.aligned.skip_left); + for (int64_t flat_idx = vec_len * (blockIdx.x * blockDim.x + threadIdx.x); + flat_idx < sample.aligned.size; flat_idx += vec_len * blockDim.x * gridDim.x) { + VecT out_vec; +#pragma unroll + for (int i = 0; i < vec_len; i++) { + auto in_idx = GetInputIdx(mismatched_ndim, flat_idx + sample.aligned.skip_left + i, + sample.in_strides, sample.out_strides); + out_vec[i] = input[in_idx]; + } + output[flat_idx / vec_len] = out_vec; + } +} + +/** + * @brief Performs partially vectorized copy. + * + * The input is read element-by-element but the output is stored in a vectorized type and + * written into global memory with wider, vectorized writes. + * This benefits performance by 1. vectorized writes and 2. utilization of + * cache in the reads if the input happens to be mostly compact + * (for example the strides are due to row-major image with padded rows). + * + * If the output base address or size is not aligned with the vectorized type, the + * begining and end of the sample is handled separately. + */ +template +__global__ void BatchedCopy(const StridedCopyDesc *sample_descs, MismatchedNdimT mismatched_ndim) { + using T = typename ElementTypeDesc::type; + using VecT = typename ElementTypeDesc::vec_type; + static_assert(sizeof(VecT) == sizeof(T) * ElementTypeDesc::vec_len); + const auto sample = sample_descs[blockIdx.y]; + if constexpr (!IsOutputAligned) { + UnalignedCopy(sample, mismatched_ndim); + } + AlignedCopy(sample, mismatched_ndim); +} + +template +void FillSampleAlignmentInfo(StridedCopyDesc &sample) { + using T = typename ElementTypeDesc::type; + using VecT = typename ElementTypeDesc::vec_type; + constexpr int vec_len = ElementTypeDesc::vec_len; + static_assert(vec_len >= alignof(VecT) / sizeof(T)); + auto output_base_addr = reinterpret_cast(sample.output); + auto aligned_output_addr = align_up(output_base_addr, sizeof(T) * vec_len); + sample.aligned.skip_left = (aligned_output_addr - output_base_addr) / sizeof(T); + assert(0 <= sample.aligned.skip_left && sample.aligned.skip_left < vec_len); + sample.aligned.skip_left = std::min(sample.size, sample.aligned.skip_left); + int64_t remaining_size = sample.size - sample.aligned.skip_left; + assert(0 <= remaining_size && remaining_size < sample.size); + sample.aligned.size = align_down(remaining_size, vec_len); + sample.aligned.skip_right = remaining_size - sample.aligned.size; + assert(0 <= sample.aligned.skip_right && sample.aligned.skip_right < vec_len); + assert(sample.aligned.skip_left + sample.aligned.skip_right + sample.aligned.size == sample.size); +} + +bool IsAligned(const StridedCopyDesc &sample) { + return sample.aligned.skip_left == 0 && sample.aligned.skip_right == 0; +} + +template +void CopyBatchTyped(span sample_descs, MismatchedNdimT mismatched_ndim, + cudaStream_t stream) { + kernels::DynamicScratchpad scratchpad({}, stream); + using T = ElementType; + constexpr unsigned int kMaxBlockSize = 1024u; + static constexpr int kBlockSize = 128; + int64_t max_vol = 0; + bool has_aligned_output = true; + for (auto &sample_desc : sample_descs) { + FillSampleAlignmentInfo(sample_desc); + has_aligned_output &= IsAligned(sample_desc); + // if needed, the first block for given sample handles unaligned writes on top + // of the "aligned work". if the sample is small enough that there is nothing + // left after the alignement is considered, still make sure to launch a single + // block for the unaligned elements + int64_t tensor_vol = sample_desc.aligned.size; + if (!tensor_vol) { + tensor_vol = sample_desc.aligned.skip_left + sample_desc.aligned.skip_right; + } + max_vol = std::max(max_vol, tensor_vol); + } + const StridedCopyDesc *sample_descs_gpu; + std::tie(sample_descs_gpu) = scratchpad.ToContiguousGPU(stream, sample_descs); + unsigned int blocks_num = div_ceil(max_vol, T::vec_len * kBlockSize); + blocks_num = std::min(blocks_num, kMaxBlockSize); + unsigned int num_samples = sample_descs.size(); + dim3 grid = {blocks_num, num_samples, 1}; + BOOL_SWITCH(has_aligned_output, HasAlignedOutput, + (BatchedCopy + <<>>(sample_descs_gpu, mismatched_ndim);)); +} + + +void CopyBatch(span sample_descs, int max_mismatched_ndim, int element_size, + cudaStream_t stream) { + VALUE_SWITCH(element_size, ElementSize, (1, 2, 4, 8), ( + VALUE_SWITCH(max_mismatched_ndim, NDim, (0, 1, 2, 3, 4, 5), ( + CopyBatchTyped(sample_descs, MismatchedNdim{}, stream); + ), ( // NOLINT + CopyBatchTyped(sample_descs, MismatchedNdim<-1>(max_mismatched_ndim), stream); + )); // NOLINT + ), DALI_FAIL(make_string("Unsupported element size: ", element_size));); // NOLINT +} + +} // namespace strided_copy + +void ValidateBatch(int &element_size, int &ndim, std::vector &dl_tensors, + int batch_size) { + int num_bits; + for (int i = 0; i < batch_size; i++) { + auto &dlm_tensor_ptr = dl_tensors[i]; + auto &dl_tensor = dlm_tensor_ptr->dl_tensor; + int lanes = dl_tensor.dtype.lanes; + DALI_ENFORCE(lanes == 1, make_string("DALI Tensors do not support types with the number of " + "lanes other than 1, got tensor with `", + lanes, "` lanes.")); + if (i == 0) { + num_bits = dl_tensor.dtype.bits; + ndim = dl_tensor.ndim; + } else { + DALI_ENFORCE(num_bits == dl_tensor.dtype.bits, + "All tensors in the DALI batch must have the same type."); + int cur_ndim = dl_tensor.ndim; + DALI_ENFORCE( + ndim == cur_ndim, + make_string("All tensors in the DALI batch must have the same number of extents. Got " + "tensors with `", + ndim, "` and `", cur_ndim, "` dims.")); + } + } + // Limitation based on DLToDALIType + DALI_ENFORCE(num_bits == 8 || num_bits == 16 || num_bits == 32 || num_bits == 64, + "Unsupported data type width. Currently DALI tensors support only types of 8, 16, " + "32 or 64 bits"); + DALI_ENFORCE(0 <= ndim && ndim <= strided_copy::MAX_DIMS, + make_string("DALI tensor must have between 0 and ", strided_copy::MAX_DIMS, + " dimensions, got tensor with `", ndim, "` dimensions.")); + element_size = num_bits / 8; +} + + +/** + * @brief Copies batch of DlTensors (which may be strided) into a batch of DALI tensors (which are + * dense/compact). + * + * For the input DlTensors that are not strided, we simply run the cudaMemcpyAsync. Otherwise, a + * copy kernel is used. The copy kernel will go over the output DALI tensors linearly (the tensor is + * compact/dese) and translate the flat output indicies into input indicies. + * + * The input batch is validated against some of DALI batch requirements, such as uniform data + * type and dimensionality. + * + * @param output + * @param dl_tensors + * @param batch_size + * @param stream + */ +void CopyDlTensorBatchGpu(TensorList &output, std::vector &dl_tensors, + cudaStream_t stream) { + int batch_size = dl_tensors.size(); + if (batch_size <= 0) { return; } - DeviceArray out_strides{}; - out_strides[ndim - 1] = item_size; - for (int i = ndim - 2; i >= 0; --i) { - out_strides[i] = out_strides[i + 1] * shape[i + 1]; + int element_size, ndim; + ValidateBatch(element_size, ndim, dl_tensors, batch_size); + SmallVector sample_descs; + const auto cuda_mem_copy = [&output, element_size, stream](int sample_idx, + const auto &dl_tensor) { + void *out_data = output.raw_mutable_tensor(sample_idx); + auto size = volume(dl_tensor.shape, dl_tensor.shape + dl_tensor.ndim) * element_size; + CUDA_CALL(cudaMemcpyAsync(out_data, dl_tensor.data, size, cudaMemcpyDeviceToDevice, stream)); + }; + // If some innermost (the smallest in DALI tensor) strides match the strides of the incoming + // DlPack tensor, we can stop the translation from output index to input index early. For that, + // we need to keep track of how many outermost dimensions actually mismatch. + int max_mismatched_ndim = 0; + for (int sample_idx = 0; sample_idx < batch_size; sample_idx++) { + auto &dlm_tensor_ptr = dl_tensors[sample_idx]; + auto &dl_tensor = dlm_tensor_ptr->dl_tensor; + if (!dl_tensor.strides) { + cuda_mem_copy(sample_idx, dl_tensor); + continue; + } + strided_copy::StridedCopyDesc sample_desc; + sample_desc.output = output.raw_mutable_tensor(sample_idx); + sample_desc.input = dl_tensor.data; + sample_desc.size = volume(dl_tensor.shape, dl_tensor.shape + ndim); + // nothing to do for empty tensor + if (!sample_desc.size) { + continue; + } + // compute input and the compact output strides first, if they match + for (int d = 0; d < dl_tensor.ndim; d++) { + sample_desc.in_strides[d] = dl_tensor.strides[d]; + } + if (ndim > 0) { + sample_desc.out_strides[ndim - 1] = 1; + for (int d = ndim - 2; d >= 0; --d) { + sample_desc.out_strides[d] = sample_desc.out_strides[d + 1] * dl_tensor.shape[d + 1]; + } + } + // if the strides match (and given that the out strides are compact/dense), + // we can just go with cudamemcpy + { + int mismatched_ndim = ndim; + for (int d = ndim - 1; d >= 0; d--) { + if (sample_desc.in_strides[d] != sample_desc.out_strides[d]) { + break; + } + mismatched_ndim--; + } + if (mismatched_ndim == 0) { + cuda_mem_copy(sample_idx, dl_tensor); + continue; + } + max_mismatched_ndim = std::max(max_mismatched_ndim, mismatched_ndim); + } + // otherwise, when the strides do not match, add it + // to the vector with samples for the kernel + sample_descs.push_back(sample_desc); + } + if (sample_descs.size() > 0) { + strided_copy::CopyBatch(make_span(sample_descs), max_mismatched_ndim, element_size, stream); } - DeviceArray in_strides_arr{}; - std::copy(in_strides, in_strides + ndim, in_strides_arr.data()); - Index size = volume(shape, shape + ndim) * item_size; - auto blocks_num = (size + 1023) / 1024; - auto block_size = (size < 1024) ? size : 1024; - CopyWithStrideKernel<<>> - (static_cast(output), static_cast(input), - size, out_strides, in_strides_arr, ndim); } } // namespace dali diff --git a/dali/pipeline/util/copy_with_stride.h b/dali/pipeline/util/copy_with_stride.h index 8dd6d7a00b..2ceb05b5b2 100644 --- a/dali/pipeline/util/copy_with_stride.h +++ b/dali/pipeline/util/copy_with_stride.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,17 +17,19 @@ #include #include "dali/core/common.h" +#include "dali/core/span.h" +#include "dali/core/util.h" +#include "dali/kernels/dynamic_scratchpad.h" #include "dali/pipeline/data/backend.h" +#include "dali/pipeline/data/dltensor.h" namespace dali { -template -DLL_PUBLIC void CopyWithStride(void *output, const void *input, - const Index *in_strides, - const Index *shape, - int ndim, - size_t item_size, - cudaStream_t stream = 0); +DLL_PUBLIC void CopyDlTensorCpu(void *out_data, DLMTensorPtr &dlm_tensor_ptr); + +DLL_PUBLIC void CopyDlTensorBatchGpu(TensorList &output, + std::vector &dl_tensors, + cudaStream_t stream); } // namespace dali diff --git a/dali/pipeline/util/copy_with_stride_test.cc b/dali/pipeline/util/copy_with_stride_test.cc index 75ecdcbc66..ca8282898e 100644 --- a/dali/pipeline/util/copy_with_stride_test.cc +++ b/dali/pipeline/util/copy_with_stride_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,101 +12,242 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "dali/pipeline/util/copy_with_stride.h" #include #include -#include "dali/pipeline/util/copy_with_stride.h" +#include #include "dali/core/dev_buffer.h" +#include "dali/pipeline/data/dltensor.h" namespace dali { TEST(CopyWithStrideTest, OneDim) { - float data[] = {1, 2, 3, 4, 5, 6}; - std::array out; - Index stride = 2 * sizeof(float); - Index shape = 3; - CopyWithStride(out.data(), data, &stride, &shape, 1, sizeof(float)); - ASSERT_TRUE((out == std::array{1, 3, 5})); + const auto dtype = DALI_FLOAT; + using T = float; + T data[] = {1, 2, 3, 4, 5, 6}; + TensorShape<1> stride{2}; + TensorShape<1> shape{3}; + constexpr int vol = 3; + ASSERT_EQ(vol, volume(shape)); + std::array out; + DLTensorResource resource(shape); + resource.strides = stride; + auto dl_tensor = + MakeDLTensor(data, dtype, false, -1, std::make_unique(resource)); + CopyDlTensorCpu(out.data(), dl_tensor); + ASSERT_TRUE((out == std::array{1, 3, 5})); } TEST(CopyWithStrideTest, TwoDims) { - size_t data[] = {11, 12, 13, 14, - 21, 22, 23, 24, - 31, 32, 33, 34, - 41, 42, 43, 44}; - std::array out; - Index stride[] = {8 * sizeof(size_t), sizeof(size_t)}; - Index shape[] = {2, 4}; - CopyWithStride(out.data(), data, stride, shape, 2, sizeof(size_t)); - ASSERT_TRUE((out == std::array{11, 12, 13, 14, - 31, 32, 33, 34})); + const auto dtype = DALI_INT64; + using T = int64_t; + T data[] = {11, 12, 13, 14, + 21, 22, 23, 24, + 31, 32, 33, 34, + 41, 42, 43, 44}; + TensorShape<2> stride{8, 1}; + TensorShape<2> shape{2, 4}; + constexpr int vol = 8; + ASSERT_EQ(vol, volume(shape)); + std::array out; + DLTensorResource resource(shape); + resource.strides = stride; + auto dl_tensor = + MakeDLTensor(data, dtype, false, -1, std::make_unique(resource)); + CopyDlTensorCpu(out.data(), dl_tensor); + ASSERT_TRUE((out == std::array{11, 12, 13, 14, + 31, 32, 33, 34})); } TEST(CopyWithStrideTest, SimpleCopy) { - uint8 data[] = {1, 2, - 3, 4, - - 5, 6, - 7, 8}; - std::array out; - Index stride[] = {4, 2, 1}; - Index shape[] = {2, 2, 2}; - CopyWithStride(out.data(), data, stride, shape, 3, 1); - ASSERT_TRUE((out == std::array{1, 2, - 3, 4, - - 5, 6, - 7, 8})); + const auto dtype = DALI_UINT8; + using T = uint8_t; + T data[] = {1, 2, + 3, 4, + + 5, 6, + 7, 8}; + TensorShape<3> stride{4, 2, 1}; + TensorShape<3> shape{2, 2, 2}; + constexpr int vol = 8; + ASSERT_EQ(vol, volume(shape)); + std::array out; + DLTensorResource resource(shape); + resource.strides = stride; + auto dl_tensor = + MakeDLTensor(data, dtype, false, -1, std::make_unique(resource)); + CopyDlTensorCpu(out.data(), dl_tensor); + ASSERT_TRUE((out == std::array{1, 2, + 3, 4, + + 5, 6, + 7, 8})); +} + +DLMTensorPtr AsDlTensor(void* data, DALIDataType dtype, TensorShape<> shape, TensorShape<> stride) { + DLTensorResource resource(shape); + resource.strides = stride; + return MakeDLTensor(data, dtype, true, 0, std::make_unique(resource)); +} + +std::vector DlTensorSingletonBatch(DLMTensorPtr dl_tensor) { + std::vector dl_tensors; + dl_tensors.push_back(std::move(dl_tensor)); + return dl_tensors; +} + +TensorList SingletonTL(TensorShape<> shape, DALIDataType dtype) { + TensorList output; + TensorListShape tls(1, shape.sample_dim()); + tls.set_tensor_shape(0, shape); + output.Resize(tls, dtype); + return output; } TEST(CopyWithStrideTest, OneDimGPU) { - float h_data[] = {1, 2, 3, 4, 5, 6}; - Index stride = 2 * sizeof(float); - Index shape = 3; - DeviceBuffer data, out; + const auto dtype = DALI_FLOAT; + using T = float; + T h_data[] = {1, 2, 3, 4, 5, 6}; + DeviceBuffer data; data.from_host(h_data); - out.resize(data.size()); - CopyWithStride(out, data, &stride, &shape, 1, sizeof(float)); - std::array h_out; - CUDA_CALL(cudaMemcpy(h_out.data(), out, 3 * sizeof(float), cudaMemcpyDeviceToHost)); - ASSERT_TRUE((h_out == std::array{1, 3, 5})); + TensorShape<1> stride{2}; + TensorShape<1> shape{3}; + constexpr int vol = 3; + ASSERT_EQ(vol, volume(shape)); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + ASSERT_TRUE((h_out == std::array{1, 3, 5})); } TEST(CopyWithStrideTest, TwoDimsGPU) { - size_t h_data[] = {11, 12, 13, 14, - 21, 22, 23, 24, - 31, 32, 33, 34, - 41, 42, 43, 44}; - Index stride[] = {8 * sizeof(size_t), sizeof(size_t)}; - Index shape[] = {2, 4}; - DeviceBuffer data, out; + const auto dtype = DALI_INT64; + using T = int64_t; + T h_data[] = {11, 12, 13, 14, + 21, 22, 23, 24, + 31, 32, 33, 34, + 41, 42, 43, 44}; + TensorShape<2> stride{8, 1}; + TensorShape<2> shape{2, 4}; + constexpr int vol = 8; + ASSERT_EQ(vol, volume(shape)); + DeviceBuffer data; data.from_host(h_data); - out.resize(data.size()); - CopyWithStride(out, data, stride, shape, 2, sizeof(size_t)); - std::array h_out; - CUDA_CALL(cudaMemcpy(h_out.data(), out, 8 * sizeof(size_t), cudaMemcpyDeviceToHost)); - ASSERT_TRUE((h_out == std::array{11, 12, 13, 14, - 31, 32, 33, 34})); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + ASSERT_TRUE((h_out == std::array{11, 12, 13, 14, 31, 32, 33, 34})); +} + +TEST(CopyWithStrideTest, TwoDimsGPUOdd) { + const auto dtype = DALI_UINT8; + using T = uint8_t; + T h_data[] = {1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30}; + TensorShape<2> stride{15, 1}; + TensorShape<2> shape{2, 4}; + constexpr int vol = 8; + ASSERT_EQ(vol, volume(shape)); + DeviceBuffer data; + data.from_host(h_data); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + ASSERT_TRUE((h_out == std::array{1, 2, 3, 4, 16, 17, 18, 19})); +} + +TEST(CopyWithStrideTest, TwoDimsInnerStride) { + const auto dtype = DALI_UINT8; + using T = uint8_t; + T h_data[] = {1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30}; + TensorShape<2> stride{15, 5}; + TensorShape<2> shape{2, 3}; + constexpr int vol = 6; + ASSERT_EQ(vol, volume(shape)); + DeviceBuffer data; + data.from_host(h_data); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + ASSERT_TRUE((h_out == std::array{1, 6, 11, 16, 21, 26})); +} + +TEST(CopyWithStrideTest, TwoDimsTransposed) { + const auto dtype = DALI_UINT16; + using T = uint16_t; + T h_data[] = {1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30}; + TensorShape<2> stride{1, 5}; + TensorShape<2> shape{5, 6}; + constexpr int vol = 30; + ASSERT_EQ(vol, volume(shape)); + DeviceBuffer data; + data.from_host(h_data); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + std::array ref = { + 1, 6, 11, 16, 21, 26, + 2, 7, 12, 17, 22, 27, + 3, 8, 13, 18, 23, 28, + 4, 9, 14, 19, 24, 29, + 5, 10, 15, 20, 25, 30}; + ASSERT_TRUE(h_out == ref); } TEST(CopyWithStrideTest, SimpleCopyGPU) { - uint8 h_data[] = {1, 2, - 3, 4, - - 5, 6, - 7, 8}; - Index stride[] = {4, 2, 1}; - Index shape[] = {2, 2, 2}; - DeviceBuffer data, out; + const auto dtype = DALI_FLOAT; + using T = float; + T h_data[] = {1, 2, 3, + 4, 5, 6, + + 7, 8, 9, + 10, 11, 12}; + TensorShape<3> stride{6, 3, 1}; + TensorShape<3> shape{2, 2, 3}; + constexpr int vol = 12; + ASSERT_EQ(vol, volume(shape)); + DeviceBuffer data; data.from_host(h_data); - out.resize(data.size()); - CopyWithStride(out, data, stride, shape, 3, sizeof(uint8)); - std::array h_out; - CUDA_CALL(cudaMemcpy(h_out.data(), out, 8 * sizeof(uint8), cudaMemcpyDeviceToHost)); - ASSERT_TRUE((h_out == std::array{1, 2, - 3, 4, - - 5, 6, - 7, 8})); + auto dl_tensors = DlTensorSingletonBatch(AsDlTensor(data, dtype, shape, stride)); + auto output_tl = SingletonTL(shape, dtype); + CopyDlTensorBatchGpu(output_tl, dl_tensors, 0); + std::array h_out; + CUDA_CALL(cudaMemcpy(h_out.data(), output_tl.raw_mutable_tensor(0), vol * sizeof(T), + cudaMemcpyDeviceToHost)); + ASSERT_TRUE((h_out == std::array{1, 2, 3, + 4, 5, 6, + + 7, 8, 9, + 10, 11, 12})); } } // namespace dali diff --git a/dali/test/python/test_dltensor_operator.py b/dali/test/python/test_dltensor_operator.py index 35155a2e0d..94321a15e3 100644 --- a/dali/test/python/test_dltensor_operator.py +++ b/dali/test/python/test_dltensor_operator.py @@ -19,6 +19,8 @@ import random from functools import partial from nvidia.dali.pipeline import Pipeline +from nvidia.dali import fn, pipeline_def +from nvidia.dali.python_function_plugin import current_dali_stream test_data_root = os.environ['DALI_EXTRA_PATH'] images_dir = os.path.join(test_data_root, 'db', 'single', 'jpeg') @@ -188,6 +190,9 @@ def test_pytorch(): for device in ['cpu', 'gpu']: yield pytorch_case, testcase, device + yield from _gpu_sliced_torch_suite() + yield from _gpu_permuted_extents_torch_suite() + def mxnet_adapter(fun, in1, in2): tin1 = [mxnd.from_dlpack(dltensor) for dltensor in in1] @@ -326,8 +331,350 @@ def test_cupy(): print(cupy) for testcase in [cupy_simple, cupy_kernel_square_diff, cupy_kernel_mix_channels]: yield cupy_case, testcase + yield from _cupy_flip_with_negative_strides_suite() def test_cupy_kernel_gray_scale(): setup_cupy() cupy_case(cupy_kernel_gray_scale, synchronize=False) + + +# ---------------- test strided copy kernel with strided tensors ----------------- + + +def get_random_torch_batch(g, shapes, dtype): + is_fp = torch.is_floating_point(torch.tensor([], dtype=dtype)) + if is_fp: + return [torch.rand((shape), generator=g, dtype=dtype) for shape in shapes] + else: + iinfo = torch.iinfo(dtype) + dtype_min, dtype_max = iinfo.min, iinfo.max + return [ + torch.randint(dtype_min, dtype_max, shape, generator=g, dtype=dtype) for shape in shapes + ] + + +def get_sliced_torch_case(case_name): + # [(extents of the original shape), (slice of the corresponding extent)] + # the original extents and slice shapes are purposely all prime numbers + # to test handling of unaligned tensors + prime_images = [ + ((107, 181, 3), (slice(1, 102), slice(179), slice(None))), + ((1097, 227, 5), (slice(None), slice(None), slice(1, 4))), + ((107, 167, 1), (slice(1, 14), slice(None), slice(None))), + ((107, 23, 3), (slice(103), slice(None), slice(None))), + ((173, 23, 5), (slice(None), slice(None), slice(1, 1))), + ((401, 167, 5), (slice(4, 167), slice(None), slice(0, 3))), + ((181, 401, 5), (slice(2, None), slice(397), slice(None))), + ((181, 107, 1), (slice(179), slice(103), slice(1))), + ((373, 181, 5), (slice(None), slice(None), slice(None, None, 2))), + ((199, 401, 3), (slice(None), slice(None), slice(None))), + ((167, 1097, 1), (slice(8, None, 7), slice(24, None, 23), slice(None))), + ((181, 61, 1), (slice(179), slice(58, None), slice(None))), + ((401, 61, 1), (slice(397), slice(None), slice(None))), + ((373, 173, 1), (slice(None), slice(167), slice(None))), + ((173, 199, 3), (slice(None), slice(None), slice(2, 3))), + ((181, 1097, 1), (slice(2, None, None), slice(1093), slice(None))), + ] + + prime_grey_images = [((199, 23), (slice(None, 173, None), slice(None, 19, None))), + ((373, 373), (slice(None, 331, None), slice(42, None, None))), + ((1097, 181), (slice(114, None, None), slice(None, 157, None))), + ((61, 227), (slice(None, 53, None), slice(28, None, None))), + ((1097, 61), (slice(114, None, None), slice(None, 53, None))), + ((181, 199), (slice(None, 157, None), slice(None, 173, None))), + ((1097, 1097), (slice(114, None, None), slice(None, 983, None))), + ((373, 227), (slice(42, None, None), slice(None, 199, None))), + ((227, 173), (slice(None, 199, None), slice(None, 151, None))), + ((227, 173), (slice(None, 199, None), slice(22, None, None))), + ((401, 173), (slice(42, None, None), slice(None, 151, None))), + ((107, 23), (slice(18, None, None), slice(None, 19, None))), + ((23, 199), (slice(4, None, None), slice(26, None, None))), + ((199, 23), (slice(26, None, None), slice(4, None, None))), + ((227, 23), (slice(None, 199, None), slice(None, 19, None))), + ((23, 23), (slice(4, None, None), slice(4, None, None))), + ((167, 181), (slice(18, None, None), slice(24, None, None))), + ((167, 181), (slice(18, None, None), slice(24, None, None))), + ((181, 227), (slice(None, 157, None), slice(None, 199, None))), + ((401, 199), (slice(None, 359, None), slice(None, 173, None))), + ((107, 181), (slice(None, 89, None), slice(None, 157, None))), + ((173, 61), (slice(None, 151, None), slice(8, None, None))), + ((227, 167), (slice(None, 199, None), slice(18, None, None))), + ((173, 401), (slice(22, None, None), slice(None, 359, None))), + ((23, 227), (slice(4, None, None), slice(28, None, None))), + ((227, 23), (slice(28, None, None), slice(4, None, None))), + ((373, 373), (slice(42, None, None), slice(None, 331, None))), + ((61, 107), (slice(None, 53, None), slice(18, None, None))), + ((181, 61), (slice(24, None, None), slice(None, 53, None))), + ((107, 181), (slice(None, 89, None), slice(24, None, None))), + ((401, 23), (slice(42, None, None), slice(4, None, None))), + ((373, 401), (slice(None, 331, None), slice(42, None, None)))] + + vid = [((17, ) + shape, (slice(None), ) + sl) for shape, sl in prime_images] + + ndim_11 = [(tuple(3 if i == j else 1 for j in range(11)) + shape, ((slice(None), ) * 11) + sl) + for i, (shape, sl) in enumerate(prime_images)] + + cases = { + "slice_images": prime_images, + "slice_grey_images": prime_grey_images, + "slice_vid": vid, + "slice_ndim_11": ndim_11 + } + shape_slices = cases[case_name] + shapes, slices = tuple(zip(*shape_slices)) + assert len(shapes) == len(slices) == len(shape_slices) + return shapes, slices + + +def _gpu_sliced_torch_case(case_name, dtype, g): + + shapes, slices = get_sliced_torch_case(case_name) + input_batch = get_random_torch_batch(g, shapes, dtype) + assert len(input_batch) == len(shapes) + + # returns sliced view of the input tensors + def sliced_tensor(batch): + stream = current_dali_stream() + torch_stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(torch_stream): + tensors = [torch_dlpack.from_dlpack(t) for t in batch] + assert len(tensors) == len(slices) + tensor_views = [t[sl] for t, sl in zip(tensors, slices)] + out = [torch_dlpack.to_dlpack(t) for t in tensor_views] + return out + + @pipeline_def(batch_size=len(input_batch), num_threads=4, device_id=0) + def pipeline(): + data = fn.external_source(lambda: input_batch) + data = fn.dl_tensor_python_function(data.gpu(), batch_processing=True, + function=sliced_tensor, synchronize_stream=False) + return data + + p = pipeline() + p.build() + out, = p.run() + + out = [numpy.array(sample) for sample in out.as_cpu()] + ref = [numpy.array(sample)[sl] for sample, sl in zip(input_batch, slices)] + + numpy.testing.assert_equal(out, ref) + + +def _gpu_sliced_torch_suite(): + + g = torch.Generator() + g.manual_seed(42) + + for case_name in ("slice_images", "slice_grey_images", "slice_vid", "slice_ndim_11"): + for dtype in (torch.uint8, torch.int16, torch.float32, torch.float64): + yield _gpu_sliced_torch_case, case_name, dtype, g + + +def get_permute_extents_case(case_name): + rng = random.Random(44) + + def permuted(it): + copy = list(it) + rng.shuffle(copy) + return tuple(copy) + + def permuted_extents(ndim): + extents = list(range(ndim)) + rng.shuffle(extents) + return tuple(extents) + + # the original extents are purposely all prime numbers + # to test handling of unaligned tensors + prime_images = [ + (199, 181, 3), + (1097, 61, 5), + (373, 373, 1), + (107, 23, 3), + (173, 23, 5), + (401, 167, 5), + (181, 401, 5), + (181, 107, 1), + (373, 181, 5), + (199, 401, 3), + (1097, 1097, 1), + (181, 61, 1), + (401, 61, 1), + (373, 173, 1), + (227, 199, 3), + (181, 1097, 1), + ] + + if case_name == "transpose_channels_image": + prime_images_transposed_channel = list(zip(prime_images, [(2, 0, 1)] * len(prime_images))) + assert len(prime_images_transposed_channel) == len(prime_images) + return prime_images_transposed_channel + + if case_name == "transpose_hw_image": + prime_images_transposed_hw = list(zip(prime_images, [(1, 0, 2)] * len(prime_images))) + assert len(prime_images_transposed_hw) == len(prime_images) + return prime_images_transposed_hw + + if case_name == "image_random_permutation": + prime_images_rnd_permuted = list( + zip(prime_images, [permuted_extents(3) for _ in range(len(prime_images))])) + assert len(prime_images_rnd_permuted) == len(prime_images) + return prime_images_rnd_permuted + + if case_name == "transpose_channels_video": + prime_vid_like = [ + (13, 199, 181, 3), + (3, 1097, 61, 5), + (17, 373, 373, 1), + (5, 107, 23, 3), + (11, 173, 23, 5), + (11, 401, 167, 5), + (7, 181, 401, 5), + (5, 181, 107, 1), + (3, 373, 181, 5), + (23, 199, 401, 3), + (3, 1097, 1097, 1), + (31, 181, 61, 1), + (17, 401, 61, 1), + (5, 373, 173, 1), + (3, 227, 199, 3), + (7, 181, 1097, 1), + ] + + prime_vid_like_transposed_channel = list( + zip(prime_vid_like, [(3, 0, 1, 2)] * len(prime_vid_like))) + assert len(prime_vid_like_transposed_channel) == len(prime_vid_like) + return prime_vid_like_transposed_channel + + if case_name == "ndim_6_permute_outermost_3": + # optimization to early stop translation of flat output index to flat input index + # should kick in, test if that's fine + ndim_6_transpose_outermost = [(permuted([3, 5, 7, 11, 13, 17]), permuted_extents(3) + + (3, 4, 5)) for _ in range(5)] + assert len(ndim_6_transpose_outermost) == 5 + return ndim_6_transpose_outermost + + if case_name == "ndim_6_permute_all": + ndim_6_rnd_permuted = [(permuted([3, 5, 7, 11, 13, 17]), permuted_extents(6)) + for _ in range(32)] + assert len(ndim_6_rnd_permuted) == 32 + return ndim_6_rnd_permuted + + if case_name == "ndim_15_permute_all": + # max ndim supported + ndim_15_rnd_permuted = [(permuted([3, 5, 7, 11, 13, 17, 1, 1, 1, 1, 1, 1, 1, 1, + 1]), permuted_extents(15)) for _ in range(32)] + assert len(ndim_15_rnd_permuted) == 32 + return ndim_15_rnd_permuted + + +def _gpu_permuted_extents_torch_case(case_name, dtype, g): + + shapes_perms = get_permute_extents_case(case_name) + shapes, perms = tuple(zip(*shapes_perms)) + assert len(shapes) == len(perms) == len(shapes_perms) + input_batch = get_random_torch_batch(g, shapes, dtype) + assert len(input_batch) == len(shapes) + + # returns permuted view of the input tensors + def permuted_tensors(batch): + stream = current_dali_stream() + torch_stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(torch_stream): + tensors = [torch_dlpack.from_dlpack(t) for t in batch] + assert len(tensors) == len(perms) + tensor_views = [t.permute(perm) for t, perm in zip(tensors, perms)] + out = [torch_dlpack.to_dlpack(t) for t in tensor_views] + return out + + @pipeline_def(batch_size=len(input_batch), num_threads=4, device_id=0) + def pipeline(): + data = fn.external_source(lambda: input_batch) + data = fn.dl_tensor_python_function(data.gpu(), batch_processing=True, + function=permuted_tensors, synchronize_stream=False) + return data + + p = pipeline() + p.build() + out, = p.run() + + out = [numpy.array(sample) for sample in out.as_cpu()] + ref = [numpy.array(sample).transpose(perm) for sample, perm in zip(input_batch, perms)] + + numpy.testing.assert_equal(out, ref) + + +def _gpu_permuted_extents_torch_suite(): + + g = torch.Generator() + g.manual_seed(44) + + for case_name in ( + "transpose_channels_image", + "transpose_hw_image", + "image_random_permutation", + "transpose_channels_video", + "ndim_6_permute_outermost_3", + "ndim_6_permute_all", + "ndim_15_permute_all", + ): + for dtype in (torch.uint8, torch.int16, torch.int32, torch.float64): + yield _gpu_permuted_extents_torch_case, case_name, dtype, g + + +def _cupy_negative_strides_case(dtype, batch_size, steps): + + @pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=42) + def baseline_pipeline(): + img, _ = fn.readers.file(name="Reader", file_root=images_dir, random_shuffle=True, seed=42) + img = fn.decoders.image(img, device="mixed") + img = fn.cast(img, dtype=dtype) + img = img[tuple(slice(None, None, step) for step in steps)] + return img + + def flip_cupy(dlps): + stream = current_dali_stream() + cp_stream = cupy.cuda.ExternalStream(stream, device_id=0) + with cp_stream: + imgs = [cupy.from_dlpack(dlp) for dlp in dlps] + imgs = [img[tuple(slice(None, None, step) for step in steps)] for img in imgs] + imgs = [img.toDlpack() for img in imgs] + return imgs + + @pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=42) + def pipeline(): + img, _ = fn.readers.file(name="Reader", file_root=images_dir, random_shuffle=True, seed=42) + img = fn.decoders.image(img, device="mixed") + img = fn.cast(img, dtype=dtype) + img = fn.dl_tensor_python_function(img, batch_processing=True, function=flip_cupy, + synchronize_stream=False) + return img + + p = pipeline() + p.build() + baseline = baseline_pipeline() + baseline.build() + + for _ in range(5): + batch, = p.run() + baseline_batch, = baseline.run() + batch = [numpy.array(sample) for sample in batch.as_cpu()] + baseline_batch = [numpy.array(sample) for sample in baseline_batch.as_cpu()] + assert len(batch) == len(baseline_batch) == batch_size + for sample, baseline_sample in zip(batch, baseline_batch): + numpy.testing.assert_equal(sample, baseline_sample) + + +def _cupy_flip_with_negative_strides_suite(): + for dtype, batch_size, steps in [ + (types.DALIDataType.UINT8, 4, (-1, -1, None)), + (types.DALIDataType.UINT8, 16, (-1, None, None)), + (types.DALIDataType.UINT8, 2, (None, None, -1)), + (types.DALIDataType.UINT8, 5, (-1, -1, -1)), + (types.DALIDataType.UINT8, 16, (-2, -2, None)), + (types.DALIDataType.UINT16, 11, (None, -1, None)), + (types.DALIDataType.FLOAT, 16, (2, -2, None)), + (types.DALIDataType.INT32, 12, (-2, None, None)), + (types.DALIDataType.FLOAT64, 11, (-2, 4, -1)), + ]: + yield _cupy_negative_strides_case, dtype, batch_size, steps