Skip to content

Commit

Permalink
Replace GPU dltensor per-sample copying kernel with a batched one (#5038
Browse files Browse the repository at this point in the history
)

* Replace GPU dltensor per-sample copying kernel with a batched one
* Adjust cpp tests
* Skip empty samples
* Add tests with torch unaligned sliced and permuted tensors
* Add negative strides/flipping test

---------

Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Sep 12, 2023
1 parent f1b68d4 commit 12a2d1d
Show file tree
Hide file tree
Showing 8 changed files with 982 additions and 152 deletions.
13 changes: 6 additions & 7 deletions dali/operators/python_function/dltensor_function.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -143,23 +143,22 @@ TensorListShape<> GetDLTensorListShape(const std::vector<DLMTensorPtr>& dl_tenso

template <>
void CopyOutputData(TensorList<CPUBackend> &output, std::vector<DLMTensorPtr> &dl_tensors,
int batch_size, Workspace &workspace) {
Workspace &workspace) {
int batch_size = dl_tensors.size();
auto &thread_pool = workspace.GetThreadPool();
auto out_shape = output.shape();
for (int i = 0; i < batch_size; ++i) {
thread_pool.AddWork([&, i](int) {
CopyDlTensor<CPUBackend>(output.raw_mutable_tensor(i), dl_tensors[i]);
CopyDlTensorCpu(output.raw_mutable_tensor(i), dl_tensors[i]);
}, out_shape.tensor_size(i));
}
thread_pool.RunAll();
}

template <>
void CopyOutputData(TensorList<GPUBackend>& output, std::vector<DLMTensorPtr> &dl_tensors,
int batch_size, Workspace &workspace) {
for (int i = 0; i < batch_size; ++i) {
CopyDlTensor<GPUBackend>(output.raw_mutable_tensor(i), dl_tensors[i], workspace.stream());
}
Workspace &workspace) {
CopyDlTensorBatchGpu(output, dl_tensors, workspace.stream());
}

} // namespace detail
Expand Down
23 changes: 4 additions & 19 deletions dali/operators/python_function/dltensor_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
#include <dali/util/pybind.h>
#include <pybind11/embed.h>
#include <pybind11/stl.h>
#include <vector>
#include <utility>
#include <string>
#include <utility>
#include <vector>
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/util/copy_with_stride.h"

Expand Down Expand Up @@ -76,21 +76,6 @@ std::vector<DLMTensorPtr> CastToDLTensorList(py::list &list, Index exp_size, Ind

TensorListShape<> GetDLTensorListShape(const std::vector<DLMTensorPtr> &dl_tensors);

template <typename Backend>
void CopyDlTensor(void *out_data, DLMTensorPtr &dlm_tensor_ptr, cudaStream_t stream = 0) {
auto &dl_tensor = dlm_tensor_ptr->dl_tensor;
auto item_size = dl_tensor.dtype.bits / 8;
if (dl_tensor.strides) {
std::vector<Index> strides(dl_tensor.ndim);
for (Index i = 0; i < dl_tensor.ndim; ++i) strides[i] = dl_tensor.strides[i] * item_size;
CopyWithStride<Backend>(out_data, dl_tensor.data, strides.data(),
dl_tensor.shape, dl_tensor.ndim, item_size, stream);
} else {
CopyWithStride<Backend>(out_data, dl_tensor.data, nullptr,
dl_tensor.shape, dl_tensor.ndim, item_size, stream);
}
}

template <typename Backend>
py::list PrepareDLTensorInputs(Workspace &ws);

Expand All @@ -99,7 +84,7 @@ py::list PrepareDLTensorInputsPerSample(Workspace &ws);

template <typename Workspace, typename Output>
void CopyOutputData(Output& output, std::vector<DLMTensorPtr> &dl_tensors,
int batch_size, Workspace &workspace);
Workspace &workspace);

template <typename Backend>
void PrepareOutputs(Workspace &ws, const py::object &output_o, int batch_size) {
Expand All @@ -110,7 +95,7 @@ void PrepareOutputs(Workspace &ws, const py::object &output_o, int batch_size) {
if (dl_tensors.empty()) continue;
auto &tlist = ws.Output<Backend>(idx);
tlist.Resize(GetDLTensorListShape(dl_tensors), DLToDALIType(dl_tensors[0]->dl_tensor.dtype));
CopyOutputData(tlist, dl_tensors, batch_size, ws);
CopyOutputData(tlist, dl_tensors, ws);
}
}

Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/data/dltensor.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -74,7 +74,7 @@ inline std::string to_string(const DLDataType &dl_type) {

DALIDataType DLToDALIType(const DLDataType &dl_type) {
DALI_ENFORCE(dl_type.lanes == 1,
"DALI Tensors do no not support types with the number of lanes other than 1");
"DALI Tensors do not support types with the number of lanes other than 1");
switch (dl_type.code) {
case kDLUInt: {
switch (dl_type.bits) {
Expand Down
26 changes: 18 additions & 8 deletions dali/pipeline/util/copy_with_stride.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,13 +84,8 @@ inline Index DeepestContiguous(const Index *in_strides,
return 0;
}

template <>
void CopyWithStride<CPUBackend>(void *output, const void *input,
const Index *in_strides,
const Index *shape,
int ndim,
size_t item_size,
cudaStream_t) {
void CopyWithStrideCpu(void *output, const void *input, const Index *in_strides, const Index *shape,
int ndim, size_t item_size) {
assert(ndim >= 0);
if (!in_strides) {
std::memcpy(output, input, item_size * volume(shape, shape + ndim));
Expand All @@ -106,4 +101,19 @@ void CopyWithStride<CPUBackend>(void *output, const void *input,
shape, ndim, 0, deepest_contiguous);
}

void CopyDlTensorCpu(void *out_data, DLMTensorPtr &dlm_tensor_ptr) {
auto &dl_tensor = dlm_tensor_ptr->dl_tensor;
auto item_size = dl_tensor.dtype.bits / 8;
if (dl_tensor.strides) {
std::vector<Index> strides(dl_tensor.ndim);
for (Index i = 0; i < dl_tensor.ndim; ++i)
strides[i] = dl_tensor.strides[i] * item_size;
CopyWithStrideCpu(out_data, dl_tensor.data, strides.data(), dl_tensor.shape, dl_tensor.ndim,
item_size);
} else {
CopyWithStrideCpu(out_data, dl_tensor.data, nullptr, dl_tensor.shape, dl_tensor.ndim,
item_size);
}
}

} // namespace dali
Loading

0 comments on commit 12a2d1d

Please sign in to comment.