Skip to content

Commit

Permalink
Finish distribtued stddev
Browse files Browse the repository at this point in the history
  • Loading branch information
raunakab committed Oct 8, 2024
1 parent 9669a08 commit 64bff42
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 48 deletions.
7 changes: 7 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod sketch_percentile;
mod sort;
pub(crate) mod sparse_tensor;
mod sqrt;
mod square_sum;
mod stddev;
mod struct_;
mod sum;
Expand Down Expand Up @@ -172,6 +173,12 @@ pub trait DaftSumAggable {
fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftSquareSumAggable {
type Output;
fn square_sum(&self) -> Self::Output;
fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftApproxSketchAggable {
type Output;
fn approx_sketch(&self) -> Self::Output;
Expand Down
37 changes: 37 additions & 0 deletions src/daft-core/src/array/ops/square_sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use crate::array::{
ops::{DaftSquareSumAggable, GroupIndices},
prelude::Float64Array,
};

impl DaftSquareSumAggable for Float64Array {
type Output = DaftResult<Self>;

fn square_sum(&self) -> Self::Output {
let sum_square = self
.into_iter()
.flatten()
.copied()
.fold(0., |acc, value| acc + value.powi(2));
let data = PrimitiveArray::from([Some(sum_square)]).boxed();
let field = self.field.clone();
Self::new(field, data)
}

fn grouped_square_sum(&self, groups: &GroupIndices) -> Self::Output {
let grouped_square_sum_iter = groups
.iter()
.map(|group| {
group.iter().copied().fold(0., |acc, index| {
self.get(index as _)
.map_or(acc, |value| acc + value.powi(2))
})
})
.map(Some);
let data = PrimitiveArray::from_trusted_len_iter(grouped_square_sum_iter).boxed();
let field = self.field.clone();
Self::new(field, data)
}
}
7 changes: 2 additions & 5 deletions src/daft-core/src/array/ops/stddev.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::sync::Arc;

use arrow2::array::PrimitiveArray;
use common_error::DaftResult;
use daft_schema::{dtype::DataType, field::Field};

use crate::{
array::{
Expand All @@ -20,7 +17,7 @@ impl DaftStddevAggable for DataArray<Float64Type> {
let stats = stats::calculate_stats(self)?;
let values = self.into_iter().flatten().copied();
let stddev = stats::calculate_stddev(stats, values);
let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64));
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from([stddev]).boxed();
Self::new(field, data)
}
Expand All @@ -30,7 +27,7 @@ impl DaftStddevAggable for DataArray<Float64Type> {
let values = group.iter().filter_map(|&index| self.get(index as _));
stats::calculate_stddev(stats, values)
});
let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64));
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from_iter(grouped_stddevs_iter).boxed();
Self::new(field, data)
}
Expand Down
38 changes: 25 additions & 13 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use logical::Decimal128Array;

use crate::{
array::{
ops::{DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, GroupIndices},
ops::{
DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftSquareSumAggable,
DaftStddevAggable, DaftSumAggable, GroupIndices,
},
ListArray,
},
count_mode::CountMode,
Expand All @@ -26,12 +29,10 @@ impl Series {
}

pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftSumAggable, datatypes::DataType::*};

