Skip to content

Commit

Permalink
[FEAT]: sql case/when (#2591)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Aug 2, 2024
1 parent 91ec88a commit 9bb4b3a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 207 deletions.
206 changes: 0 additions & 206 deletions src/daft-sql/src/analyzer.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ mod tests {
#[case::orderby("select * from tbl1 order by i32 desc")]
#[case::orderby("select * from tbl1 order by i32 asc")]
#[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")]
#[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")]
fn test_compiles(#[case] query: &str) -> SQLPlannerResult<()> {
let planner = setup();

Expand Down
30 changes: 29 additions & 1 deletion src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,35 @@ impl SQLPlanner {
SQLExpr::TypedString { .. } => unsupported_sql_err!("TYPED STRING"),
SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"),
SQLExpr::Function(func) => self.plan_function(func, current_relation),
SQLExpr::Case { .. } => unsupported_sql_err!("CASE"),
SQLExpr::Case {
operand,
conditions,
results,
else_result,
} => {
if operand.is_some() {
unsupported_sql_err!("CASE with operand not yet supported");
}
if results.len() != conditions.len() {
unsupported_sql_err!("CASE with different number of conditions and results");
}

let else_expr = match else_result {
Some(expr) => self.plan_expr(expr, current_relation)?,
None => unsupported_sql_err!("CASE with no else result"),
};

// we need to traverse from back to front to build the if else chain
// because we need to start with the else expression
conditions.iter().zip(results.iter()).rev().try_fold(
else_expr,
|else_expr, (condition, result)| {
let cond = self.plan_expr(condition, current_relation)?;
let res = self.plan_expr(result, current_relation)?;
Ok(cond.if_else(res, else_expr))
},
)
}
SQLExpr::Exists { .. } => unsupported_sql_err!("EXISTS"),
SQLExpr::Subquery(_) => unsupported_sql_err!("SUBQUERY"),
SQLExpr::GroupingSets(_) => unsupported_sql_err!("GROUPING SETS"),
Expand Down
32 changes: 32 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import numpy as np
import pytest

import daft
Expand Down Expand Up @@ -48,3 +49,34 @@ def test_parse_ok(name, sql):
print(name)
print(sql)
print("--------------")


def test_fizzbuzz_sql():
arr = np.arange(100)
df = daft.from_pydict({"a": arr})
catalog = SQLCatalog({"test": df})
# test case expression
expected = daft.from_pydict(
{
"a": arr,
"fizzbuzz": [
"FizzBuzz" if x % 15 == 0 else "Fizz" if x % 3 == 0 else "Buzz" if x % 5 == 0 else str(x)
for x in range(0, 100)
],
}
).collect()
df = daft.sql(
"""
SELECT
a,
CASE
WHEN a % 15 = 0 THEN 'FizzBuzz'
WHEN a % 3 = 0 THEN 'Fizz'
WHEN a % 5 = 0 THEN 'Buzz'
ELSE CAST(a AS TEXT)
END AS fizzbuzz
FROM test
""",
catalog=catalog,
).collect()
assert df.to_pydict() == expected.to_pydict()

0 comments on commit 9bb4b3a

Please sign in to comment.