Skip to content

Commit

Permalink
[TensorV2] Average Tensor element by axes
Browse files Browse the repository at this point in the history
This pull request adds new functions to the TensorV2 that allow users to average tensor elements along specified axes.
The functions take in an axis or list of axes as input and return a new tensor with elements replaced by their corresponding means. If no axis is provided, it returns a tensor value by averaging the elements by all axes.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Mar 7, 2024
1 parent 459810e commit 35d4b36
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
60 changes: 60 additions & 0 deletions nntrainer/tensor/tensor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,66 @@ TensorV2 &TensorV2::sum(const std::vector<unsigned int> &axes, TensorV2 &output,
return output;
}

TensorV2 TensorV2::average(unsigned int axis) const {
TensorV2 output("", this->getFormat(), this->getDataType());
return average(axis, output);
}

TensorV2 &TensorV2::average(unsigned int axis, TensorV2 &output) const {
if (axis >= TensorDim::MAXDIM)
throw std::out_of_range(
"negative axis or axis more then MAXDIM is invalid");

unsigned int axis_size = getDim()[axis];
if (axis_size == 1)
output.copy(*this);
else
this->sum(axis, output, 1.0 / ((float)axis_size));

return output;
}

TensorV2 TensorV2::average(const std::vector<unsigned int> &axes) const {
TensorV2 output("", this->getFormat(), this->getDataType());
return average(axes, output);
}

TensorV2 &TensorV2::average(const std::vector<unsigned int> &axes,
TensorV2 &output) const {
if (axes.empty())
return this->average(output);

TensorDim ret_shape(getTensorType());

for (const auto &idx : axes) {
if (idx >= TensorDim::MAXDIM) {
throw std::out_of_range("axis more then MAXDIM is invalid");
}
ret_shape.setTensorDim(idx, getDim().getTensorDim(idx));
}

return this->sum(axes, output, 1.0 / (float)ret_shape.getDataLen());
}

TensorV2 TensorV2::average() const {
TensorV2 output = *this;
unsigned int axis = 0;
if (this->getFormat() == Tformat::NHWC) {
output.reshape({1, getDim().getDataLen(), 1, 1, this->getTensorType()});
axis = 1;
} else {
output.reshape({1, 1, 1, getDim().getDataLen(), this->getTensorType()});
axis = 3;
}
return output.average(axis);
}

TensorV2 &TensorV2::average(TensorV2 &output) const {
TensorV2 result = *this;
result.reshape({1, 1, 1, getDim().getDataLen()});
return result.average(3, output);
}

int TensorV2::pow_i(float exponent) {
pow(exponent, *this);
return ML_ERROR_NONE;
Expand Down
44 changes: 44 additions & 0 deletions nntrainer/tensor/tensor_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,50 @@ class TensorV2 {
TensorV2 &sum(const std::vector<unsigned int> &axes, TensorV2 &output,
float alpha = 1.0) const;

/**
* @brief Averaging the Tensor elements according to the axis
* 0 : batch direction
* 1 : channel direction
* 2 : height direction
* 3 : width direction
* @retval Calculated Tensor
*/
TensorV2 average(unsigned int axis) const;

/**
* @brief Averaging the Tensor elements according to the axis
* @retval Calculated Tensor
*/
TensorV2 &average(unsigned int axis, TensorV2 &output) const;

/**
* @brief Average all the Tensor by multiple axes
* @param[in] axes axes to sum along
* @retval Calculated Tensor
*/
TensorV2 average(const std::vector<unsigned int> &axes) const;

/**
* @brief Average all the Tensor by multiple axes
* @param[in] axes axes to sum along
* @param[out] output output tensor
* @retval Calculated Tensor
*/
TensorV2 &average(const std::vector<unsigned int> &axes,
TensorV2 &output) const;

/**
* @brief Average the Tensor elements by all axis
* @retval Calculated Tensor
*/
TensorV2 average() const;

/**
* @brief Averaging the Tensor elements by all axis
* @retval Calculated Tensor
*/
TensorV2 &average(TensorV2 &output) const;

/**
* @brief Tensor power element without mem copy
* @param[in] exponent exponent
Expand Down

0 comments on commit 35d4b36

Please sign in to comment.