match self.data_type() {
// intX -> int64 (in line with numpy)
Int8 | Int16 | Int32 | Int64 => {
let casted = self.cast(&Int64)?;
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
let casted = self.cast(&DataType::Int64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series())
Expand All @@ -40,8 +41,8 @@ impl Series {
}
}
// uintX -> uint64 (in line with numpy)
UInt8 | UInt16 | UInt32 | UInt64 => {
let casted = self.cast(&UInt64)?;
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
let casted = self.cast(&DataType::UInt64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series())
Expand All @@ -50,23 +51,23 @@ impl Series {
}
}
// floatX -> floatX (in line with numpy)
Float32 => match groups {
DataType::Float32 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float32Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float32Array>()?)?.into_series()),
},
Float64 => match groups {
DataType::Float64 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float64Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float64Array>()?)?.into_series()),
},
Decimal128(_, _) => match groups {
DataType::Decimal128(_, _) => match groups {
Some(groups) => Ok(Decimal128Array::new(
Field {
dtype: try_sum_supertype(self.data_type())?,
Expand Down Expand Up @@ -94,13 +95,24 @@ impl Series {
}
}

pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*};
pub fn square_sum(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(
|| casted.square_sum(),
|groups| casted.grouped_square_sum(groups),
)?
.into_series();
Ok(series)
}

pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
// Upcast all numeric types to float64 and compute approx_sketch.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
let casted = self.cast(&DataType::Float64)?;
match groups {
Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch(
&casted.f64()?,
Expand Down
24 changes: 12 additions & 12 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ pub enum AggExpr {
#[display("sum({_0})")]
Sum(ExprRef),

#[display("square_sum({_0})")]
SquareSum(ExprRef),

#[display("approx_percentile({}, percentiles={:?}, force_list_output={})", _0.child, _0.percentiles, _0.force_list_output)]
ApproxPercentile(ApproxPercentileParams),

Expand All @@ -127,9 +130,6 @@ pub enum AggExpr {
#[display("stddev({_0})")]
Stddev(ExprRef),

#[display("stddev_merge({_0})")]
StddevMerge(ExprRef),

#[display("min({_0})")]
Min(ExprRef),

Expand Down Expand Up @@ -171,13 +171,13 @@ impl AggExpr {
match self {
Self::Count(expr, ..)
| Self::Sum(expr)
| Self::SquareSum(expr)
| Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. })
| Self::ApproxCountDistinct(expr)
| Self::ApproxSketch(expr, _)
| Self::MergeSketch(expr, _)
| Self::Mean(expr)
| Self::Stddev(expr)
| Self::StddevMerge(expr)
| Self::Min(expr)
| Self::Max(expr)
| Self::AnyValue(expr, _)
Expand All @@ -197,6 +197,10 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_sum()"))
}
Self::SquareSum(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_square_sum()"))
}
Self::ApproxPercentile(ApproxPercentileParams {
child: expr,
percentiles,
Expand Down Expand Up @@ -232,10 +236,6 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_stddev()"))
}
Self::StddevMerge(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_stddev_merge()"))
}
Self::Min(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_min()"))
Expand Down Expand Up @@ -266,13 +266,13 @@ impl AggExpr {
match self {
Self::Count(expr, ..)
| Self::Sum(expr)
| Self::SquareSum(expr)
| Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. })
| Self::ApproxCountDistinct(expr)
| Self::ApproxSketch(expr, _)
| Self::MergeSketch(expr, _)
| Self::Mean(expr)
| Self::Stddev(expr)
| Self::StddevMerge(expr)
| Self::Min(expr)
| Self::Max(expr)
| Self::AnyValue(expr, _)
Expand All @@ -292,9 +292,9 @@ impl AggExpr {
match self {
Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode),
Self::Sum(_) => Self::Sum(first_child()),
Self::SquareSum(_) => Self::SquareSum(first_child()),
Self::Mean(_) => Self::Mean(first_child()),
Self::Stddev(_) => Self::Stddev(first_child()),
Self::StddevMerge(_) => Self::StddevMerge(first_child()),
Self::Min(_) => Self::Min(first_child()),
Self::Max(_) => Self::Max(first_child()),
Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls),
Expand Down Expand Up @@ -325,7 +325,7 @@ impl AggExpr {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name.as_str(), DataType::UInt64))
}
Self::Sum(expr) => {
Self::Sum(expr) | Self::SquareSum(expr) => {
let field = expr.to_field(schema)?;
Ok(Field::new(
field.name.as_str(),
Expand Down Expand Up @@ -392,7 +392,7 @@ impl AggExpr {
};
Ok(Field::new(field.name, dtype))
}
Self::Mean(expr) | Self::Stddev(expr) | Self::StddevMerge(expr) => {
Self::Mean(expr) | Self::Stddev(expr) => {
let field = expr.to_field(schema)?;
Ok(Field::new(
field.name.as_str(),
Expand Down
4 changes: 1 addition & 3 deletions src/daft-dsl/src/resolve_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<AggExpr> {
AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode)
}
AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()),
AggExpr::SquareSum(e) => AggExpr::SquareSum(e.alias(name.clone())),
AggExpr::ApproxPercentile(ApproxPercentileParams {
child: e,
percentiles,
Expand All @@ -239,9 +240,6 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<AggExpr> {
}
AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()),
AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()),
AggExpr::StddevMerge(e) => {
AggExpr::StddevMerge(Expr::Alias(e, name.clone()).into())
}
AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()),
AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()),
AggExpr::AnyValue(e, ignore_nulls) => {
Expand Down
8 changes: 4 additions & 4 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Sum, |_| e)
}
AggExpr::SquareSum(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::SquareSum, |_| e)
}
AggExpr::ApproxPercentile(ApproxPercentileParams {
ref child,
ref percentiles,
Expand Down Expand Up @@ -419,10 +423,6 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Stddev, |_| e)
}
AggExpr::StddevMerge(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::StddevMerge, |_| e)
}
AggExpr::Min(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Min, |_| e)
Expand Down
51 changes: 43 additions & 8 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ pub fn populate_aggregation_stages(
));
final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name));
}
AggExpr::SquareSum(..) => {
unimplemented!("User-facing square_sum aggregation is not implemented")
}
AggExpr::Mean(e) => {
let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id;
let count_id = AggExpr::Count(e.clone(), CountMode::Valid)
Expand Down Expand Up @@ -846,18 +849,50 @@ pub fn populate_aggregation_stages(
);
}
AggExpr::Stddev(sub_expr) => {
// first stage
// first stage aggregation
let sum_expr = AggExpr::Sum(sub_expr.clone());
let sq_sum_expr = AggExpr::SquareSum(sub_expr.clone());
let count_expr = AggExpr::Count(sub_expr.clone(), CountMode::Valid);
add_to_stage(&mut first_stage_aggs, get_id(&sum_expr), sum_expr);
add_to_stage(&mut first_stage_aggs, get_id(&count_expr), count_expr);
let sum_id = get_id(&sum_expr);
let sq_sum_id = get_id(&sq_sum_expr);
let count_id = get_id(&count_expr);
add_to_stage(&mut first_stage_aggs, sum_id.clone(), sum_expr);
add_to_stage(&mut first_stage_aggs, sq_sum_id.clone(), sq_sum_expr);
add_to_stage(&mut first_stage_aggs, count_id.clone(), count_expr);

// second stage aggregation
let global_sum_expr = AggExpr::Sum(col(sum_id));
let global_sq_sum_expr = AggExpr::Sum(col(sq_sum_id));
let global_count_expr = AggExpr::Sum(col(count_id));
let global_sum_id = get_id(&global_sum_expr);
let global_sq_sum_id = get_id(&global_sq_sum_expr);
let global_count_id = get_id(&global_count_expr);
add_to_stage(
&mut second_stage_aggs,
global_sum_id.clone(),
global_sum_expr,
);
add_to_stage(
&mut second_stage_aggs,
global_sq_sum_id.clone(),
global_sq_sum_expr,
);
add_to_stage(
&mut second_stage_aggs,
global_count_id.clone(),
global_count_expr,
);

// second stage
// final projection
let g_sq_sum = col(global_sq_sum_id);
let g_sum = col(global_sum_id);
let g_count = col(global_count_id);
let left = g_sq_sum.div(g_count.clone());
let right = g_sum.div(g_count);
let right = right.clone().mul(right);
let result = left.sub(right);

todo!("stddev")
}
AggExpr::StddevMerge(..) => {
unimplemented!("User-facing stddev_merge aggregation is not implemented")
final_exprs.push(result);
}
AggExpr::Min(e) => {
let min_id = agg_expr.semantic_id(schema).id;
Expand Down
Loading

0 comments on commit 64bff42

Please sign in to comment.