Skip to content

Commit

Permalink
Redesign metal kernel api
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos committed Oct 16, 2024
1 parent 8f67034 commit 6090449
Show file tree
Hide file tree
Showing 35 changed files with 372 additions and 283 deletions.
7 changes: 5 additions & 2 deletions metal/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,12 @@ impl MetalContext {
self.arena.borrow_mut().take().ok_or_else(|| {
anyhow!("Unexpected None arena while executing inside a metal arena")
})?;
log::info!("MetalArena: {:.3} %", arena.used_capacity() as f32 / arena.capacity() as f32);
log::debug!("MetalArena: {:.3} %", arena.used_capacity() as f32 / arena.capacity() as f32);
arena.try_reset();
log::info!("MetalArena after reset: {:.3} %", arena.used_capacity() as f32 / arena.capacity() as f32);
log::debug!(
"MetalArena after reset: {:.3} %",
arena.used_capacity() as f32 / arena.capacity() as f32
);
Ok((arena, res))
}

Expand Down
48 changes: 35 additions & 13 deletions metal/src/fact.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
use std::fmt;
use tract_core::internal::*;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MetalFact(pub TypedFact);
pub enum MetalFactKind {
Temporary,
Shared,
}

#[derive(Clone, PartialEq, Eq, Hash)]
pub struct MetalFact {
pub kind: MetalFactKind,
pub fact: TypedFact,
}

impl MetalFact {
pub fn new(fact: TypedFact) -> TractResult<Self> {
pub fn new(kind: MetalFactKind, fact: TypedFact) -> TractResult<Self> {
ensure!(fact.as_metal_fact().is_none());
Ok(Self(fact))
Ok(Self { kind, fact })
}

pub fn shared(fact: TypedFact) -> TractResult<Self> {
Self::new(MetalFactKind::Shared, fact)
}

pub fn marked_as_temporary(self) -> Self {
Self { kind: MetalFactKind::Temporary, ..self }
}

pub fn into_typed_fact(self) -> TypedFact {
self.0
self.fact
}

pub fn into_opaque_fact(self) -> TypedFact {
Expand All @@ -20,25 +38,29 @@ impl MetalFact {

impl OpaqueFact for MetalFact {
fn clarify_dt_shape(&self) -> Option<(DatumType, &[usize])> {
self.0.shape.as_concrete().map(|s| (self.0.datum_type, s))
self.fact.shape.as_concrete().map(|s| (self.fact.datum_type, s))
}

fn mem_size(&self) -> TDim {
self.0.mem_size()
self.fact.mem_size()
}
}

impl fmt::Debug for MetalFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self.kind {
MetalFactKind::Shared => write!(fmt, "Metal,Shared({:?})", self.fact),
MetalFactKind::Temporary => write!(fmt, "Metal,Tmp({:?})", self.fact),
}
}
}

pub trait MetalTypedFactExt {
fn into_opaque_metal_fact(self) -> TractResult<TypedFact>;
fn to_metal_fact(&self) -> TractResult<&MetalFact>;
fn as_metal_fact(&self) -> Option<&MetalFact>;
}

impl MetalTypedFactExt for TypedFact {
fn into_opaque_metal_fact(self) -> TractResult<TypedFact> {
Ok(MetalFact::new(self)?.into_opaque_fact())
}

fn to_metal_fact(&self) -> TractResult<&MetalFact> {
ensure!(
self.datum_type == DatumType::Opaque,
Expand All @@ -57,12 +79,12 @@ impl MetalTypedFactExt for TypedFact {
impl std::ops::Deref for MetalFact {
type Target = TypedFact;
fn deref(&self) -> &Self::Target {
&self.0
&self.fact
}
}

impl std::convert::AsRef<TypedFact> for MetalFact {
fn as_ref(&self) -> &TypedFact {
&self.0
&self.fact
}
}
19 changes: 9 additions & 10 deletions metal/src/kernels/array/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ impl MultiBroadcast {
input_offset: usize,
output_shape: &[usize],
) -> Result<MetalTensor> {
let output = self.dispatch_eval(context, input, input_offset, output_shape)?;
let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), &output_shape)? };
self.dispatch_eval(context, input, input_offset, &output)?;
context.wait_until_completed()?;
Ok(output)
}
Expand All @@ -56,15 +57,13 @@ impl MultiBroadcast {
context: &MetalContext,
input: &MetalTensor,
input_offset: usize,
output_shape: &[usize],
) -> Result<MetalTensor> {
output: &MetalTensor,
) -> Result<()> {
input.retain_until_completion();
ensure!(input_offset % input.datum_type().size_of() == 0);

let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), output_shape)? };
output.retain_until_completion();

ensure!(input.rank() <= output.rank(), "Input must have a rank lowe than output");
ensure!(input_offset % input.datum_type().size_of() == 0);
ensure!(input.rank() <= output.rank(), "Input must have a rank lower or equal to output");

let mut input_shape = vec![1; output.rank() - input.rank()];
input_shape.extend(input.shape());
Expand All @@ -74,7 +73,7 @@ impl MultiBroadcast {
anyhow!(
"Unsupported broadcast for broadcast op: (in: {:?}, out: {:?})",
input.shape(),
output_shape
output.shape(),
)
})?;

