From c395190388e9c014877723701c15d68d05b754af Mon Sep 17 00:00:00 2001 From: Igor Yusupov Date: Wed, 16 Oct 2024 11:46:31 +0500 Subject: [PATCH 1/3] Add int8 and uint8 support --- .DS_Store | Bin 0 -> 6148 bytes src/ops/convert.rs | 19 ++++++++++++++--- src/ops/gather.rs | 30 +++++++++++++++++++++++--- src/ops/layout.rs | 51 +++++++++++++++++++++++++++++++++++---------- src/ops/pad.rs | 9 +++++++- src/ops/slice.rs | 16 ++++++++++++-- 6 files changed, 105 insertions(+), 20 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2c7dd38ba7e91a6cb6b7d24412bef3b610361f77 GIT binary patch literal 6148 zcmeHK!A=4(5S=2*5;WnU3CAX0F@hQ*@v>R*2VA2EHOLCljcZF_fde7gv;L4@;`cby z7NbEf9*iL~$@ER9Gu!la(`f-f1fy;ppbP*WDq*gO%@0E3qzjUB9ztQhu?P2{LBSY$ z3(@TOj||Y>ZNdOj+%W}T_YZy1WY~`q9bwE{)KB7JTC2T_+**Epqu>?1l6MPeV( z($O&J46dnnp>-0??M`$V_a^P~_Ni8BCsw_INr<~W47s_ARaZ{}JyPAF$qns*SM-YQ za%DQLH=0#>cziIc%4zeUR+UGGjoGZ|?drBpl7Z+u%q*e@g)ahz25y*vKV{$pf_6_C literal 0 HcmV?d00001 diff --git a/src/ops/convert.rs b/src/ops/convert.rs index fd30c48a..2aba72e8 100644 --- a/src/ops/convert.rs +++ b/src/ops/convert.rs @@ -8,14 +8,27 @@ fn cast(pool: &TensorPool, input: Input, dtype: DataType) -> Result match input { Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), }, DataType::Float => match input { Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x).into()), Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), + }, + DataType::Int8 => match input { + Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + }, + DataType::UInt8 => match input { + Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), }, - _ => Err(OpError::UnsupportedValue("Unsupported cast")), } } diff --git a/src/ops/gather.rs b/src/ops/gather.rs index 51d1e564..ad3ed612 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -90,7 +90,7 @@ impl Operator for Gather { Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(), Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), } } } @@ -238,7 +238,12 @@ impl Operator for GatherElements { Input::FloatTensor(input) => { gather_elements(pool, input, indices, self.axis).into_op_result() } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + gather_elements(pool, input, indices, self.axis).into_op_result() + } + Input::UInt8Tensor(input) => { + gather_elements(pool, input, indices, self.axis).into_op_result() + } } } } @@ -336,7 +341,12 @@ impl Operator for GatherND { Input::FloatTensor(input) => { gather_nd(pool, input, indices, self.batch_dims).into_op_result() } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + gather_nd(pool, input, indices, self.batch_dims).into_op_result() + } + Input::UInt8Tensor(input) => { + gather_nd(pool, input, indices, self.batch_dims).into_op_result() + } } } } @@ -451,6 +461,14 @@ impl Operator for ScatterElements { scatter_elements(pool, data, indices, updates, self.axis, self.reduction) .into_op_result() } + (Input::Int8Tensor(data), Input::Int8Tensor(updates)) => { + scatter_elements(pool, data, indices, updates, self.axis, self.reduction) + .into_op_result() + } + (Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => { + scatter_elements(pool, data, indices, updates, self.axis, self.reduction) + .into_op_result() + } _ => Err(OpError::UnsupportedType), } } @@ -547,6 +565,12 @@ impl Operator for ScatterND { (Input::FloatTensor(data), Input::FloatTensor(updates)) => { scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() } + (Input::Int8Tensor(data), Input::Int8Tensor(updates)) => { + scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() + } + (Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => { + scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() + } _ => Err(OpError::UnsupportedType), } } diff --git a/src/ops/layout.rs b/src/ops/layout.rs index d9db9502..29303a73 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -95,7 +95,8 @@ impl Operator for Expand { match input { Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(), Input::Int32Tensor(input) => expand(pool, input, &shape).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::UInt8Tensor(input) => expand(pool, input, &shape).into_op_result(), + Input::Int8Tensor(input) => expand(pool, input, &shape).into_op_result(), } } @@ -122,7 +123,8 @@ impl Operator for Expand { let output: Output = match input { Output::FloatTensor(input) => expand_to(pool, input.view(), &out_shape).into(), Output::Int32Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), - _ => return Err(OpError::UnsupportedType), + Output::Int8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), + Output::UInt8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), }; Ok(output) } @@ -172,7 +174,8 @@ impl Operator for Flatten { match input { Input::FloatTensor(input) => flatten(pool, input, self.axis).into_op_result(), Input::Int32Tensor(input) => flatten(pool, input, self.axis).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => flatten(pool, input, self.axis).into_op_result(), + Input::UInt8Tensor(input) => flatten(pool, input, self.axis).into_op_result(), } } @@ -195,7 +198,14 @@ impl Operator for Flatten { flatten_in_place(pool, &mut output, self.axis)?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + flatten_in_place(pool, &mut output, self.axis)?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + flatten_in_place(pool, &mut output, self.axis)?; + Ok(output.into()) + } } } } @@ -314,7 +324,8 @@ impl Operator for Reshape { match input { Input::Int32Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), Input::FloatTensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), + Input::UInt8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), } } @@ -340,7 +351,14 @@ impl Operator for Reshape { reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; + Ok(output.into()) + } } } } @@ -449,7 +467,8 @@ impl Operator for Squeeze { match input { Input::FloatTensor(t) => squeeze(pool, t, axes).into_op_result(), Input::Int32Tensor(t) => squeeze(pool, t, axes).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => squeeze(pool, t, axes).into_op_result(), + Input::UInt8Tensor(t) => squeeze(pool, t, axes).into_op_result(), } } @@ -475,7 +494,14 @@ impl Operator for Squeeze { squeeze_in_place(&mut t, axes)?; Ok(t.into()) } - _ => Err(OpError::UnsupportedType), + Output::UInt8Tensor(mut t) => { + squeeze_in_place(&mut t, axes)?; + Ok(t.into()) + } + Output::Int8Tensor(mut t) => { + squeeze_in_place(&mut t, axes)?; + Ok(t.into()) + } } } } @@ -519,7 +545,8 @@ impl Operator for Transpose { match input { Input::FloatTensor(input) => transpose(pool, input, perm_slice).into_op_result(), Input::Int32Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), + Input::UInt8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), } } } @@ -577,7 +604,8 @@ impl Operator for Unsqueeze { match input { Input::FloatTensor(input) => unsqueeze(pool, input, &axes).into_op_result(), Input::Int32Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), + Input::UInt8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), } } @@ -597,7 +625,8 @@ impl Operator for Unsqueeze { match output { Output::FloatTensor(t) => unsqueeze_in_place(t, &axes).map(Output::FloatTensor), Output::Int32Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int32Tensor), - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int8Tensor), + Output::UInt8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::UInt8Tensor), } } } diff --git a/src/ops/pad.rs b/src/ops/pad.rs index c4e79ab5..246e40ff 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -192,11 +192,18 @@ impl Operator for Pad { let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); pad(pool, t, &pads, self.mode, const_val).into_op_result() } + Input::Int8Tensor(t) => { + let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); + pad(pool, t, &pads, self.mode, const_val).into_op_result() + } + Input::UInt8Tensor(t) => { + let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); + pad(pool, t, &pads, self.mode, const_val).into_op_result() + } Input::FloatTensor(t) => { let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0.); pad(pool, t, &pads, self.mode, const_val).into_op_result() } - _ => Err(OpError::UnsupportedType), } } } diff --git a/src/ops/slice.rs b/src/ops/slice.rs index 07944e07..f122cb5d 100644 --- a/src/ops/slice.rs +++ b/src/ops/slice.rs @@ -112,7 +112,12 @@ impl Operator for Slice { Input::Int32Tensor(input) => { slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) + } + Input::UInt8Tensor(input) => { + slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) + } }; result.into_op_result() } @@ -168,7 +173,14 @@ impl Operator for Slice { slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; + Ok(output.into()) + } } } } From 91ed776627f4dd5e03df552ae875bb5f12ba4460 Mon Sep 17 00:00:00 2001 From: Igor Yusupov Date: Thu, 17 Oct 2024 14:51:49 +0500 Subject: [PATCH 2/3] Remove DS_Store and update gitignore --- .DS_Store | Bin 6148 -> 0 bytes .gitignore | 2 ++ 2 files changed, 2 insertions(+) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 2c7dd38ba7e91a6cb6b7d24412bef3b610361f77..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!A=4(5S=2*5;WnU3CAX0F@hQ*@v>R*2VA2EHOLCljcZF_fde7gv;L4@;`cby z7NbEf9*iL~$@ER9Gu!la(`f-f1fy;ppbP*WDq*gO%@0E3qzjUB9ztQhu?P2{LBSY$ z3(@TOj||Y>ZNdOj+%W}T_YZy1WY~`q9bwE{)KB7JTC2T_+**Epqu>?1l6MPeV( z($O&J46dnnp>-0??M`$V_a^P~_Ni8BCsw_INr<~W47s_ARaZ{}JyPAF$qns*SM-YQ za%DQLH=0#>cziIc%4zeUR+UGGjoGZ|?drBpl7Z+u%q*e@g)ahz25y*vKV{$pf_6_C diff --git a/.gitignore b/.gitignore index 85ab8bb9..c1051434 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ dist/ # Converted rten ML models *.rten models/ + +*.DS_Store From 34afdabc80fddb85edcecba22cdbb21f78e5b137 Mon Sep 17 00:00:00 2001 From: Igor Yusupov Date: Thu, 17 Oct 2024 17:00:01 +0500 Subject: [PATCH 3/3] Remove DS_Store from gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index c1051434..85ab8bb9 100644 --- a/.gitignore +++ b/.gitignore @@ -12,5 +12,3 @@ dist/ # Converted rten ML models *.rten models/ - -*.DS_Store