Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Implement standard deviation #3005

Merged
merged 31 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
818c076
Clean up code
raunakab Oct 6, 2024
a53dfaa
Add all structure for stddev
raunakab Oct 6, 2024
2ae9d08
Implement structure for local and distributed stddev
raunakab Oct 7, 2024
c7f189c
Implement non-grouped stddev
raunakab Oct 7, 2024
797358f
Implement grouped standard deviation
raunakab Oct 7, 2024
93b125b
Remove unwraps that may have panicked because of invalid first element
raunakab Oct 7, 2024
db48aa7
Merge branch 'main' into feat/stddev
raunakab Oct 7, 2024
a55ed2b
Add `#[pyfunctions]` functions to code
raunakab Oct 7, 2024
bd509d4
Add basic test for stddev
raunakab Oct 7, 2024
1b4d039
Add partition based testing
raunakab Oct 7, 2024
633f486
Add first stage pass to stddev distributed implementation
raunakab Oct 7, 2024
9b9ac18
Add `StddevMerge` variant to finish the second stage aggregations
raunakab Oct 7, 2024
9669a08
Implement `stddev_merge` todo
raunakab Oct 7, 2024
64bff42
Finish distribtued stddev
raunakab Oct 8, 2024
81d7f75
Merge branch 'main' into feat/stddev
raunakab Oct 8, 2024
d825a21
Edit data-type of `square_sum` field in `to_field` impl
raunakab Oct 8, 2024
9b94626
Fix errors in multi-partition aggregation planning
raunakab Oct 8, 2024
70577ab
Add some tests for stddev (single- and multi- partitioned)
raunakab Oct 8, 2024
265a7a7
Finish tests for stddev feature
raunakab Oct 8, 2024
a76fade
Explicitly import typing module; fix lints
raunakab Oct 8, 2024
7a5a36a
Remove `SquareSum` since it can just be implemented as `AggExpr::Sum(…
raunakab Oct 8, 2024
c6eba4e
Add debug_assertions to length checking during stats calculations
raunakab Oct 8, 2024
4581104
Remove dead function and remove re-calculation of mean
raunakab Oct 8, 2024
dd941b0
Change type of count to f64 to avoid casts in loop; remove panic asse…
raunakab Oct 8, 2024
823e3af
Change name of data-type function
raunakab Oct 8, 2024
53a0566
Add comment to `populate_aggregation_stages`; explains what each agg-…
raunakab Oct 8, 2024
e4222f5
Add docs to dataframe stddev API
raunakab Oct 8, 2024
0c976a4
Change `assert_eq` to `debug_assert_eq`
raunakab Oct 8, 2024
65e443a
Update grouped-mean impl to use stats
raunakab Oct 8, 2024
04dcb04
Merge branch 'main' into feat/stddev
raunakab Oct 8, 2024
1874f43
Add to docs
raunakab Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ class PyExpr:
def approx_count_distinct(self) -> PyExpr: ...
def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ...
def mean(self) -> PyExpr: ...
def stddev(self) -> PyExpr: ...
raunakab marked this conversation as resolved.
Show resolved Hide resolved
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def any_value(self, ignore_nulls: bool) -> PyExpr: ...
Expand Down Expand Up @@ -1336,6 +1337,7 @@ class PySeries:
def count(self, mode: CountMode) -> PySeries: ...
def sum(self) -> PySeries: ...
def mean(self) -> PySeries: ...
def stddev(self) -> PySeries: ...
def min(self) -> PySeries: ...
def max(self) -> PySeries: ...
def agg_list(self) -> PySeries: ...
Expand Down
22 changes: 22 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,17 @@
"""
return self._apply_agg_fn(Expression.mean, cols)

@DataframePublicAPI
def stddev(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global standard deviation on the DataFrame

raunakab marked this conversation as resolved.
Show resolved Hide resolved
Args:
*cols (Union[str, Expression]): columns to stddev
Returns:
DataFrame: Globally aggregated standard deviation. Should be a single row.
"""
return self._apply_agg_fn(Expression.stddev, cols)

Check warning on line 2130 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L2130

Added line #L2130 was not covered by tests

@DataframePublicAPI
def min(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global min on the DataFrame
Expand Down Expand Up @@ -2856,6 +2867,17 @@
"""
return self.df._apply_agg_fn(Expression.mean, cols, self.group_by)

def stddev(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs grouped standard deviation on this GroupedDataFrame.

Args:
*cols (Union[str, Expression]): columns to stddev

Returns:
DataFrame: DataFrame with grouped standard deviation.
"""
return self.df._apply_agg_fn(Expression.stddev, cols, self.group_by)

Check warning on line 2879 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L2879

Added line #L2879 was not covered by tests

def min(self, *cols: ColumnInputType) -> "DataFrame":
"""Perform grouped min on this GroupedDataFrame.

Expand Down
5 changes: 5 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,11 @@ def mean(self) -> Expression:
expr = self._expr.mean()
return Expression._from_pyexpr(expr)

def stddev(self) -> Expression:
"""Calculates the standard deviation of the values in the expression"""
expr = self._expr.stddev()
return Expression._from_pyexpr(expr)

def min(self) -> Expression:
"""Calculates the minimum value in the expression"""
expr = self._expr.min()
Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,10 @@
assert self._series is not None
return Series._from_pyseries(self._series.mean())

def stddev(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.stddev())

Check warning on line 517 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L516-L517

Added lines #L516 - L517 were not covered by tests

def sum(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.sum())
Expand Down
39 changes: 18 additions & 21 deletions src/daft-core/src/array/ops/mean.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
use std::sync::Arc;

use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable};
use crate::{array::ops::GroupIndices, count_mode::CountMode, datatypes::*};
impl DaftMeanAggable for &DataArray<Float64Type> {
type Output = DaftResult<DataArray<Float64Type>>;
use crate::{
array::ops::{
as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable, GroupIndices,
},
count_mode::CountMode,
datatypes::*,
utils::stats,
};

fn mean(&self) -> Self::Output {
let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0);
let count_value = DaftCountAggable::count(self, CountMode::Valid)?
.as_arrow()
.value(0);
impl DaftMeanAggable for DataArray<Float64Type> {
type Output = DaftResult<Self>;

let result = match count_value {
0 => None,
count_value => Some(sum_value / count_value as f64),
};
let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result]));

DataArray::new(
Arc::new(Field::new(self.field.name.clone(), DataType::Float64)),
arrow_array,
)
fn mean(&self) -> Self::Output {
let stats = stats::calculate_stats(self)?;
let mean = stats::calculate_mean(stats.sum, stats.count);
raunakab marked this conversation as resolved.
Show resolved Hide resolved
let data = PrimitiveArray::from([mean]).boxed();
let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64));
Self::new(field, data)
}

fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output {
use arrow2::array::PrimitiveArray;
let sum_values = self.grouped_sum(groups)?;
let count_values = self.grouped_count(groups, CountMode::Valid)?;
assert_eq!(sum_values.len(), count_values.len());
Expand All @@ -39,6 +36,6 @@ impl DaftMeanAggable for &DataArray<Float64Type> {
(s, c) => Some(s / (*c as f64)),
});
let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group));
Ok(DataArray::from((self.field.name.as_ref(), mean_array)))
Ok(Self::from((self.field.name.as_ref(), mean_array)))
raunakab marked this conversation as resolved.
Show resolved Hide resolved
}
}
7 changes: 7 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod sketch_percentile;
mod sort;
pub(crate) mod sparse_tensor;
mod sqrt;
mod stddev;
mod struct_;
mod sum;
mod take;
Expand Down Expand Up @@ -189,6 +190,12 @@ pub trait DaftMeanAggable {
fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftStddevAggable {
type Output;
fn stddev(&self) -> Self::Output;
fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftCompareAggable {
type Output;
fn min(&self) -> Self::Output;
Expand Down
34 changes: 34 additions & 0 deletions src/daft-core/src/array/ops/stddev.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use crate::{
array::{
ops::{DaftStddevAggable, GroupIndices},
DataArray,
},
datatypes::Float64Type,
utils::stats,
};

impl DaftStddevAggable for DataArray<Float64Type> {
type Output = DaftResult<Self>;

fn stddev(&self) -> Self::Output {
let stats = stats::calculate_stats(self)?;
let values = self.into_iter().flatten().copied();
let stddev = stats::calculate_stddev(stats, values);
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from([stddev]).boxed();
Self::new(field, data)
}

fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output {
let grouped_stddevs_iter = stats::grouped_stats(self, groups)?.map(|(stats, group)| {
let values = group.iter().filter_map(|&index| self.get(index as _));
stats::calculate_stddev(stats, values)
});
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from_iter(grouped_stddevs_iter).boxed();
Self::new(field, data)
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
pub fn try_numeric_aggregation_supertype(dtype: &DataType) -> DaftResult<DataType> {
raunakab marked this conversation as resolved.
Show resolved Hide resolved
if dtype.is_numeric() {
Ok(DataType::Float64)
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType;
pub mod prelude;
use std::ops::{Add, Div, Mul, Rem, Sub};

pub use agg_ops::{try_mean_supertype, try_sum_supertype};
pub use agg_ops::{try_numeric_aggregation_supertype, try_sum_supertype};
use arrow2::{
compute::comparison::Simd8,
types::{simd::Simd, NativeType},
Expand Down
60 changes: 30 additions & 30 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use logical::Decimal128Array;

use crate::{
array::{
ops::{DaftHllMergeAggable, GroupIndices},
ops::{
DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable,
DaftSumAggable, GroupIndices,
},
ListArray,
},
count_mode::CountMode,
Expand All @@ -26,12 +29,10 @@ impl Series {
}

pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftSumAggable, datatypes::DataType::*};

match self.data_type() {
// intX -> int64 (in line with numpy)
Int8 | Int16 | Int32 | Int64 => {
let casted = self.cast(&Int64)?;
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
let casted = self.cast(&DataType::Int64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series())
Expand All @@ -40,8 +41,8 @@ impl Series {
}
}
// uintX -> uint64 (in line with numpy)
UInt8 | UInt16 | UInt32 | UInt64 => {
let casted = self.cast(&UInt64)?;
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
let casted = self.cast(&DataType::UInt64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series())
Expand All @@ -50,23 +51,23 @@ impl Series {
}
}
// floatX -> floatX (in line with numpy)
Float32 => match groups {
DataType::Float32 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float32Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float32Array>()?)?.into_series()),
},
Float64 => match groups {
DataType::Float64 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float64Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float64Array>()?)?.into_series()),
},
Decimal128(_, _) => match groups {
DataType::Decimal128(_, _) => match groups {
Some(groups) => Ok(Decimal128Array::new(
Field {
dtype: try_sum_supertype(self.data_type())?,
Expand Down Expand Up @@ -95,12 +96,10 @@ impl Series {
}

pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*};

// Upcast all numeric types to float64 and compute approx_sketch.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
let casted = self.cast(&DataType::Float64)?;
match groups {
Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch(
&casted.f64()?,
Expand Down Expand Up @@ -149,24 +148,25 @@ impl Series {
}

pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*};

// Upcast all numeric types to float64 and use f64 mean kernel.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
match groups {
Some(groups) => {
Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series())
}
None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()),
}
}
other => Err(DaftError::TypeError(format!(
"Numeric mean is not implemented for type {}",
other
))),
}
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))?
.into_series();
Ok(series)
}

pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
// Upcast all numeric types to float64 and use f64 stddev kernel.
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))?
.into_series();
Ok(series)
}

pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod arrow;
pub mod display;
pub mod dyn_compare;
pub mod identity_hash_set;
pub mod stats;
pub mod supertype;
Loading
Loading