Skip to content

Commit

Permalink
feat: initial commit to add velox as a consumer
Browse files Browse the repository at this point in the history
fix: update java lib path for linux

fix: use double instead of decimal in duckdb

fix: update substrait to use doubles and varchar

feat: velox consumer support for extension function tests

fix: add updated schema for isthmus

fix: update rounding functions

fix: update boolean function tests

fix: update schemas for consumers
  • Loading branch information
richtia committed Jan 19, 2023
1 parent ed97f26 commit b534cbb
Show file tree
Hide file tree
Showing 17 changed files with 586 additions and 352 deletions.
6 changes: 3 additions & 3 deletions substrait_consumer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def pytest_addoption(parser):
action="store",
default=",".join([x.__name__ for x in CONSUMERS]),
help=f"A comma separated list of consumers to run against.",
choices=[x.__name__ for x in CONSUMERS]
choices=[x.__name__ for x in CONSUMERS],
)
parser.addoption(
"--producer",
action="store",
default=",".join([x.__name__ for x in PRODUCERS]),
help="A comma separated list of producers to run against.",
choices=[x.__name__ for x in PRODUCERS]
choices=[x.__name__ for x in PRODUCERS],
)


PRODUCERS = [DuckDBProducer, IbisProducer, IsthmusProducer]
CONSUMERS = [AceroConsumer, DuckDBConsumer]
CONSUMERS = [AceroConsumer, DuckDBConsumer, VeloxConsumer]


def _get_consumers():
Expand Down
47 changes: 46 additions & 1 deletion substrait_consumer/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.substrait as substrait
import velox

from substrait_consumer.common import SubstraitUtils
from substrait_consumer.schema_updates import PA_SCHEMA, TABLE_TO_RECREATE


