Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add opaque fact in packed tensors #1556

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use fs_err as fs;
use ndarray_npy::NpzWriter;
use nu_ansi_term::Color::*;
use tract_core::ops::cnn::conv::Im2Col;
use tract_core::ops::matmul::pack::MatMatMulPack;
use tract_core::ops::matmul::pack::OptMatMulPack;
use tract_core::tract_data::itertools::izip;
use tract_hir::internal::*;
use tract_libcli::tensor::RunParams;
Expand Down Expand Up @@ -213,7 +213,7 @@ fn run_regular(
}
if assert_sane_floats {
for (ix, o) in clarified_r.iter().enumerate() {
if node.op_is::<Im2Col>() || node.op_is::<MatMatMulPack>() {
if node.op_is::<Im2Col>() || node.op_is::<OptMatMulPack>() {
continue;
}
if let Ok(floats) = o.as_slice::<f32>() {
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::ops::math::{add, div, mul, sub};
use crate::ops::math::{Add, Div, Mul, Sub};
use crate::ops::matmul::optimized::AddMatMulGeometry;
use crate::ops::matmul::optimized::MapOutputAxisToInput;
use crate::ops::matmul::pack::MatMatMulPack;
use crate::ops::matmul::pack::OptMatMulPack;
use crate::ops::matmul::quant::wire_ensure_q8_flavour;
use crate::ops::nn::Reduce;

Expand Down Expand Up @@ -79,7 +79,7 @@ impl Conv {
) -> TractResult<OutletId> {
Ok(model.wire_node(
format!("{name}.prep_kernel.pack"),
MatMatMulPack { packers: vec![format], k_axis: 2, mn_axis: 1 },
OptMatMulPack { packers: vec![format], k_axis: 2, mn_axis: 1 },
&[kernel],
)?[0])
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/ops/einsum/kernel_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tract_linalg::mmm::{MMMInputValue, MatMatMul};

use crate::internal::*;
use crate::ops::matmul::de_block_quant::{BlockQuantFact, BlockQuantValue};
use crate::ops::matmul::pack::MatMatMulPack;
use crate::ops::matmul::pack::OptMatMulPack;

use super::optimize::EinSumAnnotatedAsMatMul;

Expand Down Expand Up @@ -47,13 +47,13 @@ pub fn wire_packing(
.with_context(|| format!("No packing for {mmm:?} with inputs {a_dt:?} and {b_dt:?}"))?;
let pa = patch.wire_node(
format!("{prefix}.pack_a"),
MatMatMulPack { k_axis: op.a_k(), mn_axis: op.a_m(), packers: vec![pa.clone()] },
OptMatMulPack { k_axis: op.a_k(), mn_axis: op.a_m(), packers: vec![pa.clone()] },
&[operands[0]],
)?[0];

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down Expand Up @@ -99,7 +99,7 @@ fn with_block_quant(
.clone();
patch
.node_mut(pb.node)
.op_as_mut::<MatMatMulPack>()
.op_as_mut::<OptMatMulPack>()
.context("Expected MatMatMulPack on B")?
.packers
.push(alternative_b_packing);
Expand Down Expand Up @@ -149,7 +149,7 @@ fn with_block_quant_matmat(

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down Expand Up @@ -198,7 +198,7 @@ fn with_block_quant_matvec(

let pb = patch.wire_node(
format!("{prefix}.pack_b"),
MatMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
OptMatMulPack { k_axis: op.b_k(), mn_axis: op.b_n(), packers: vec![pb.clone()] },
&[operands[1]],
)?[0];

Expand Down
32 changes: 25 additions & 7 deletions core/src/ops/matmul/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use tract_data::TooEarly;
use tract_linalg::frame::PackedFormat;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MatMatMulPack {
pub struct OptMatMulPack {
pub(crate) packers: Vec<PackedFormat>,
pub(crate) k_axis: usize,
pub(crate) mn_axis: usize,
}

impl Op for MatMatMulPack {
impl Op for OptMatMulPack {
fn name(&self) -> Cow<str> {
"MatMatMulPack".into()
"OptMatMulPack".into()
}

fn info(&self) -> TractResult<Vec<String>> {
Expand All @@ -24,7 +24,7 @@ impl Op for MatMatMulPack {
impl_op_same_as!();
}

impl EvalOp for MatMatMulPack {
impl EvalOp for OptMatMulPack {
fn is_stateless(&self) -> bool {
true
}
Expand All @@ -38,9 +38,14 @@ impl EvalOp for MatMatMulPack {
}
}

impl TypedOp for MatMatMulPack {
impl TypedOp for OptMatMulPack {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
Ok(tvec!(Opaque::datum_type().fact(self.output_shape(&inputs[0].shape))))
let k = inputs[0].shape[self.k_axis].clone();
let mn = inputs[0].shape[self.mn_axis].clone();
let opaque_fact = PackedOpaqueFact { k, mn, packers: self.packers.clone() };
Ok(tvec!(Opaque::datum_type()
.fact(self.output_shape(&inputs[0].shape))
.with_opaque_fact(opaque_fact)))
}

fn axes_mapping(
Expand All @@ -63,7 +68,7 @@ impl TypedOp for MatMatMulPack {
as_op!();
}

impl MatMatMulPack {
impl OptMatMulPack {
fn do_eval(&self, session: &SessionState, input: TValue) -> TractResult<TVec<TValue>> {
unsafe {
let packer = if self.packers.len() == 1 {
Expand Down Expand Up @@ -118,3 +123,16 @@ impl MatMatMulPack {
packed_shape
}
}

#[derive(Hash, Clone, Debug, PartialEq, Eq)]
pub struct PackedOpaqueFact {
pub k: TDim,
pub mn: TDim,
pub packers: Vec<PackedFormat>,
}

impl OpaqueFact for PackedOpaqueFact {
fn mem_size(&self) -> TDim {
self.k.clone() * &self.mn * self.packers[0].dt.size_of()
}
}
Loading