Skip to content

Commit

Permalink
bind sigmoid and tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 22, 2024
1 parent 6bf5c5e commit 211f46b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 2 additions & 0 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ fn ignore_onnx(t: &[String]) -> bool {
test_prelu
test_relu
test_selu
test_sigmoid
test_tanh
test_thresholdrelu
",
);
Expand Down
8 changes: 7 additions & 1 deletion tflite/src/ops/element_wise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tract_core::internal::*;
use tract_core::ops::element_wise::ElementWiseOp;
use tract_core::ops::logic::{ Not, not };
use tract_core::ops::math::*;
use tract_core::ops::nn::{hard_swish, leaky_relu, HardSwish, LeakyRelu};
use tract_core::ops::nn::{hard_swish, leaky_relu, sigmoid, HardSwish, LeakyRelu, Sigmoid};

pub fn register_all(reg: &mut Registry) {
reg.reg_to_tflite(ser);
Expand All @@ -24,9 +24,11 @@ pub fn register_all(reg: &mut Registry) {
reg.reg_to_tract(BuiltinOperator::LOG, |op| deser(op, ln()));
reg.reg_to_tract(BuiltinOperator::LOGICAL_NOT, |op| deser(op, not()));
reg.reg_to_tract(BuiltinOperator::SIN, |op| deser(op, sin()));
reg.reg_to_tract(BuiltinOperator::LOGISTIC, |op| deser(op, sigmoid()));
reg.reg_to_tract(BuiltinOperator::SQRT, |op| deser(op, sqrt()));
reg.reg_to_tract(BuiltinOperator::SQUARE, |op| deser(op, square()));
reg.reg_to_tract(BuiltinOperator::RSQRT, |op| deser(op, rsqrt()));
reg.reg_to_tract(BuiltinOperator::TANH, |op| deser(op, tanh()));
}

fn deser(op: &mut DeserOp, ew: ElementWiseOp) -> TractResult<TVec<OutletId>> {
Expand Down Expand Up @@ -113,6 +115,10 @@ fn ser(
builder.write_op(&[input], &[output], 75, 1, BuiltinOperator::SQRT)
} else if (*op.0).is::<Rsqrt>() {
builder.write_op(&[input], &[output], 76, 1, BuiltinOperator::SQRT)
} else if (*op.0).is::<Sigmoid>() {
builder.write_op(&[input], &[output], 14, 1, BuiltinOperator::LOGISTIC)
} else if (*op.0).is::<Tanh>() {
builder.write_op(&[input], &[output], 28, 1, BuiltinOperator::TANH)
} else if (*op.0).is::<Ln>() {
builder.write_op(&[input], &[output], 73, 1, BuiltinOperator::LOG)
} else {
Expand Down

0 comments on commit 211f46b

Please sign in to comment.