diff --git a/src/ops/convert.rs b/src/ops/convert.rs index fd30c48a..2aba72e8 100644 --- a/src/ops/convert.rs +++ b/src/ops/convert.rs @@ -8,14 +8,27 @@ fn cast(pool: &TensorPool, input: Input, dtype: DataType) -> Result match input { Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i32).into()), }, DataType::Float => match input { Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x).into()), Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as f32).into()), + }, + DataType::Int8 => match input { + Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x as i8).into()), + }, + DataType::UInt8 => match input { + Input::FloatTensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::Int32Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::Int8Tensor(t) => Ok(t.map_in(pool, |x| *x as u8).into()), + Input::UInt8Tensor(t) => Ok(t.map_in(pool, |x| *x).into()), }, - _ => Err(OpError::UnsupportedValue("Unsupported cast")), } } diff --git a/src/ops/gather.rs b/src/ops/gather.rs index 51d1e564..ad3ed612 100644 --- a/src/ops/gather.rs +++ b/src/ops/gather.rs @@ -90,7 +90,7 @@ impl Operator for Gather { Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(), Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(), } } } @@ -238,7 +238,12 @@ impl Operator for GatherElements { Input::FloatTensor(input) => { gather_elements(pool, input, indices, self.axis).into_op_result() } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + gather_elements(pool, input, indices, self.axis).into_op_result() + } + Input::UInt8Tensor(input) => { + gather_elements(pool, input, indices, self.axis).into_op_result() + } } } } @@ -336,7 +341,12 @@ impl Operator for GatherND { Input::FloatTensor(input) => { gather_nd(pool, input, indices, self.batch_dims).into_op_result() } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + gather_nd(pool, input, indices, self.batch_dims).into_op_result() + } + Input::UInt8Tensor(input) => { + gather_nd(pool, input, indices, self.batch_dims).into_op_result() + } } } } @@ -451,6 +461,14 @@ impl Operator for ScatterElements { scatter_elements(pool, data, indices, updates, self.axis, self.reduction) .into_op_result() } + (Input::Int8Tensor(data), Input::Int8Tensor(updates)) => { + scatter_elements(pool, data, indices, updates, self.axis, self.reduction) + .into_op_result() + } + (Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => { + scatter_elements(pool, data, indices, updates, self.axis, self.reduction) + .into_op_result() + } _ => Err(OpError::UnsupportedType), } } @@ -547,6 +565,12 @@ impl Operator for ScatterND { (Input::FloatTensor(data), Input::FloatTensor(updates)) => { scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() } + (Input::Int8Tensor(data), Input::Int8Tensor(updates)) => { + scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() + } + (Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => { + scatter_nd(pool, data, indices, updates, self.reduction).into_op_result() + } _ => Err(OpError::UnsupportedType), } } diff --git a/src/ops/layout.rs b/src/ops/layout.rs index d9db9502..29303a73 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -95,7 +95,8 @@ impl Operator for Expand { match input { Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(), Input::Int32Tensor(input) => expand(pool, input, &shape).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::UInt8Tensor(input) => expand(pool, input, &shape).into_op_result(), + Input::Int8Tensor(input) => expand(pool, input, &shape).into_op_result(), } } @@ -122,7 +123,8 @@ impl Operator for Expand { let output: Output = match input { Output::FloatTensor(input) => expand_to(pool, input.view(), &out_shape).into(), Output::Int32Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), - _ => return Err(OpError::UnsupportedType), + Output::Int8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), + Output::UInt8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(), }; Ok(output) } @@ -172,7 +174,8 @@ impl Operator for Flatten { match input { Input::FloatTensor(input) => flatten(pool, input, self.axis).into_op_result(), Input::Int32Tensor(input) => flatten(pool, input, self.axis).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => flatten(pool, input, self.axis).into_op_result(), + Input::UInt8Tensor(input) => flatten(pool, input, self.axis).into_op_result(), } } @@ -195,7 +198,14 @@ impl Operator for Flatten { flatten_in_place(pool, &mut output, self.axis)?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + flatten_in_place(pool, &mut output, self.axis)?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + flatten_in_place(pool, &mut output, self.axis)?; + Ok(output.into()) + } } } } @@ -314,7 +324,8 @@ impl Operator for Reshape { match input { Input::Int32Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), Input::FloatTensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), + Input::UInt8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(), } } @@ -340,7 +351,14 @@ impl Operator for Reshape { reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + reshape_in_place(pool, &mut output, &shape, self.allow_zero)?; + Ok(output.into()) + } } } } @@ -449,7 +467,8 @@ impl Operator for Squeeze { match input { Input::FloatTensor(t) => squeeze(pool, t, axes).into_op_result(), Input::Int32Tensor(t) => squeeze(pool, t, axes).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(t) => squeeze(pool, t, axes).into_op_result(), + Input::UInt8Tensor(t) => squeeze(pool, t, axes).into_op_result(), } } @@ -475,7 +494,14 @@ impl Operator for Squeeze { squeeze_in_place(&mut t, axes)?; Ok(t.into()) } - _ => Err(OpError::UnsupportedType), + Output::UInt8Tensor(mut t) => { + squeeze_in_place(&mut t, axes)?; + Ok(t.into()) + } + Output::Int8Tensor(mut t) => { + squeeze_in_place(&mut t, axes)?; + Ok(t.into()) + } } } } @@ -519,7 +545,8 @@ impl Operator for Transpose { match input { Input::FloatTensor(input) => transpose(pool, input, perm_slice).into_op_result(), Input::Int32Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), + Input::UInt8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(), } } } @@ -577,7 +604,8 @@ impl Operator for Unsqueeze { match input { Input::FloatTensor(input) => unsqueeze(pool, input, &axes).into_op_result(), Input::Int32Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), + Input::UInt8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(), } } @@ -597,7 +625,8 @@ impl Operator for Unsqueeze { match output { Output::FloatTensor(t) => unsqueeze_in_place(t, &axes).map(Output::FloatTensor), Output::Int32Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int32Tensor), - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int8Tensor), + Output::UInt8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::UInt8Tensor), } } } diff --git a/src/ops/pad.rs b/src/ops/pad.rs index c4e79ab5..246e40ff 100644 --- a/src/ops/pad.rs +++ b/src/ops/pad.rs @@ -192,11 +192,18 @@ impl Operator for Pad { let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); pad(pool, t, &pads, self.mode, const_val).into_op_result() } + Input::Int8Tensor(t) => { + let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); + pad(pool, t, &pads, self.mode, const_val).into_op_result() + } + Input::UInt8Tensor(t) => { + let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0); + pad(pool, t, &pads, self.mode, const_val).into_op_result() + } Input::FloatTensor(t) => { let const_val = inputs.get_as_scalar::(2)?.unwrap_or(0.); pad(pool, t, &pads, self.mode, const_val).into_op_result() } - _ => Err(OpError::UnsupportedType), } } } diff --git a/src/ops/slice.rs b/src/ops/slice.rs index 07944e07..f122cb5d 100644 --- a/src/ops/slice.rs +++ b/src/ops/slice.rs @@ -112,7 +112,12 @@ impl Operator for Slice { Input::Int32Tensor(input) => { slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) } - _ => Err(OpError::UnsupportedType), + Input::Int8Tensor(input) => { + slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) + } + Input::UInt8Tensor(input) => { + slice(pool, input, &starts, &ends, axes.as_ref(), steps.as_ref()).map(|t| t.into()) + } }; result.into_op_result() } @@ -168,7 +173,14 @@ impl Operator for Slice { slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; Ok(output.into()) } - _ => Err(OpError::UnsupportedType), + Output::Int8Tensor(mut output) => { + slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; + Ok(output.into()) + } + Output::UInt8Tensor(mut output) => { + slice_in_place(&mut output, &starts, &ends, axes.as_ref())?; + Ok(output.into()) + } } } }