Skip to content

Commit

Permalink
Document and allow more operations on decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
philss committed Sep 20, 2024
1 parent 24e60bb commit 2d16ef0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
13 changes: 8 additions & 5 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3262,6 +3262,7 @@ defmodule Explorer.Series do
* `:time`
* `:datetime`
* `:duration`
* `:decimal`
## Examples
Expand Down Expand Up @@ -3344,6 +3345,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* `:decimal`
## Examples
Expand Down Expand Up @@ -3422,6 +3424,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* `:decimal`
## Examples
Expand Down Expand Up @@ -3492,7 +3495,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* decimals - returning decimal series.
* `:decimal` - returning decimal series.
## Examples
Expand Down Expand Up @@ -3549,7 +3552,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* decimals - returning f64 series.
* `:decimal` - returning f64 series.
## Examples
Expand Down Expand Up @@ -3629,7 +3632,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* decimals - returning f64 series.
* `:decimal` - returning f64 series.
## Examples
Expand Down Expand Up @@ -3701,7 +3704,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* decimals.
* `:decimal` - returns f64 series.
## Examples
Expand All @@ -3726,7 +3729,7 @@ defmodule Explorer.Series do
* floats: #{Shared.inspect_dtypes(@float_dtypes, backsticks: true)}
* integers: #{Shared.inspect_dtypes(@integer_types, backsticks: true)}
* decimals.
* `:decimal`.
## Examples
Expand Down
3 changes: 2 additions & 1 deletion lib/explorer/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ defmodule Explorer.Shared do
def merge_numeric_dtype({:decimal, _, _} = decimal, :null), do: decimal
def merge_numeric_dtype(:null, {:decimal, _, _} = decimal), do: decimal

# For now, float has priority over decimals due to Polars.
def merge_numeric_dtype({:decimal, _, _}, {:f, _} = float), do: float
def merge_numeric_dtype({:f, _} = float, {:decimal, _, _}), do: float

Expand Down Expand Up @@ -512,7 +513,7 @@ defmodule Explorer.Shared do
def dtype_to_string({:f, size}), do: "f" <> Integer.to_string(size)
def dtype_to_string({:s, size}), do: "s" <> Integer.to_string(size)
def dtype_to_string({:u, size}), do: "u" <> Integer.to_string(size)
def dtype_to_string({:decimal, precision, scale}), do: "decimal[#{precision}, #{scale}]"
def dtype_to_string({:decimal, precision, scale}), do: "decimal[#{precision || "*"}, #{scale}]"
def dtype_to_string(other) when is_atom(other), do: Atom.to_string(other)

defp precision_string(:millisecond), do: "ms"
Expand Down
10 changes: 10 additions & 0 deletions native/explorer/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,16 @@ pub struct ExDecimal {
}

