From 09ecb7293e8b8831bedc819a0cc51995af410332 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Tue, 30 Jan 2024 07:24:03 +0000 Subject: [PATCH] Support `count_include_pad` attr for AveragePool operator PyTorch sets this attribute to true even if no padding is used. --- src/model.rs | 2 ++ src/model_builder.rs | 1 + src/ops/pooling.rs | 62 ++++++++++++++++++++++++--------------- src/schema.fbs | 2 ++ src/schema_generated.rs | 25 ++++++++++++++++ tools/convert-onnx.py | 12 +++++++- tools/schema_generated.py | 15 +++++++++- 7 files changed, 94 insertions(+), 25 deletions(-) diff --git a/src/model.rs b/src/model.rs index e8a459a2..1d866b3f 100644 --- a/src/model.rs +++ b/src/model.rs @@ -599,6 +599,7 @@ fn read_average_pool_op(node: &OperatorNode) -> ReadOpResult { Ok(Box::new(ops::AveragePool { kernel_size, padding, + count_include_pad: attrs.count_include_pad(), strides, })) } @@ -1335,6 +1336,7 @@ mod tests { kernel_size: [2, 2], strides: [2, 2], padding: [0, 0, 0, 0].into(), + count_include_pad: false, }); // Dummy value for BatchNormalization inputs which are vectors with diff --git a/src/model_builder.rs b/src/model_builder.rs index 254d0e8a..4ee7cbdc 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -343,6 +343,7 @@ impl<'a> ModelBuilder<'a> { pad_mode: pad_args.pad_mode, pads, strides, + count_include_pad: args.count_include_pad, } }), OpType::BatchNormalization(args) => op_with_attrs!( diff --git a/src/ops/pooling.rs b/src/ops/pooling.rs index 4612697c..c374be2b 100644 --- a/src/ops/pooling.rs +++ b/src/ops/pooling.rs @@ -91,6 +91,7 @@ pub fn average_pool( kernel_size: [usize; 2], strides: [usize; 2], padding: Padding, + count_include_pad: bool, ) -> Result { let [batch, in_c, in_h, in_w] = check_dims!(input, 4, "NCHW"); let (out_h, out_w, fixed_padding) = calc_output_size_and_padding( @@ -134,7 +135,13 @@ pub fn average_pool( } } - out_view[[out_y, out_x]] = accumulator / non_padding_elements; + let counted_elems = if count_include_pad { + (kernel_h * kernel_w) as f32 + } else { + non_padding_elements + }; + + out_view[[out_y, out_x]] = accumulator / counted_elems; } } } @@ -147,6 +154,7 @@ pub fn average_pool( pub struct AveragePool { pub kernel_size: [usize; 2], pub padding: Padding, + pub count_include_pad: bool, pub strides: [usize; 2], } @@ -162,6 +170,7 @@ impl Operator for AveragePool { self.kernel_size, self.strides, self.padding.clone(), + self.count_include_pad, ) .into_op_result() } @@ -368,19 +377,6 @@ mod tests { use crate::ops::tests::expect_eq_1e4; use crate::ops::{average_pool, global_average_pool, max_pool, OpError, Padding}; - fn from_2d_slice(data: &[&[T]]) -> Tensor { - let rows = data.len(); - let cols = data.get(0).map(|first_row| first_row.len()).unwrap_or(0); - - let mut result = Vec::new(); - for row in data { - assert!(cols == row.len(), "All row slices must have same length"); - result.extend_from_slice(row); - } - - Tensor::from_data(&[rows, cols], result) - } - #[test] fn test_average_pool() -> Result<(), Box> { let input = Tensor::from_data( @@ -452,6 +448,7 @@ mod tests { case.kernel_size, case.strides, [0, 0, 0, 0].into(), + false, /* count_include_pad */ ) .unwrap(); expect_equal(&result, &case.expected)?; @@ -462,21 +459,21 @@ mod tests { #[test] fn test_average_pool_padding() -> Result<(), Box> { - let mut input = from_2d_slice(&[ - &[0.0809, 0.5529, 0.1534, 0.7507], - &[0.4698, 0.7771, 0.9896, 0.4873], - &[0.9750, 0.5160, 0.6419, 0.3670], - &[0.4101, 0.3762, 0.9689, 0.4389], + let mut input = Tensor::from([ + [0.0809, 0.5529, 0.1534, 0.7507], + [0.4698, 0.7771, 0.9896, 0.4873], + [0.9750, 0.5160, 0.6419, 0.3670], + [0.4101, 0.3762, 0.9689, 0.4389], ]); let [rows, cols]: [usize; 2] = input.shape().try_into().unwrap(); input.reshape(&[1, 1, rows, cols]); // Computed with `torch.nn.functional.avg_pool2d` in PyTorch with // `padding=1` and `count_include_pad=False`. - let mut expected = from_2d_slice(&[ - &[0.0809, 0.3531, 0.7507], - &[0.7224, 0.7312, 0.4271], - &[0.4101, 0.6725, 0.4389], + let mut expected = Tensor::from([ + [0.0809, 0.3531, 0.7507], + [0.7224, 0.7312, 0.4271], + [0.4101, 0.6725, 0.4389], ]); let [rows, cols]: [usize; 2] = expected.shape().try_into().unwrap(); expected.reshape(&[1, 1, rows, cols]); @@ -486,10 +483,29 @@ mod tests { [2, 2], [2, 2], /* stride */ [1, 1, 1, 1].into(), + false, /* count_include_pad */ ) .unwrap(); expect_eq_1e4(&result, &expected)?; + // As above, but with `count_include_pad=True`. + let expected_include_pad = Tensor::from([ + [0.0202, 0.1766, 0.1877], + [0.3612, 0.7312, 0.2136], + [0.1025, 0.3363, 0.1097], + ]) + .into_shape([1, 1, 3, 3]) + .into_dyn(); + let result = average_pool( + input.view(), + [2, 2], + [2, 2], /* stride */ + [1, 1, 1, 1].into(), + true, /* count_include_pad */ + ) + .unwrap(); + expect_eq_1e4(&result, &expected_include_pad)?; + Ok(()) } diff --git a/src/schema.fbs b/src/schema.fbs index 9daf90fd..a6a63c95 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -192,6 +192,8 @@ table AveragePoolAttrs { pads:[uint]; strides:[uint]; + + count_include_pad:bool; } table BatchNormalizationAttrs { diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 57571e1b..4d196921 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -1878,6 +1878,7 @@ impl<'a> AveragePoolAttrs<'a> { pub const VT_PAD_MODE: flatbuffers::VOffsetT = 6; pub const VT_PADS: flatbuffers::VOffsetT = 8; pub const VT_STRIDES: flatbuffers::VOffsetT = 10; + pub const VT_COUNT_INCLUDE_PAD: flatbuffers::VOffsetT = 12; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -1898,6 +1899,7 @@ impl<'a> AveragePoolAttrs<'a> { if let Some(x) = args.kernel_size { builder.add_kernel_size(x); } + builder.add_count_include_pad(args.count_include_pad); builder.add_pad_mode(args.pad_mode); builder.finish() } @@ -1953,6 +1955,17 @@ impl<'a> AveragePoolAttrs<'a> { ) } } + #[inline] + pub fn count_include_pad(&self) -> bool { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(AveragePoolAttrs::VT_COUNT_INCLUDE_PAD, Some(false)) + .unwrap() + } + } } impl flatbuffers::Verifiable for AveragePoolAttrs<'_> { @@ -1979,6 +1992,7 @@ impl flatbuffers::Verifiable for AveragePoolAttrs<'_> { Self::VT_STRIDES, false, )? + .visit_field::("count_include_pad", Self::VT_COUNT_INCLUDE_PAD, false)? .finish(); Ok(()) } @@ -1988,6 +2002,7 @@ pub struct AveragePoolAttrsArgs<'a> { pub pad_mode: PadMode, pub pads: Option>>, pub strides: Option>>, + pub count_include_pad: bool, } impl<'a> Default for AveragePoolAttrsArgs<'a> { #[inline] @@ -1997,6 +2012,7 @@ impl<'a> Default for AveragePoolAttrsArgs<'a> { pad_mode: PadMode::Same, pads: None, strides: None, + count_include_pad: false, } } } @@ -2032,6 +2048,14 @@ impl<'a: 'b, 'b> AveragePoolAttrsBuilder<'a, 'b> { .push_slot_always::>(AveragePoolAttrs::VT_STRIDES, strides); } #[inline] + pub fn add_count_include_pad(&mut self, count_include_pad: bool) { + self.fbb_.push_slot::( + AveragePoolAttrs::VT_COUNT_INCLUDE_PAD, + count_include_pad, + false, + ); + } + #[inline] pub fn new( _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, ) -> AveragePoolAttrsBuilder<'a, 'b> { @@ -2057,6 +2081,7 @@ impl core::fmt::Debug for AveragePoolAttrs<'_> { ds.field("pad_mode", &self.pad_mode()); ds.field("pads", &self.pads()); ds.field("strides", &self.strides()); + ds.field("count_include_pad", &self.count_include_pad()); ds.finish() } } diff --git a/tools/convert-onnx.py b/tools/convert-onnx.py index 729fc11b..48e53bcd 100755 --- a/tools/convert-onnx.py +++ b/tools/convert-onnx.py @@ -223,6 +223,16 @@ def get_attr(self, name: str, expected_type: str, default): return val return default + def get_bool_attr(self, name: str, default: bool) -> bool: + """ + Get the value of an optional boolean operator attribute. + + ONNX represents boolean attributes as "int" fields with values 0 or 1 + rather than a dedicated boolean type. This method converts these + attributes to Python booleans. + """ + return bool(self.get_attr(name, "int", int(default))) + def get_enum_attr(self, name: str, enum: Any, default: str, fallback: Any = None): """ Get an optional attribute whose value is an enum variant. @@ -564,7 +574,6 @@ def op_node_from_onnx_operator( check_ints_length("kernel_shape", kernel_shape, 2) pad_mode, pads = read_pads(op_reader) op_reader.check_attr("ceil_mode", "int", 0) - op_reader.check_attr("count_include_pad", "int", 0) attrs = sg.AveragePoolAttrsT() attrs.kernelSize = kernel_shape @@ -575,6 +584,7 @@ def op_node_from_onnx_operator( else: attrs.padMode = sg.PadMode.Fixed attrs.strides = read_strides(op_reader) + attrs.countIncludePad = op_reader.get_bool_attr("count_include_pad", False) case "BatchNormalization": attrs = sg.BatchNormalizationAttrsT() diff --git a/tools/schema_generated.py b/tools/schema_generated.py index 4171a622..6a7d975b 100644 --- a/tools/schema_generated.py +++ b/tools/schema_generated.py @@ -501,8 +501,15 @@ def StridesIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) return o == 0 + # AveragePoolAttrs + def CountIncludePad(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def AveragePoolAttrsStart(builder): - builder.StartObject(4) + builder.StartObject(5) def AveragePoolAttrsAddKernelSize(builder, kernelSize): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernelSize), 0) @@ -525,6 +532,9 @@ def AveragePoolAttrsAddStrides(builder, strides): def AveragePoolAttrsStartStridesVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def AveragePoolAttrsAddCountIncludePad(builder, countIncludePad): + builder.PrependBoolSlot(4, countIncludePad, 0) + def AveragePoolAttrsEnd(builder): return builder.EndObject() @@ -542,6 +552,7 @@ def __init__(self): self.padMode = 0 # type: int self.pads = None # type: List[int] self.strides = None # type: List[int] + self.countIncludePad = False # type: bool @classmethod def InitFromBuf(cls, buf, pos): @@ -586,6 +597,7 @@ def _UnPack(self, averagePoolAttrs): self.strides.append(averagePoolAttrs.Strides(i)) else: self.strides = averagePoolAttrs.StridesAsNumpy() + self.countIncludePad = averagePoolAttrs.CountIncludePad() # AveragePoolAttrsT def Pack(self, builder): @@ -621,6 +633,7 @@ def Pack(self, builder): AveragePoolAttrsAddPads(builder, pads) if self.strides is not None: AveragePoolAttrsAddStrides(builder, strides) + AveragePoolAttrsAddCountIncludePad(builder, self.countIncludePad) averagePoolAttrs = AveragePoolAttrsEnd(builder) return averagePoolAttrs