Skip to content

Commit

Permalink
Merge pull request #160 from SludgePhD/fix-initializer-promotion
Browse files Browse the repository at this point in the history
Respect `raw_data` when promoting initializer
  • Loading branch information
pixelspark authored Apr 17, 2023
2 parents 868e882 + 381e148 commit bb5f57f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion wonnx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ exclude = [

[dependencies]
wgpu = "0.14.0"
bytemuck = "1.9.1"
bytemuck = { version = "1.9.1", features = ["extern_crate_alloc"] }
protobuf = { version = "2.27.1", features = ["with-bytes"] }
log = "0.4.17"
tera = { version = "1.15.0", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ pub enum CompileError {
#[error("the opset version {0} is not supported")]
UnsupportedOpsetVersion(i64),

#[error("the value '{attribute}' is invalid for attribute '{value}' (opset version {opset_version})")]
#[error("the value '{value}' is invalid for attribute '{attribute}' (opset version {opset_version})")]
InvalidAttributeValue {
attribute: String,
value: String,
Expand Down
20 changes: 17 additions & 3 deletions wonnx/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
GpuError,
};
use async_recursion::async_recursion;
use bytemuck::pod_collect_to_vec;
use protobuf::RepeatedField;
use std::{
borrow::Cow,
Expand Down Expand Up @@ -613,8 +614,14 @@ impl<'model> Optimizer<'model> {
| ("Resize", "scales")
| ("Clip", "min" | "max") => match data_type {
ScalarType::F32 => {
let value: Vec<f32> =
tensor_proto.get_float_data().to_vec();
let value: Vec<f32> = if tensor_proto
.get_float_data()
.is_empty()
{
pod_collect_to_vec(tensor_proto.get_raw_data())
} else {
tensor_proto.get_float_data().to_vec()
};
log::info!(
"transferring input {} for op {} to f32 attribute (initializer data type: {:?}): {:?}",
attr_name,
Expand All @@ -628,7 +635,14 @@ impl<'model> Optimizer<'model> {
));
}
ScalarType::I64 => {
let value = tensor_proto.get_int64_data().to_vec();
let value = if tensor_proto
.get_int64_data()
.is_empty()
{
pod_collect_to_vec(tensor_proto.get_raw_data())
} else {
tensor_proto.get_int64_data().to_vec()
};
log::info!(
"transferring input {} for op {} to i64 attribute (initializer data type: {:?}): {:?}",
attr_name,
Expand Down

0 comments on commit bb5f57f

Please sign in to comment.