Expand All @@ -98,14 +97,14 @@ impl MultiBroadcast {
);
encoder.set_slice(1, &input_broadcast_strides);
encoder.set_metal_tensor(2, &output, metal::MTLResourceUsage::Write);
encoder.set_slice(3, output_shape);
encoder.set_slice(3, output.shape());
encoder.set_slice(4, output.strides());

let grid_size = utils::build_metal_size_for_shape(output.shape());
let group_size = utils::build_metal_size_with_ones();

encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
Ok(output)
Ok(())
}
}
20 changes: 13 additions & 7 deletions metal/src/kernels/array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,28 @@ impl Cast {
input: &MetalTensor,
to_dt: DatumType,
) -> Result<MetalTensor> {
let o = self.dispatch_eval(context, input, to_dt)?;
let output = unsafe { MetalTensor::uninitialized_dt(to_dt, input.shape())? };
self.dispatch_eval(context, input, &output)?;
context.wait_until_completed()?;
Ok(o)
Ok(output)
}

pub fn dispatch_eval(
&self,
context: &MetalContext,
input: &MetalTensor,
to_dt: DatumType,
) -> Result<MetalTensor> {
let output = unsafe { MetalTensor::uninitialized_dt(to_dt, input.shape())? };
output: &MetalTensor,
) -> Result<()> {
input.retain_until_completion();
output.retain_until_completion();
ensure!(
input.shape() == output.shape(),
"Cast I/O don't have the same shape in: {:?}, out: {:?}",
input.shape(),
output.shape()
);

let kernel_name = self.kernel_name(input.datum_type(), to_dt)?;
let kernel_name = self.kernel_name(input.datum_type(), output.datum_type())?;

let pipeline =
context.shared_context().load_pipeline(LibraryName::ArrayOps, &kernel_name)?;
Expand All @@ -80,6 +86,6 @@ impl Cast {
let group_size = MTLSize { width: 1, height: 1, depth: 1 };
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
Ok(output)
Ok(())
}
}
49 changes: 25 additions & 24 deletions metal/src/kernels/array/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use std::fmt;
use tract_core::internal::*;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Concat;
pub struct Concat {
pub axis: usize,
}

