diff --git a/src/core/graph/autodiff.rs b/src/core/graph/autodiff.rs index c664d10..b54faee 100644 --- a/src/core/graph/autodiff.rs +++ b/src/core/graph/autodiff.rs @@ -64,7 +64,9 @@ impl Context { //Again again, clone() here is not wonderful, there's gotta be a better way to //store the i64 vec for Transpose match self.nodes[dependent_node].operation.clone() { - Operation::Constant(_) => panic!("Constant found as dependent node!"), + Operation::Constant(_) + | Operation::RngUniform(_, _, _) + | Operation::RngNormal(_, _, _) => panic!("Constant found as dependent node!"), Operation::Parameter(_) => panic!("Parameter found as dependent node!"), Operation::StopGradient(_) => continue, diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index 926114f..c748507 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -1,6 +1,7 @@ use super::*; use slotmap::SlotMap; use smallvec::SmallVec; +use xla::{XlaOp, ArrayShape}; use std::collections::{HashMap, HashSet, VecDeque}; #[derive(thiserror::Error, Debug)] @@ -168,6 +169,36 @@ impl Context { } } + Operation::RngNormal(mu, sigma, shape) => { + if unda_xla_map.contains_key(&mu) + && unda_xla_map.contains_key(&sigma) + && xla_op_slotmap.contains_key(unda_xla_map[&mu]) + && xla_op_slotmap.contains_key(unda_xla_map[&sigma]) + { + let dtype = self.nodes[mu].dtype; + let xla_op = XlaOp::rng_normal(&xla_op_slotmap[unda_xla_map[&mu]], + &xla_op_slotmap[unda_xla_map[&sigma]], &shape.to_array_shape(dtype))?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } + Operation::RngUniform(min, max, shape) => { + if unda_xla_map.contains_key(&min) + && unda_xla_map.contains_key(&max) + && xla_op_slotmap.contains_key(unda_xla_map[&min]) + && xla_op_slotmap.contains_key(unda_xla_map[&max]) + { + let dtype = self.nodes[min].dtype; + let xla_op = XlaOp::rng_uniform(&xla_op_slotmap[unda_xla_map[&min]], + &xla_op_slotmap[unda_xla_map[&max]], &shape.to_array_shape(dtype))?; + let xla_id = xla_op_slotmap.insert(xla_op); + unda_xla_map.insert(*dependent_op, xla_id); + unda_op_queue.push_back(*dependent_op); + covered_ops.insert(*dependent_op); + } + } Operation::Pow(a, b) => { if unda_xla_map.contains_key(&a) && unda_xla_map.contains_key(&b) diff --git a/src/core/graph/consteval.rs b/src/core/graph/consteval.rs index eaa81d6..578e27f 100644 --- a/src/core/graph/consteval.rs +++ b/src/core/graph/consteval.rs @@ -35,6 +35,16 @@ impl Context { changed = true; } } + Operation::RngUniform(a, b, shape) + | Operation::RngNormal(a, b, shape) => { + if a == to_remove { + self.nodes[dep_node].operation = Operation::RngUniform(rep_with, b, shape); + changed = true; + } else if b == to_remove { + self.nodes[dep_node].operation = Operation::RngUniform(a, rep_with, shape); + changed = true; + } + } Operation::Pow(a, b) => { if a == to_remove && a == b { self.nodes[dep_node].operation = Operation::Pow(rep_with, rep_with); @@ -478,6 +488,8 @@ impl Context { | Operation::NotEqual(a, b) | Operation::Div(a, b) | Operation::Pow(a, b) + | Operation::RngUniform(a, b, _) + | Operation::RngNormal(a, b, _) | Operation::MatMul(a, b) => { if self.nodes[a].is_const().is_none() { to_visit.push(a); diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index 8f2e9cb..665c66d 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::format}; use super::*; @@ -104,6 +104,8 @@ impl Context { Operation::Exp(a) => format!("Exp ({})", self.to_string(a)), Operation::Log(a) => format!("Log ({})", self.to_string(a)), Operation::Transpose(a, b) => format!("Transpose: ({}) ({:?})", self.to_string(a), b), + Operation::RngUniform(a, b, shape) => format!("RngUniform: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape), + Operation::RngNormal(a, b, shape) => format!("RngNormal: ({}) ({}) ({})", self.to_string(a), self.to_string(b), shape), Operation::Equal(a, b) => { format!("LessThan ({}) ({})", self.to_string(a), self.to_string(b)) } diff --git a/src/core/graph/math.rs b/src/core/graph/math.rs index 49e4e6c..9fccb5d 100644 --- a/src/core/graph/math.rs +++ b/src/core/graph/math.rs @@ -102,6 +102,52 @@ impl Context { Ok(node_id) } + pub fn rng_uniform(&mut self, min: NodeIdentifier, max: NodeIdentifier, shape: &[u32]) -> Result { + if self.nodes[min].dtype != self.nodes[max].dtype { + Err(ContextError::IncompatibleOperandTypes( + self.nodes[min].dtype, + self.nodes[max].dtype, + callsite!(1), + )) + } else { + let shape_node = Shape::from(shape); + let node = Node { + callsite: callsite!(1), + shape: shape_node.clone(), + operation: Operation::RngUniform(min, max, shape_node), + dtype: self.nodes[min].dtype, + }; + let node_id = self.nodes.insert(node); + self.dependent_nodes.entry(min).or_default().push(node_id); + self.dependent_nodes.entry(max).or_default().push(node_id); + + Ok(node_id) + } + } + + pub fn rng_normal(&mut self, mu: NodeIdentifier, sigma: NodeIdentifier, shape: &[u32]) -> Result { + if self.nodes[mu].dtype != self.nodes[sigma].dtype { + Err(ContextError::IncompatibleOperandTypes( + self.nodes[mu].dtype, + self.nodes[sigma].dtype, + callsite!(1), + )) + } else { + let shape_node = Shape::from(shape); + let node = Node { + callsite: callsite!(1), + shape: shape_node.clone(), + operation: Operation::RngNormal(mu, sigma, shape_node), + dtype: self.nodes[mu].dtype, + }; + let node_id = self.nodes.insert(node); + self.dependent_nodes.entry(mu).or_default().push(node_id); + self.dependent_nodes.entry(sigma).or_default().push(node_id); + + Ok(node_id) + } + } + pub fn exp(&mut self, a: NodeIdentifier) -> Result { let node = Node { callsite: callsite!(1), diff --git a/src/core/graph/operation.rs b/src/core/graph/operation.rs index 14c9f8d..2bd15a2 100644 --- a/src/core/graph/operation.rs +++ b/src/core/graph/operation.rs @@ -70,6 +70,8 @@ pub enum Operation { }, OneHot(NodeIdentifier), + RngUniform(NodeIdentifier, NodeIdentifier, Shape), + RngNormal(NodeIdentifier, NodeIdentifier, Shape) } impl Hash for Operation { @@ -147,6 +149,12 @@ impl Hash for Operation { n_tiles.hash(state); dim.hash(state); } + Self::RngUniform(a, b, dim) + | Self::RngNormal(a, b, dim) => { + a.hash(state); + b.hash(state); + dim.hash(state); + } } } } @@ -190,6 +198,8 @@ impl PartialEq for Operation { }, ) => pred == pred2 && on_true == on_true2 && on_false == on_false2, (&Self::TypeCast(a, ty), &Self::TypeCast(b, ty2)) => a == b && ty == ty2, + (&Self::RngUniform(a, b, shape), &Self::RngUniform(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2, + (&Self::RngNormal(a, b, shape), &Self::RngNormal(a2, b2, shape2)) => a == a2 && b == b2 && shape == shape2, (&Self::Transpose(a, dim), &Self::Transpose(b, dim2)) => a == b && dim == dim2, ( &Self::SliceInDim { diff --git a/src/core/graph/shape.rs b/src/core/graph/shape.rs index 7bd7086..f2ea895 100644 --- a/src/core/graph/shape.rs +++ b/src/core/graph/shape.rs @@ -1,4 +1,5 @@ use smallvec::SmallVec; +use xla::ArrayShape; use super::callsite::Callsite; @@ -64,6 +65,10 @@ impl Shape { } } + pub fn to_array_shape(&self, dtype: xla::ElementType) -> ArrayShape { + ArrayShape::new(self.sizes.iter().map(|d| *d as i64).collect(), dtype) + } + pub fn matmul_shape(&self, dim2: &[u32]) -> Option> { let dim1 = &self.sizes; if dim1.last()? == dim2.get(dim2.len().saturating_sub(2))? { diff --git a/src/core/graph/subterm.rs b/src/core/graph/subterm.rs index 0bf0a3d..3edd299 100644 --- a/src/core/graph/subterm.rs +++ b/src/core/graph/subterm.rs @@ -47,6 +47,8 @@ impl Context { | Operation::GreaterThanEq(a, b) | Operation::LessThanEq(a, b) | Operation::MatMul(a, b) + | Operation::RngNormal(a, b, _) + | Operation::RngUniform(a, b, _) | Operation::Pow(a, b) => { to_visit.push(a); to_visit.push(b); diff --git a/src/core/graph/tests_cpu.rs b/src/core/graph/tests_cpu.rs index 149525e..b0d5cea 100644 --- a/src/core/graph/tests_cpu.rs +++ b/src/core/graph/tests_cpu.rs @@ -68,9 +68,68 @@ mod tests { create_test!(test_add_1_2, add, F32, 1f32, 2f32, 3f32); create_test!(test_sub_1_2, sub, F32, 1f32, 2f32, -1f32); + #[test] + fn test_normal_dist() { + let mut ctx = Context::new(); + let mu = ctx.scalar(0, xla::ElementType::F32).expect("mu = 0"); + let sigma = ctx.scalar(1, xla::ElementType::F32).expect("sigma = 1"); + let mat = ctx.rng_normal(mu, sigma, &[2,3]).expect("sample the normal distribution"); + + let client = xla::PjRtClient::cpu().expect("client"); + let name = "test"; + let executable = ctx.compile(&name, [mat], &client).expect("executable"); + + let device_result = executable.execute::(&[]).expect("execute"); + let host_result = device_result[0][0] + .to_literal_sync() + .expect("to_literal_sync"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); + println!("{:?}", rust_result); + + match untupled_result.shape().unwrap() { + Shape::Array(shape) => { + assert_eq!(shape.dims(), &[2,3]); + }, + _ => { + panic!("Shape is not correct"); + } + } + } + + #[test] + fn test_uniform_dist() { + let mut ctx = Context::new(); + let min = ctx.scalar(0, xla::ElementType::F32).expect("min = 0"); + let max = ctx.scalar(1, xla::ElementType::F32).expect("max = 10"); + let mat = ctx.rng_uniform(min, max, &[10,1]).expect("sample the uniform distribution"); + + let client = xla::PjRtClient::cpu().expect("client"); + let name = "test"; + let executable = ctx.compile(&name, [mat], &client).expect("executable"); + + let device_result = executable.execute::(&[]).expect("execute"); + let host_result = device_result[0][0] + .to_literal_sync() + .expect("to_literal_sync"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); + println!("{:?}", rust_result); + + match untupled_result.shape().unwrap() { + Shape::Array(shape) => { + assert_eq!(shape.dims(), &[10,1]); + }, + _ => { + panic!("Shape is not correct"); + } + } + } + + #[test] fn test_large_cte() { - let mut ctx = Context::new(); + let mut ctx = Context::new(); let a = ctx.parameter("a", [], xla::ElementType::F32).expect("a"); let two = ctx.scalar(2, xla::ElementType::F32).expect("2");