Skip to content

Commit

Permalink
Refactor type of shift argument passed to tosa.rescale (#415)
Browse files Browse the repository at this point in the history
Closes #406.
  • Loading branch information
henri-gruender authored Apr 4, 2024
1 parent 51eb706 commit 922abce
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion reference-implementation/include/emitc/tosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ inline Src reciprocal(Src x) {
template <typename Dest, size_t Dim, typename Src>
inline Dest rescale(Src x, typename get_element_type<Src>::type in_zp,
typename get_element_type<Dest>::type out_zp,
Tensor1D<int32_t, Dim> mult, Tensor1D<int32_t, Dim> shift,
Tensor1D<int32_t, Dim> mult, Tensor1D<int8_t, Dim> shift,
bool scale32, bool double_round, bool per_channel) {
using ET_Dest = typename get_element_type<Dest>::type;
using Dest_I32 = typename replace_element_type<int32_t, Dest>::type;
Expand Down
8 changes: 4 additions & 4 deletions reference-implementation/unittests/tosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ TEST(tosa, rescale) {
int8_t in_zp = 10;
int16_t out_zp = 0;
Tensor1D<int32_t, 1> mult{10000};
Tensor1D<int32_t, 1> shift{5};
Tensor1D<int8_t, 1> shift{5};
bool scale32 = false;
bool double_round = false;
bool per_channel = false;
Expand All @@ -241,7 +241,7 @@ TEST(tosa, rescale) {
int32_t in_zp = 0;
int8_t out_zp = 0;
Tensor1D<int32_t, 3> mult{150, 100, 50};
Tensor1D<int32_t, 3> shift{13, 14, 15};
Tensor1D<int8_t, 3> shift{13, 14, 15};
bool scale32 = false;
bool double_round = false;
bool per_channel = true;
Expand All @@ -258,7 +258,7 @@ TEST(tosa, rescale) {
int64_t in_zp = 0;
uint8_t out_zp = 100;
Tensor1D<int32_t, 1> mult{100};
Tensor1D<int32_t, 1> shift{14};
Tensor1D<int8_t, 1> shift{14};
bool scale32 = false;
bool double_round = false;
bool per_channel = false;
Expand All @@ -278,7 +278,7 @@ TEST(tosa, rescale) {
int64_t in_zp = 0;
int32_t out_zp = 0;
Tensor1D<int32_t, 1> mult{2147483647};
Tensor1D<int32_t, 1> shift{32};
Tensor1D<int8_t, 1> shift{32};
bool scale32 = true;
bool double_round = true;
bool per_channel = false;
Expand Down

0 comments on commit 922abce

Please sign in to comment.