impl ExDecimal {
pub fn new(signed_coef: i128, scale: usize) -> Self {
Self {
coef: signed_coef
.abs()
.try_into()
.expect("signed coef is too large for u64"),
sign: if signed_coef >= 0 { 1 } else { -1 },
exp: -(scale as i64),
}
}
pub fn signed_coef(self) -> i128 {
self.sign as i128 * self.coef as i128
}
Expand Down
26 changes: 18 additions & 8 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ pub fn s_min(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
DataType::UInt16 => Ok(s.min::<u16>()?.encode(env)),
DataType::UInt32 => Ok(s.min::<u32>()?.encode(env)),
DataType::UInt64 => Ok(s.min::<u64>()?.encode(env)),
DataType::Float32 | DataType::Float64 => Ok(term_from_optional_float(s.min::<f64>()?, env)),
DataType::Float32 | DataType::Float64 | DataType::Decimal(_, _) => {
Ok(term_from_optional_float(s.min::<f64>()?, env))
}
DataType::Date => Ok(s.min::<i32>()?.map(ExDate::from).encode(env)),
DataType::Time => Ok(s.min::<i64>()?.map(ExTime::from).encode(env)),
DataType::Datetime(unit, _) => Ok(s
Expand All @@ -796,7 +798,9 @@ pub fn s_max(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
DataType::UInt16 => Ok(s.max::<u16>()?.encode(env)),
DataType::UInt32 => Ok(s.max::<u32>()?.encode(env)),
DataType::UInt64 => Ok(s.max::<u64>()?.encode(env)),
DataType::Float32 | DataType::Float64 => Ok(term_from_optional_float(s.max::<f64>()?, env)),
DataType::Float32 | DataType::Float64 | DataType::Decimal(_, _) => {
Ok(term_from_optional_float(s.max::<f64>()?, env))
}
DataType::Date => Ok(s.max::<i32>()?.map(ExDate::from).encode(env)),
DataType::Time => Ok(s.max::<i64>()?.map(ExTime::from).encode(env)),
DataType::Datetime(unit, _) => Ok(s
Expand All @@ -817,9 +821,13 @@ pub fn s_argmin(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
Ok(s.arg_min().encode(env))
}

fn is_numeric(dtype: &DataType) -> bool {
dtype.is_numeric() || matches!(dtype, DataType::Decimal(_, _))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_mean(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
Ok(term_from_optional_float(s.mean(), env))
} else {
panic!("mean/1 not implemented for {:?}", &s.dtype())
Expand All @@ -828,7 +836,7 @@ pub fn s_mean(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_median(env: Env, s: ExSeries) -> Result<Term, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
Ok(term_from_optional_float(s.median(), env))
} else {
panic!("median/1 not implemented for {:?}", &s.dtype())
Expand All @@ -845,7 +853,7 @@ pub fn s_mode(s: ExSeries) -> Result<ExSeries, ExplorerError> {

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_product(s: ExSeries) -> Result<ExSeries, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
let series = s
.clone_inner()
.into_frame()
Expand All @@ -863,7 +871,7 @@ pub fn s_product(s: ExSeries) -> Result<ExSeries, ExplorerError> {

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_variance(s: ExSeries, ddof: u8) -> Result<ExSeries, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
let var_series = s
.clone_inner()
.into_frame()
Expand All @@ -881,7 +889,7 @@ pub fn s_variance(s: ExSeries, ddof: u8) -> Result<ExSeries, ExplorerError> {

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_standard_deviation(s: ExSeries, ddof: u8) -> Result<ExSeries, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
let std_series = s
.clone_inner()
.into_frame()
Expand All @@ -899,7 +907,7 @@ pub fn s_standard_deviation(s: ExSeries, ddof: u8) -> Result<ExSeries, ExplorerE

#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_skew(env: Env, s: ExSeries, bias: bool) -> Result<Term, ExplorerError> {
if s.dtype().is_numeric() {
if is_numeric(s.dtype()) {
Ok(term_from_optional_float(s.skew(bias)?, env))
} else {
panic!("skew/2 not implemented for {:?}", &s.dtype())
Expand Down Expand Up @@ -1040,6 +1048,7 @@ pub fn s_peak_max(s: ExSeries) -> Result<ExSeries, ExplorerError> {

DataType::Float32 => peak_max(s.f32()?),
DataType::Float64 => peak_max(s.f64()?),
DataType::Decimal(_, _) => peak_max(s.decimal()?),

DataType::Date => peak_max(s.date()?),
DataType::Time => peak_max(s.time()?),
Expand All @@ -1066,6 +1075,7 @@ pub fn s_peak_min(s: ExSeries) -> Result<ExSeries, ExplorerError> {

DataType::Float32 => peak_min(s.f32()?),
DataType::Float64 => peak_min(s.f64()?),
DataType::Decimal(_, _) => peak_min(s.decimal()?),

DataType::Date => peak_min(s.date()?),
DataType::Time => peak_min(s.time()?),
Expand Down

0 comments on commit 2d16ef0

Please sign in to comment.