impl fmt::Display for Concat {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -39,13 +41,14 @@ impl Concat {
Ok(format!("array_ops::copy_{broadcast_name}_{tname}"))
}

pub fn eval(
&self,
context: &MetalContext,
inputs: &[&MetalTensor],
axis: usize,
) -> Result<MetalTensor> {
let output = self.dispatch_eval(context, inputs, axis)?;
pub fn eval(&self, context: &MetalContext, inputs: &[&MetalTensor]) -> Result<MetalTensor> {
ensure!(!inputs.is_empty());
let mut output_shape = inputs[0].shape().to_vec();
output_shape[self.axis] = inputs.iter().map(|it| it.shape()[self.axis]).sum();
let output =
unsafe { MetalTensor::uninitialized_dt(inputs[0].datum_type(), &output_shape)? };

self.dispatch_eval(context, inputs, &output)?;
context.wait_until_completed()?;
Ok(output)
}
Expand All @@ -54,29 +57,27 @@ impl Concat {
&self,
context: &MetalContext,
inputs: &[&MetalTensor],
axis: usize,
) -> Result<MetalTensor> {
output: &MetalTensor,
) -> Result<()> {
ensure!(!inputs.is_empty());

let output_dt = inputs[0].datum_type();
let mut output_shape = inputs[0].shape().to_vec();
output_shape[axis] = inputs.iter().map(|it| it.shape()[axis]).sum();
let output_strides = Tensor::natural_strides(&output_shape);
output.retain_until_completion();

let output_shape = output.shape();
let output_strides = output.strides();

let mut offsets = tvec![0; inputs.len()];
let mut cursor = 0;

for (i_idx, input) in inputs.iter().enumerate() {
let i_shape = input.shape();
ensure!(i_shape[..axis] == output_shape[..axis]);
ensure!(i_shape[axis + 1..] == output_shape[axis + 1..]);
offsets[i_idx] = cursor * (output_strides[axis] as usize) * output_dt.size_of();
cursor += i_shape[axis];
ensure!(i_shape[..self.axis] == output_shape[..self.axis]);
ensure!(i_shape[self.axis + 1..] == output_shape[self.axis + 1..]);
offsets[i_idx] =
cursor * (output_strides[self.axis] as usize) * output.datum_type().size_of();
cursor += i_shape[self.axis];
}

let output = unsafe { MetalTensor::uninitialized_dt(output_dt, &output_shape)? };
output.retain_until_completion();

let broadcast_kind = BroadcastKind::from_rank(output.rank()).with_context(|| {
anyhow!(
"Unsupported broadcast for broadcast op: (in: {:?}, out: {:?})",
Expand All @@ -85,7 +86,7 @@ impl Concat {
)
})?;

let kernel_name = self.kernel_name(output_dt, broadcast_kind)?;
let kernel_name = self.kernel_name(output.datum_type(), broadcast_kind)?;
let pipeline =
context.shared_context().load_pipeline(LibraryName::ArrayOps, &kernel_name)?;
let command_buffer = context.command_buffer();
Expand Down Expand Up @@ -117,7 +118,7 @@ impl Concat {
encoder.end_encoding();
}

Ok(output)
Ok(())
}
}

Expand All @@ -141,7 +142,7 @@ mod tests {
inputs.push(Tensor::from_shape(shape, &data)?.into_metal()?);
}

let output = Concat.eval(context, &inputs.iter().collect_vec(), axis)?;
let output = Concat { axis }.eval(context, &inputs.iter().collect_vec())?;
let ref_output = Tensor::stack_tensors(
axis,
&inputs.iter().map(|it| it.to_cpu()).collect::<Result<Vec<_>>>()?,
Expand Down
12 changes: 7 additions & 5 deletions metal/src/kernels/array/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ impl Memcpy {
context: &MetalContext,
input: &MetalTensor,
input_offset: usize,
) -> Result<MetalTensor> {
output: &MetalTensor,
) -> Result<()> {
ensure!(input_offset % input.datum_type().size_of() == 0);
ensure!(output.len() <= input.len() - input_offset);

input.retain_until_completion();

let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), input.shape())? };
output.retain_until_completion();

let kernel_name = self.kernel_name(input.datum_type())?;
Expand All @@ -72,16 +72,18 @@ impl Memcpy {
let group_size = MTLSize { width: 1, height: 1, depth: 1 };
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();
Ok(output)
Ok(())
}

pub fn eval(
&self,
context: &MetalContext,
input: &MetalTensor,
input_offset: usize,
output_shape: &[usize],
) -> Result<MetalTensor> {
let output = self.dispatch_eval(context, input, input_offset)?;
let output = unsafe { MetalTensor::uninitialized_dt(input.datum_type(), output_shape)? };
self.dispatch_eval(context, input, input_offset, &output)?;
context.wait_until_completed()?;
Ok(output)
}
Expand Down
Loading

0 comments on commit 6090449

Please sign in to comment.