Skip to content

Commit

Permalink
Support count_include_pad attr for AveragePool operator
Browse files Browse the repository at this point in the history
PyTorch sets this attribute to true even if no padding is used.
  • Loading branch information
robertknight committed Jan 30, 2024
1 parent f9827c7 commit 09ecb72
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 25 deletions.
2 changes: 2 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ fn read_average_pool_op(node: &OperatorNode) -> ReadOpResult {
Ok(Box::new(ops::AveragePool {
kernel_size,
padding,
count_include_pad: attrs.count_include_pad(),
strides,
}))
}
Expand Down Expand Up @@ -1335,6 +1336,7 @@ mod tests {
kernel_size: [2, 2],
strides: [2, 2],
padding: [0, 0, 0, 0].into(),
count_include_pad: false,
});

// Dummy value for BatchNormalization inputs which are vectors with
Expand Down
1 change: 1 addition & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ impl<'a> ModelBuilder<'a> {
pad_mode: pad_args.pad_mode,
pads,
strides,
count_include_pad: args.count_include_pad,
}
}),
OpType::BatchNormalization(args) => op_with_attrs!(
Expand Down
62 changes: 39 additions & 23 deletions src/ops/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub fn average_pool(
kernel_size: [usize; 2],
strides: [usize; 2],
padding: Padding,
count_include_pad: bool,
) -> Result<Tensor, OpError> {
let [batch, in_c, in_h, in_w] = check_dims!(input, 4, "NCHW");
let (out_h, out_w, fixed_padding) = calc_output_size_and_padding(
Expand Down Expand Up @@ -134,7 +135,13 @@ pub fn average_pool(
}
}

out_view[[out_y, out_x]] = accumulator / non_padding_elements;
let counted_elems = if count_include_pad {
(kernel_h * kernel_w) as f32
} else {
non_padding_elements
};

out_view[[out_y, out_x]] = accumulator / counted_elems;
}
}
}
Expand All @@ -147,6 +154,7 @@ pub fn average_pool(
pub struct AveragePool {
pub kernel_size: [usize; 2],
pub padding: Padding,
pub count_include_pad: bool,
pub strides: [usize; 2],
}

Expand All @@ -162,6 +170,7 @@ impl Operator for AveragePool {
self.kernel_size,
self.strides,
self.padding.clone(),
self.count_include_pad,
)
.into_op_result()
}
Expand Down Expand Up @@ -368,19 +377,6 @@ mod tests {
use crate::ops::tests::expect_eq_1e4;
use crate::ops::{average_pool, global_average_pool, max_pool, OpError, Padding};

fn from_2d_slice<T: Clone>(data: &[&[T]]) -> Tensor<T> {
let rows = data.len();
let cols = data.get(0).map(|first_row| first_row.len()).unwrap_or(0);

let mut result = Vec::new();
for row in data {
assert!(cols == row.len(), "All row slices must have same length");
result.extend_from_slice(row);
}

Tensor::from_data(&[rows, cols], result)
}

#[test]
fn test_average_pool() -> Result<(), Box<dyn Error>> {
let input = Tensor::from_data(
Expand Down Expand Up @@ -452,6 +448,7 @@ mod tests {
case.kernel_size,
case.strides,
[0, 0, 0, 0].into(),
false, /* count_include_pad */
)
.unwrap();
expect_equal(&result, &case.expected)?;
Expand All @@ -462,21 +459,21 @@ mod tests {

#[test]
fn test_average_pool_padding() -> Result<(), Box<dyn Error>> {
let mut input = from_2d_slice(&[
&[0.0809, 0.5529, 0.1534, 0.7507],
&[0.4698, 0.7771, 0.9896, 0.4873],
&[0.9750, 0.5160, 0.6419, 0.3670],
&[0.4101, 0.3762, 0.9689, 0.4389],
let mut input = Tensor::from([
[0.0809, 0.5529, 0.1534, 0.7507],
[0.4698, 0.7771, 0.9896, 0.4873],
[0.9750, 0.5160, 0.6419, 0.3670],
[0.4101, 0.3762, 0.9689, 0.4389],
]);
let [rows, cols]: [usize; 2] = input.shape().try_into().unwrap();
input.reshape(&[1, 1, rows, cols]);

// Computed with `torch.nn.functional.avg_pool2d` in PyTorch with
// `padding=1` and `count_include_pad=False`.
let mut expected = from_2d_slice(&[
&[0.0809, 0.3531, 0.7507],
&[0.7224, 0.7312, 0.4271],
&[0.4101, 0.6725, 0.4389],
let mut expected = Tensor::from([
[0.0809, 0.3531, 0.7507],
[0.7224, 0.7312, 0.4271],
[0.4101, 0.6725, 0.4389],
]);
let [rows, cols]: [usize; 2] = expected.shape().try_into().unwrap();
expected.reshape(&[1, 1, rows, cols]);
Expand All @@ -486,10 +483,29 @@ mod tests {
[2, 2],
[2, 2], /* stride */
[1, 1, 1, 1].into(),
false, /* count_include_pad */
)
.unwrap();
expect_eq_1e4(&result, &expected)?;

// As above, but with `count_include_pad=True`.
let expected_include_pad = Tensor::from([
[0.0202, 0.1766, 0.1877],
[0.3612, 0.7312, 0.2136],
[0.1025, 0.3363, 0.1097],
])
.into_shape([1, 1, 3, 3])
.into_dyn();
let result = average_pool(
input.view(),
[2, 2],
[2, 2], /* stride */
[1, 1, 1, 1].into(),
true, /* count_include_pad */
)
.unwrap();
expect_eq_1e4(&result, &expected_include_pad)?;

Ok(())
}

Expand Down
2 changes: 2 additions & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ table AveragePoolAttrs {
pads:[uint];

strides:[uint];

count_include_pad:bool;
}

table BatchNormalizationAttrs {
Expand Down
25 changes: 25 additions & 0 deletions src/schema_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,7 @@ impl<'a> AveragePoolAttrs<'a> {
pub const VT_PAD_MODE: flatbuffers::VOffsetT = 6;
pub const VT_PADS: flatbuffers::VOffsetT = 8;
pub const VT_STRIDES: flatbuffers::VOffsetT = 10;
pub const VT_COUNT_INCLUDE_PAD: flatbuffers::VOffsetT = 12;

#[inline]
pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self {
Expand All @@ -1898,6 +1899,7 @@ impl<'a> AveragePoolAttrs<'a> {
if let Some(x) = args.kernel_size {
builder.add_kernel_size(x);
}
builder.add_count_include_pad(args.count_include_pad);
builder.add_pad_mode(args.pad_mode);
builder.finish()
}
Expand Down Expand Up @@ -1953,6 +1955,17 @@ impl<'a> AveragePoolAttrs<'a> {
)
}
}
#[inline]
pub fn count_include_pad(&self) -> bool {
// Safety:
// Created from valid Table for this object
// which contains a valid value in this slot
unsafe {
self._tab
.get::<bool>(AveragePoolAttrs::VT_COUNT_INCLUDE_PAD, Some(false))
.unwrap()
}
}
}

impl flatbuffers::Verifiable for AveragePoolAttrs<'_> {
Expand All @@ -1979,6 +1992,7 @@ impl flatbuffers::Verifiable for AveragePoolAttrs<'_> {
Self::VT_STRIDES,
false,
)?
.visit_field::<bool>("count_include_pad", Self::VT_COUNT_INCLUDE_PAD, false)?
.finish();
Ok(())
}
Expand All @@ -1988,6 +2002,7 @@ pub struct AveragePoolAttrsArgs<'a> {
pub pad_mode: PadMode,
pub pads: Option<flatbuffers::WIPOffset<flatbuffers::Vector<'a, u32>>>,
pub strides: Option<flatbuffers::WIPOffset<flatbuffers::Vector<'a, u32>>>,
pub count_include_pad: bool,
}
impl<'a> Default for AveragePoolAttrsArgs<'a> {
#[inline]
Expand All @@ -1997,6 +2012,7 @@ impl<'a> Default for AveragePoolAttrsArgs<'a> {
pad_mode: PadMode::Same,
pads: None,
strides: None,
count_include_pad: false,
}
}
}
Expand Down Expand Up @@ -2032,6 +2048,14 @@ impl<'a: 'b, 'b> AveragePoolAttrsBuilder<'a, 'b> {
.push_slot_always::<flatbuffers::WIPOffset<_>>(AveragePoolAttrs::VT_STRIDES, strides);
}
#[inline]
pub fn add_count_include_pad(&mut self, count_include_pad: bool) {
self.fbb_.push_slot::<bool>(
AveragePoolAttrs::VT_COUNT_INCLUDE_PAD,
count_include_pad,
false,
);
}
#[inline]
pub fn new(
_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>,
) -> AveragePoolAttrsBuilder<'a, 'b> {
Expand All @@ -2057,6 +2081,7 @@ impl core::fmt::Debug for AveragePoolAttrs<'_> {
ds.field("pad_mode", &self.pad_mode());
ds.field("pads", &self.pads());
ds.field("strides", &self.strides());
ds.field("count_include_pad", &self.count_include_pad());
ds.finish()
}
}
Expand Down
12 changes: 11 additions & 1 deletion tools/convert-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ def get_attr(self, name: str, expected_type: str, default):
return val
return default

def get_bool_attr(self, name: str, default: bool) -> bool:
"""
Get the value of an optional boolean operator attribute.
ONNX represents boolean attributes as "int" fields with values 0 or 1
rather than a dedicated boolean type. This method converts these
attributes to Python booleans.
"""
return bool(self.get_attr(name, "int", int(default)))

def get_enum_attr(self, name: str, enum: Any, default: str, fallback: Any = None):
"""
Get an optional attribute whose value is an enum variant.
Expand Down Expand Up @@ -564,7 +574,6 @@ def op_node_from_onnx_operator(
check_ints_length("kernel_shape", kernel_shape, 2)
pad_mode, pads = read_pads(op_reader)
op_reader.check_attr("ceil_mode", "int", 0)
op_reader.check_attr("count_include_pad", "int", 0)

attrs = sg.AveragePoolAttrsT()
attrs.kernelSize = kernel_shape
Expand All @@ -575,6 +584,7 @@ def op_node_from_onnx_operator(
else:
attrs.padMode = sg.PadMode.Fixed
attrs.strides = read_strides(op_reader)
attrs.countIncludePad = op_reader.get_bool_attr("count_include_pad", False)

case "BatchNormalization":
attrs = sg.BatchNormalizationAttrsT()
Expand Down
15 changes: 14 additions & 1 deletion tools/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,15 @@ def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
return o == 0

# AveragePoolAttrs
def CountIncludePad(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False

def AveragePoolAttrsStart(builder):
builder.StartObject(4)
builder.StartObject(5)

def AveragePoolAttrsAddKernelSize(builder, kernelSize):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(kernelSize), 0)
Expand All @@ -525,6 +532,9 @@ def AveragePoolAttrsAddStrides(builder, strides):
def AveragePoolAttrsStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def AveragePoolAttrsAddCountIncludePad(builder, countIncludePad):
builder.PrependBoolSlot(4, countIncludePad, 0)

def AveragePoolAttrsEnd(builder):
return builder.EndObject()

Expand All @@ -542,6 +552,7 @@ def __init__(self):
self.padMode = 0 # type: int
self.pads = None # type: List[int]
self.strides = None # type: List[int]
self.countIncludePad = False # type: bool

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -586,6 +597,7 @@ def _UnPack(self, averagePoolAttrs):
self.strides.append(averagePoolAttrs.Strides(i))
else:
self.strides = averagePoolAttrs.StridesAsNumpy()
self.countIncludePad = averagePoolAttrs.CountIncludePad()

# AveragePoolAttrsT
def Pack(self, builder):
Expand Down Expand Up @@ -621,6 +633,7 @@ def Pack(self, builder):
AveragePoolAttrsAddPads(builder, pads)
if self.strides is not None:
AveragePoolAttrsAddStrides(builder, strides)
AveragePoolAttrsAddCountIncludePad(builder, self.countIncludePad)
averagePoolAttrs = AveragePoolAttrsEnd(builder)
return averagePoolAttrs

Expand Down

0 comments on commit 09ecb72

Please sign in to comment.