diff --git a/Cargo.lock b/Cargo.lock index c382223c42..4764d26a8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1884,6 +1884,7 @@ dependencies = [ "indexmap", "itertools", "log", + "num-traits", "petgraph", "pp-rs", "ron", diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 5a2f983175..2a4eca491d 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -82,6 +82,7 @@ petgraph = { version = "0.6", optional = true } pp-rs = { version = "0.2.1", optional = true } hexf-parse = { version = "0.2.1", optional = true } unicode-xid = { version = "0.2.5", optional = true } +num-traits = "0.2" [build-dependencies] cfg_aliases.workspace = true diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 1b7f5cf910..28c1fb0274 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -254,6 +254,256 @@ gen_component_wise_extractor! { ], } +/// Vector for each [`Literal`] type +/// +/// This type ensures that all elements have same type +#[derive(Debug)] +enum LiteralVector { + F64(ArrayVec), + F32(ArrayVec), + U32(ArrayVec), + I32(ArrayVec), + U64(ArrayVec), + I64(ArrayVec), + Bool(ArrayVec), + AbstractInt(ArrayVec), + AbstractFloat(ArrayVec), +} + +impl LiteralVector { + #[allow(clippy::pattern_type_mismatch, clippy::missing_const_for_fn)] + fn len(&self) -> usize { + match self { + LiteralVector::F64(v) => v.len(), + LiteralVector::F32(v) => v.len(), + LiteralVector::U32(v) => v.len(), + LiteralVector::I32(v) => v.len(), + LiteralVector::U64(v) => v.len(), + LiteralVector::I64(v) => v.len(), + LiteralVector::Bool(v) => v.len(), + LiteralVector::AbstractInt(v) => v.len(), + LiteralVector::AbstractFloat(v) => v.len(), + } + } + /// Creates [`LiteralVector`] of size 1 from single [`Literal`] + fn from_literal(literal: Literal) -> Self { + match literal { + Literal::F64(e) => Self::F64(ArrayVec::from_iter(iter::once(e))), + Literal::F32(e) => Self::F32(ArrayVec::from_iter(iter::once(e))), + Literal::U32(e) => Self::U32(ArrayVec::from_iter(iter::once(e))), + Literal::I32(e) => Self::I32(ArrayVec::from_iter(iter::once(e))), + Literal::U64(e) => Self::U64(ArrayVec::from_iter(iter::once(e))), + Literal::I64(e) => Self::I64(ArrayVec::from_iter(iter::once(e))), + Literal::Bool(e) => Self::Bool(ArrayVec::from_iter(iter::once(e))), + Literal::AbstractInt(e) => Self::AbstractInt(ArrayVec::from_iter(iter::once(e))), + Literal::AbstractFloat(e) => Self::AbstractFloat(ArrayVec::from_iter(iter::once(e))), + } + } + + #[allow(dead_code)] + /// Creates [`LiteralVector`] from Array of [`Literal`]s + /// + /// Panics if vector is empty + fn from_literal_vec( + components: ArrayVec, + ) -> Result { + let scalar = components[0].scalar(); + Self::from_literal_vec_with_scalar_type(components, scalar) + } + + /// Creates [`LiteralVector`] of type provided by scalar from Array of [`Literal`]s + /// + /// Panics if vector is empty, returns error if types do not match + fn from_literal_vec_with_scalar_type( + components: ArrayVec, + scalar: crate::Scalar, + ) -> Result { + assert!(!components.is_empty()); + Ok(match scalar { + crate::Scalar::I32 => Self::I32( + components + .iter() + .map(|l| match l { + &Literal::I32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::U32 => Self::U32( + components + .iter() + .map(|l| match l { + &Literal::U32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::I64 => Self::I64( + components + .iter() + .map(|l| match l { + &Literal::I64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::U64 => Self::U64( + components + .iter() + .map(|l| match l { + &Literal::U64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::F32 => Self::F32( + components + .iter() + .map(|l| match l { + &Literal::F32(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::F64 => Self::F64( + components + .iter() + .map(|l| match l { + &Literal::F64(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::BOOL => Self::Bool( + components + .iter() + .map(|l| match l { + &Literal::Bool(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::ABSTRACT_INT => Self::AbstractInt( + components + .iter() + .map(|l| match l { + &Literal::AbstractInt(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + crate::Scalar::ABSTRACT_FLOAT => Self::AbstractFloat( + components + .iter() + .map(|l| match l { + &Literal::AbstractFloat(v) => Ok(v), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?, + ), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }) + } + + fn from_expr( + expr: Handle, + eval: &mut ConstantEvaluator<'_>, + span: Span, + allow_single: bool, + ) -> Result { + let expr = eval + .eval_zero_value_and_splat(expr, span) + .map(|expr| &eval.expressions[expr])?; + match *expr { + Expression::Literal(literal) => { + if allow_single { + Ok(Self::from_literal(literal)) + } else { + Err(ConstantEvaluatorError::InvalidMathArg) + } + } + Expression::Compose { ty, ref components } => match eval.types[ty].inner { + TypeInner::Vector { scalar, .. } => { + if components.len() > crate::VectorSize::MAX { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + let components: ArrayVec = + crate::proc::flatten_compose(ty, components, eval.expressions, eval.types) + .map(|expr| match eval.expressions[expr] { + Expression::Literal(l) => Ok(l), + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }) + .collect::>()?; + Self::from_literal_vec_with_scalar_type(components, scalar) + } + _ => Err(ConstantEvaluatorError::InvalidMathArg), + }, + _ => Err(ConstantEvaluatorError::InvalidMathArg), + } + } + + /// Returns [`ArrayVec`] of [`Literal`]s + fn to_literal_vec(&self) -> ArrayVec { + #[allow(clippy::pattern_type_mismatch)] + match self { + LiteralVector::F64(v) => v.iter().map(|e| (Literal::F64(*e))).collect(), + LiteralVector::F32(v) => v.iter().map(|e| (Literal::F32(*e))).collect(), + LiteralVector::U32(v) => v.iter().map(|e| (Literal::U32(*e))).collect(), + LiteralVector::I32(v) => v.iter().map(|e| (Literal::I32(*e))).collect(), + LiteralVector::U64(v) => v.iter().map(|e| (Literal::U64(*e))).collect(), + LiteralVector::I64(v) => v.iter().map(|e| (Literal::I64(*e))).collect(), + LiteralVector::Bool(v) => v.iter().map(|e| (Literal::Bool(*e))).collect(), + LiteralVector::AbstractInt(v) => v.iter().map(|e| (Literal::AbstractInt(*e))).collect(), + LiteralVector::AbstractFloat(v) => { + v.iter().map(|e| (Literal::AbstractFloat(*e))).collect() + } + } + } + + fn to_expr( + &self, + eval: &mut ConstantEvaluator<'_>, + ) -> Result { + let lit_vec = self.to_literal_vec(); + assert!(!lit_vec.is_empty()); + if lit_vec.len() == 1 { + Ok(Expression::Literal(lit_vec[0])) + } else { + Ok(Expression::Compose { + ty: eval.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: match lit_vec.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => unreachable!(), + }, + scalar: lit_vec[0].scalar(), + }, + }, + Span::UNDEFINED, + ), + components: lit_vec + .iter() + .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), Span::UNDEFINED)) + .collect::>()?, + }) + } + } + + /// Puts self into eval's expressions arena and returns handle to it + fn handle( + &self, + eval: &mut ConstantEvaluator<'_>, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let expr = self.to_expr(eval)?; + eval.register_evaluated_expr(expr, span) + } +} + #[derive(Debug)] enum Behavior<'a> { Wgsl(WgslRestrictions<'a>), @@ -917,9 +1167,10 @@ impl<'a> ConstantEvaluator<'a> { Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented( "select built-in function".into(), )), - Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented( - format!("{fun:?} built-in function"), - )), + Expression::Relational { fun, argument } => { + let arg = self.check_and_get(argument)?; + self.relational_op(fun, arg, span) + } Expression::ArrayLength(expr) => match self.behavior { Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), Behavior::Glsl(_) => { @@ -1230,6 +1481,149 @@ impl<'a> ConstantEvaluator<'a> { }) } + // geometry + crate::MathFunction::Dot => { + let e1 = LiteralVector::from_expr(arg, self, span, false)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?; + if e1.len() != e2.len() { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + + fn float_dot(a: ArrayVec, b: ArrayVec) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + a.iter().zip(b.iter()).map(|(&aa, &bb)| aa * bb).sum() + } + + fn int_dot( + a: ArrayVec, + b: ArrayVec, + ) -> Result + where + P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul, + { + a.iter() + .zip(b.iter()) + .map(|(&aa, bb)| aa.checked_mul(bb)) + .try_fold(P::zero(), |acc, x| { + if let Some(x) = x { + acc.checked_add(&x) + } else { + None + } + }) + .ok_or(ConstantEvaluatorError::Overflow( + "in dot built-in".to_string(), + )) + } + + LiteralVector::from_literal(match (e1, e2) { + (LiteralVector::AbstractFloat(e1), LiteralVector::AbstractFloat(e2)) => { + Literal::AbstractFloat(float_dot(e1, e2)) + } + (LiteralVector::F32(e1), LiteralVector::F32(e2)) => { + Literal::F32(float_dot(e1, e2)) + } + (LiteralVector::AbstractInt(e1), LiteralVector::AbstractInt(e2)) => { + Literal::AbstractInt(int_dot(e1, e2)?) + } + (LiteralVector::I32(e1), LiteralVector::I32(e2)) => { + Literal::I32(int_dot(e1, e2)?) + } + (LiteralVector::U32(e1), LiteralVector::U32(e2)) => { + Literal::U32(int_dot(e1, e2)?) + } + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }) + .handle(self, span) + } + crate::MathFunction::Cross => { + let e1 = LiteralVector::from_expr(arg, self, span, false)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, false)?; + if e1.len() == 3 && e2.len() == 3 { + fn float_cross( + a: ArrayVec, + b: ArrayVec, + ) -> ArrayVec + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + [ + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ] + .into_iter() + .collect() + } + match (e1, e2) { + (LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { + LiteralVector::AbstractFloat(float_cross(a, b)) + } + (LiteralVector::F32(a), LiteralVector::F32(b)) => { + LiteralVector::F32(float_cross(a, b)) + } + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + } + .handle(self, span) + } else { + Err(ConstantEvaluatorError::InvalidMathArg) + } + } + crate::MathFunction::Length => { + let e1 = LiteralVector::from_expr(arg, self, span, true)?; + + fn float_length(e: ArrayVec) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum, + { + e.iter().map(|&ei| ei * ei).sum::().sqrt() + } + + LiteralVector::from_literal(match e1 { + LiteralVector::AbstractFloat(a) => Literal::AbstractFloat(float_length(a)), + LiteralVector::F32(a) => Literal::F32(float_length(a)), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }) + .handle(self, span) + } + crate::MathFunction::Distance => { + let e1 = LiteralVector::from_expr(arg, self, span, true)?; + let e2 = LiteralVector::from_expr(arg1.unwrap(), self, span, true)?; + if e1.len() != e2.len() { + return Err(ConstantEvaluatorError::InvalidMathArg); + } + + fn float_distance( + a: ArrayVec, + b: ArrayVec, + ) -> F + where + F: std::ops::Mul, + F: num_traits::Float + std::iter::Sum + std::ops::Sub, + { + a.iter() + .zip(b.iter()) + .map(|(&aa, &bb)| aa - bb) + .map(|ei| ei * ei) + .sum::() + .sqrt() + } + LiteralVector::from_literal(match (e1, e2) { + (LiteralVector::AbstractFloat(a), LiteralVector::AbstractFloat(b)) => { + Literal::AbstractFloat(float_distance(a, b)) + } + (LiteralVector::F32(a), LiteralVector::F32(b)) => { + Literal::F32(float_distance(a, b)) + } + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }) + .handle(self, span) + } // computational crate::MathFunction::Sign => { component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) }) @@ -2059,6 +2453,38 @@ impl<'a> ConstantEvaluator<'a> { Ok(Expression::Compose { ty, components }) } + fn relational_op( + &mut self, + fun: crate::RelationalFunction, + arg: Handle, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let arg = LiteralVector::from_expr(arg, self, span, true)?; + let res = LiteralVector::Bool(match fun { + crate::RelationalFunction::IsNan => match arg { + LiteralVector::F64(f) => f.iter().map(|e| e.is_nan()).collect(), + LiteralVector::F32(f) => f.iter().map(|e| e.is_nan()).collect(), + LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_nan()).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::IsInf => match arg { + LiteralVector::F64(f) => f.iter().map(|e| e.is_infinite()).collect(), + LiteralVector::F32(f) => f.iter().map(|e| e.is_infinite()).collect(), + LiteralVector::AbstractFloat(f) => f.iter().map(|e| e.is_infinite()).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::All => match arg { + LiteralVector::Bool(bools) => iter::once(bools.iter().all(|b| *b)).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + crate::RelationalFunction::Any => match arg { + LiteralVector::Bool(bools) => iter::once(bools.iter().any(|b| *b)).collect(), + _ => return Err(ConstantEvaluatorError::InvalidMathArg), + }, + }); + res.handle(self, span) + } + /// Deep copy `expr` from `expressions` into `self.expressions`. /// /// Return the root of the new copy.