class DuckDBConsumer:
Expand Down Expand Up @@ -68,6 +70,14 @@ def load_tables_from_parquet(
create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');"
self.db_connection.execute(create_table_sql)
created_tables.add(table_name)
if table_name in TABLE_TO_RECREATE.keys():
self.db_connection.query(
f"ALTER TABLE {table_name} RENAME TO {table_name}_orig"
)
self.db_connection.query(f"{TABLE_TO_RECREATE[table_name]}")
self.db_connection.query(
f"insert into {table_name} select * from {table_name}_orig"
)
table_names.append(table_name)

return table_names
Expand All @@ -93,7 +103,9 @@ def setup(self, db_connection, file_names: Iterable[str]):
)
if table_name not in self.created_tables:
self.created_tables.add(table_name)
self.tables[table_name] = pq.read_table(file_path)
self.tables[table_name] = pq.read_table(
file_path, schema=PA_SCHEMA[table_name]
)
else:
table = pa.table(
{
Expand Down Expand Up @@ -147,3 +159,36 @@ def run_substrait_query(self, substrait_query: bytes) -> pa.Table:
result = reader.read_all()

return result


class VeloxConsumer:
"""
Adapts the Velox Substrait consumer to the test framework.
"""

def __init__(self):
self.created_tables = set()
self.tables = {}
self.table_provider = lambda names: self.tables[names[0]]

def setup(self, db_connection, file_names: Iterable[str]):
pass

def run_substrait_query(self, substrait_query: bytes) -> pa.Table:
"""
Run the substrait plan against Velox.
Parameters:
substrait_query:
A json formatted byte representation of the substrait query plan.
Returns:
A pyarrow table resulting from running the substrait query plan.
"""
velox_result = velox.from_json(substrait_query)

record_batches = []
for vec in velox_result:
record_batches.append(vec.to_arrow())

return pa.Table.from_batches(record_batches)
8 changes: 8 additions & 0 deletions substrait_consumer/data/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE lineitem(l_orderkey INTEGER, l_partkey INTEGER, l_suppkey INTEGER, l_linenumber INTEGER, l_quantity INTEGER, l_extendedprice DOUBLE, l_discount DOUBLE, l_tax DOUBLE, l_returnflag VARCHAR, l_linestatus VARCHAR, l_shipdate DATE, l_commitdate DATE, l_receiptdate DATE, l_shipinstruct VARCHAR, l_shipmode VARCHAR, l_comment VARCHAR);
CREATE TABLE orders(o_orderkey INTEGER, o_custkey INTEGER, o_orderstatus VARCHAR, o_totalprice DOUBLE, o_orderdate DATE, o_orderpriority VARCHAR, o_clerk VARCHAR, o_shippriority INTEGER, o_comment VARCHAR);
CREATE TABLE partsupp(ps_partkey INTEGER, ps_suppkey INTEGER, ps_availqty INTEGER, ps_supplycost DOUBLE, ps_comment VARCHAR);
CREATE TABLE part(p_partkey INTEGER, p_name VARCHAR, p_mfgr VARCHAR, p_brand VARCHAR, p_type VARCHAR, p_size INTEGER, p_container VARCHAR, p_retailprice DOUBLE, p_comment VARCHAR);
CREATE TABLE customer(c_custkey INTEGER, c_name VARCHAR, c_address VARCHAR, c_nationkey INTEGER, c_phone VARCHAR, c_acctbal DOUBLE, c_mktsegment VARCHAR, c_comment VARCHAR);
CREATE TABLE supplier(s_suppkey INTEGER, s_name VARCHAR, s_address VARCHAR, s_nationkey INTEGER, s_phone VARCHAR, s_acctbal DOUBLE, s_comment VARCHAR);
CREATE TABLE nation(n_nationkey INTEGER, n_name VARCHAR, n_regionkey INTEGER, n_comment VARCHAR);
CREATE TABLE region(r_regionkey INTEGER, r_name VARCHAR, r_comment VARCHAR);
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
def ceil_expr(partsupp):
new_col = partsupp.ps_supplycost.ceil().name("CEIL_SUPPLYCOST")
return partsupp[partsupp.ps_supplycost, new_col]
return partsupp[new_col]


def floor_expr(partsupp):
new_col = partsupp.ps_supplycost.floor().name("FLOOR_SUPPLYCOST")
return partsupp[partsupp.ps_supplycost, new_col]
return partsupp[new_col]


IBIS_SCALAR = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
SQL_AGGREGATE = {
"sum": (
"""
SELECT sum(L_EXTENDEDPRICE) AS SUM_EXTENDEDPRICE
SELECT CAST(sum(L_EXTENDEDPRICE) AS DECIMAL(15,2)) AS SUM_EXTENDEDPRICE
FROM '{}';
""",
[DuckDBProducer],
),
"avg": (
"""
SELECT avg(L_EXTENDEDPRICE) AS AVG_EXTENDEDPRICE
SELECT CAST(avg(L_EXTENDEDPRICE) AS DECIMAL(15,2)) AVG_EXTENDEDPRICE
FROM '{}';
""",
[DuckDBProducer],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
SQL_SCALAR = {
"ceil": (
"""
SELECT PS_SUPPLYCOST, ceil(PS_SUPPLYCOST) AS CEIL_SUPPLYCOST
SELECT ceil(PS_SUPPLYCOST) AS CEIL_SUPPLYCOST
FROM '{}';
""",
[DuckDBProducer],
),
"floor": (
"""
SELECT PS_SUPPLYCOST, floor(PS_SUPPLYCOST) AS FLOOR_SUPPLYCOST
SELECT floor(PS_SUPPLYCOST) AS FLOOR_SUPPLYCOST
FROM '{}';
""",
[DuckDBProducer],
Expand Down
25 changes: 19 additions & 6 deletions substrait_consumer/producers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from substrait_consumer.common import SubstraitUtils
from substrait_consumer.context import get_schema, produce_isthmus_substrait
from substrait_consumer.schema_updates import TABLE_TO_RECREATE


class DuckDBProducer:
Expand Down Expand Up @@ -46,7 +47,9 @@ def produce_substrait(

def format_sql(self, created_tables, sql_query, file_names):
if len(file_names) > 0:
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand Down Expand Up @@ -93,7 +96,9 @@ def produce_substrait(

def format_sql(self, created_tables, sql_query, file_names):
if len(file_names) > 0:
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand All @@ -116,9 +121,7 @@ def __init__(self, db_connection=None):
def set_db_connection(self, db_connection):
self.db_connection = db_connection

def produce_substrait(
self, sql_query: str, consumer, ibis_expr: str = None
) -> str:
def produce_substrait(self, sql_query: str, consumer, ibis_expr: str = None) -> str:
"""
Produce the Isthmus substrait plan using the given SQL query.
Expand All @@ -140,7 +143,9 @@ def format_sql(self, created_tables, sql_query, file_names):
sql_query = sql_query.replace("'t'", "t")
if len(file_names) > 0:
self.file_names = file_names
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand Down Expand Up @@ -175,6 +180,14 @@ def load_tables_from_parquet(
create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');"
db_connection.execute(create_table_sql)
created_tables.add(table_name)
if table_name in TABLE_TO_RECREATE.keys():
db_connection.query(
f"ALTER TABLE {table_name} RENAME TO {table_name}_orig"
)
db_connection.query(f"{TABLE_TO_RECREATE[table_name]}")
db_connection.query(
f"insert into {table_name} select * from {table_name}_orig"
)
table_names.append(table_name)

return table_names
159 changes: 159 additions & 0 deletions substrait_consumer/schema_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from pathlib import Path
import pyarrow as pa


REPO_DIR = Path(__file__).parent.parent
schema_file = Path.joinpath(REPO_DIR, "tests/data/schema.sql")

ORDERS_TABLE = """
CREATE TABLE orders(
o_orderkey INTEGER NOT NULL,
o_custkey INTEGER NOT NULL,
o_orderstatus VARCHAR NOT NULL,
o_totalprice DOUBLE NOT NULL,
o_orderdate DATE NOT NULL,
o_orderpriority VARCHAR NOT NULL,
o_clerk VARCHAR NOT NULL,
o_shippriority INTEGER NOT NULL,
o_comment VARCHAR NOT NULL);
"""
ORDERS_PA_SCHEMA = pa.schema([
pa.field('o_orderkey', pa.int32()),
pa.field('o_custkey', pa.int32()),
pa.field('o_orderstatus', pa.string()),
pa.field('o_totalprice', pa.float64()),
pa.field('o_orderdate', pa.date32()),
])
LINEITEM_TABLE = """
CREATE TABLE lineitem(
l_orderkey INTEGER NOT NULL,
l_partkey INTEGER NOT NULL,
l_suppkey INTEGER NOT NULL,
l_linenumber INTEGER NOT NULL,
l_quantity DOUBLE NOT NULL,
l_extendedprice DOUBLE NOT NULL,
l_discount DOUBLE NOT NULL,
l_tax DOUBLE NOT NULL,
l_returnflag VARCHAR NOT NULL,
l_linestatus VARCHAR NOT NULL,
l_shipdate DATE NOT NULL,
l_commitdate DATE NOT NULL,
l_receiptdate DATE NOT NULL,
l_shipinstruct VARCHAR NOT NULL,
l_shipmode VARCHAR NOT NULL,
l_comment VARCHAR NOT NULL);
"""
LINEITEM_PA_SCHEMA = pa.schema([
pa.field('l_orderkey', pa.int32()),
pa.field('l_partkey', pa.int32()),
pa.field('l_suppkey', pa.int32()),
pa.field('l_linenumber', pa.int32()),
pa.field('l_quantity', pa.float64()),
pa.field('l_extendedprice', pa.float64()),
pa.field('l_discount', pa.float64()),
pa.field('l_tax', pa.float64()),
pa.field('l_returnflag', pa.string()),
pa.field('l_linestatus', pa.string()),
pa.field('l_shipdate', pa.date32()),
pa.field('l_commitdate', pa.date32()),
pa.field('l_receiptdate', pa.date32()),
pa.field('l_shipinstruct', pa.string()),
pa.field('l_shipmode', pa.string()),
pa.field('l_comment', pa.string()),
])
PARTSUPP_TABLE = """
CREATE TABLE partsupp(
ps_partkey INTEGER NOT NULL,
ps_suppkey INTEGER NOT NULL,
ps_availqty INTEGER NOT NULL,
ps_supplycost DOUBLE NOT NULL,
ps_comment VARCHAR NOT NULL);
"""
PARTSUPP_PA_SCHEMA = pa.schema([
pa.field('ps_partkey', pa.int32()),
pa.field('ps_suppkey', pa.int32()),
pa.field('ps_availqty', pa.int32()),
pa.field('ps_supplycost', pa.float64()),
pa.field('ps_comment', pa.string()),
])
PART_TABLE = """
CREATE TABLE part(
p_partkey INTEGER NOT NULL,
p_name VARCHAR NOT NULL,
p_mfgr VARCHAR NOT NULL,
p_brand VARCHAR NOT NULL,
p_type VARCHAR NOT NULL,
p_size INTEGER NOT NULL,
p_container VARCHAR NOT NULL,
p_retailprice DOUBLE NOT NULL,
p_comment VARCHAR NOT NULL);
"""
PART_PA_SCHEMA = pa.schema([
pa.field('ps_partkey', pa.int32()),
pa.field('p_name', pa.string()),
pa.field('p_mfgr', pa.string()),
pa.field('p_brand', pa.string()),
pa.field('p_type', pa.string()),
pa.field('p_size', pa.int32()),
pa.field('p_container', pa.string()),
pa.field('p_retailprice', pa.float64()),
pa.field('p_comment', pa.string()),
])
CUSTOMER_TABLE = """
CREATE TABLE customer(
c_custkey INTEGER NOT NULL,
c_name VARCHAR NOT NULL,
c_address VARCHAR NOT NULL,
c_nationkey INTEGER NOT NULL,
c_phone VARCHAR NOT NULL,
c_acctbal DOUBLE NOT NULL,
c_mktsegment VARCHAR NOT NULL,
c_comment VARCHAR NOT NULL);
"""
CUSTOMER_PA_SCHEMA = pa.schema([
pa.field('c_custkey', pa.int32()),
pa.field('c_name', pa.string()),
pa.field('c_address', pa.string()),
pa.field('c_nationkey', pa.int32()),
pa.field('c_phone', pa.string()),
pa.field('c_acctbal', pa.float64()),
pa.field('c_mktsegment', pa.string()),
pa.field('c_comment', pa.string()),
])
SUPPLIER_TABLE = """
CREATE TABLE supplier(
s_suppkey INTEGER NOT NULL,
s_name VARCHAR NOT NULL,
s_address VARCHAR NOT NULL,
s_nationkey INTEGER NOT NULL,
s_phone VARCHAR NOT NULL,
s_acctbal DOUBLE NOT NULL,
s_comment VARCHAR NOT NULL);
"""
CUSTOMER_PA_SCHEMA = pa.schema([
pa.field('s_suppkey', pa.int32()),
pa.field('s_name', pa.string()),
pa.field('s_address', pa.string()),
pa.field('s_nationkey', pa.int32()),
pa.field('s_phone', pa.string()),
pa.field('s_acctbal', pa.float64()),
pa.field('c_comment', pa.string()),
])
TABLE_TO_RECREATE = {
"orders": ORDERS_TABLE,
"lineitem": LINEITEM_TABLE,
"partsupp": PARTSUPP_TABLE,
"part": PART_TABLE,
"customer": CUSTOMER_TABLE,
"supplier": SUPPLIER_TABLE,
}
PA_SCHEMA = {
"orders": ORDERS_PA_SCHEMA,
"lineitem": LINEITEM_PA_SCHEMA,
"partsupp": PARTSUPP_PA_SCHEMA,
"part": PART_PA_SCHEMA,
"customer": CUSTOMER_PA_SCHEMA,
"supplier": PARTSUPP_PA_SCHEMA,
"nation": None,
"region": None
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_boolean_functions(
ibis_expr: Callable[[Table], Table],
producer,
consumer,
partsupp
) -> None:
substrait_function_test(
self.db_connection,
Expand All @@ -55,6 +54,5 @@ def test_boolean_functions(
ibis_expr,
producer,
consumer,
partsupp,
self.table_t,
)
Loading

0 comments on commit b534cbb

Please sign in to comment.