Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add velox consumer #11

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions substrait_consumer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def pytest_addoption(parser):
action="store",
default=",".join([x.__name__ for x in CONSUMERS]),
help=f"A comma separated list of consumers to run against.",
choices=[x.__name__ for x in CONSUMERS]
choices=[x.__name__ for x in CONSUMERS],
)
parser.addoption(
"--producer",
action="store",
default=",".join([x.__name__ for x in PRODUCERS]),
help="A comma separated list of producers to run against.",
choices=[x.__name__ for x in PRODUCERS]
choices=[x.__name__ for x in PRODUCERS],
)


PRODUCERS = [DuckDBProducer, IbisProducer, IsthmusProducer]
CONSUMERS = [AceroConsumer, DuckDBConsumer]
CONSUMERS = [AceroConsumer, DuckDBConsumer, VeloxConsumer]


def _get_consumers():
Expand Down
47 changes: 46 additions & 1 deletion substrait_consumer/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.substrait as substrait
import velox

from substrait_consumer.common import SubstraitUtils
from substrait_consumer.schema_updates import PA_SCHEMA, TABLE_TO_RECREATE


class DuckDBConsumer:
Expand Down Expand Up @@ -68,6 +70,14 @@ def load_tables_from_parquet(
create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');"
self.db_connection.execute(create_table_sql)
created_tables.add(table_name)
if table_name in TABLE_TO_RECREATE.keys():
self.db_connection.query(
f"ALTER TABLE {table_name} RENAME TO {table_name}_orig"
)
self.db_connection.query(f"{TABLE_TO_RECREATE[table_name]}")
self.db_connection.query(
f"insert into {table_name} select * from {table_name}_orig"
)
table_names.append(table_name)

return table_names
Expand All @@ -93,7 +103,9 @@ def setup(self, db_connection, file_names: Iterable[str]):
)
if table_name not in self.created_tables:
self.created_tables.add(table_name)
self.tables[table_name] = pq.read_table(file_path)
self.tables[table_name] = pq.read_table(
file_path, schema=PA_SCHEMA[table_name]
)
else:
table = pa.table(
{
Expand Down Expand Up @@ -147,3 +159,36 @@ def run_substrait_query(self, substrait_query: bytes) -> pa.Table:
result = reader.read_all()

return result


class VeloxConsumer:
"""
Adapts the Velox Substrait consumer to the test framework.
"""

def __init__(self):
self.created_tables = set()
self.tables = {}
self.table_provider = lambda names: self.tables[names[0]]

def setup(self, db_connection, file_names: Iterable[str]):
pass

def run_substrait_query(self, substrait_query: bytes) -> pa.Table:
"""
Run the substrait plan against Velox.

Parameters:
substrait_query:
A json formatted byte representation of the substrait query plan.

Returns:
A pyarrow table resulting from running the substrait query plan.
"""
velox_result = velox.from_json(substrait_query)

record_batches = []
for vec in velox_result:
record_batches.append(vec.to_arrow())

return pa.Table.from_batches(record_batches)
8 changes: 8 additions & 0 deletions substrait_consumer/data/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE lineitem(l_orderkey INTEGER, l_partkey INTEGER, l_suppkey INTEGER, l_linenumber INTEGER, l_quantity INTEGER, l_extendedprice DOUBLE, l_discount DOUBLE, l_tax DOUBLE, l_returnflag VARCHAR, l_linestatus VARCHAR, l_shipdate DATE, l_commitdate DATE, l_receiptdate DATE, l_shipinstruct VARCHAR, l_shipmode VARCHAR, l_comment VARCHAR);
CREATE TABLE orders(o_orderkey INTEGER, o_custkey INTEGER, o_orderstatus VARCHAR, o_totalprice DOUBLE, o_orderdate DATE, o_orderpriority VARCHAR, o_clerk VARCHAR, o_shippriority INTEGER, o_comment VARCHAR);
CREATE TABLE partsupp(ps_partkey INTEGER, ps_suppkey INTEGER, ps_availqty INTEGER, ps_supplycost DOUBLE, ps_comment VARCHAR);
CREATE TABLE part(p_partkey INTEGER, p_name VARCHAR, p_mfgr VARCHAR, p_brand VARCHAR, p_type VARCHAR, p_size INTEGER, p_container VARCHAR, p_retailprice DOUBLE, p_comment VARCHAR);
CREATE TABLE customer(c_custkey INTEGER, c_name VARCHAR, c_address VARCHAR, c_nationkey INTEGER, c_phone VARCHAR, c_acctbal DOUBLE, c_mktsegment VARCHAR, c_comment VARCHAR);
CREATE TABLE supplier(s_suppkey INTEGER, s_name VARCHAR, s_address VARCHAR, s_nationkey INTEGER, s_phone VARCHAR, s_acctbal DOUBLE, s_comment VARCHAR);
CREATE TABLE nation(n_nationkey INTEGER, n_name VARCHAR, n_regionkey INTEGER, n_comment VARCHAR);
CREATE TABLE region(r_regionkey INTEGER, r_name VARCHAR, r_comment VARCHAR);
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
def ceil_expr(partsupp):
new_col = partsupp.ps_supplycost.ceil().name("CEIL_SUPPLYCOST")
return partsupp[partsupp.ps_supplycost, new_col]
return partsupp[new_col]


def floor_expr(partsupp):
new_col = partsupp.ps_supplycost.floor().name("FLOOR_SUPPLYCOST")
return partsupp[partsupp.ps_supplycost, new_col]
return partsupp[new_col]


IBIS_SCALAR = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
SQL_AGGREGATE = {
"sum": (
"""
SELECT sum(L_EXTENDEDPRICE) AS SUM_EXTENDEDPRICE
SELECT CAST(sum(L_EXTENDEDPRICE) AS DECIMAL(15,2)) AS SUM_EXTENDEDPRICE
FROM '{}';
""",
[DuckDBProducer],
),
"avg": (
"""
SELECT avg(L_EXTENDEDPRICE) AS AVG_EXTENDEDPRICE
SELECT CAST(avg(L_EXTENDEDPRICE) AS DECIMAL(15,2)) AVG_EXTENDEDPRICE
FROM '{}';
""",
[DuckDBProducer],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
SQL_SCALAR = {
"ceil": (
"""
SELECT PS_SUPPLYCOST, ceil(PS_SUPPLYCOST) AS CEIL_SUPPLYCOST
SELECT ceil(PS_SUPPLYCOST) AS CEIL_SUPPLYCOST
FROM '{}';
""",
[DuckDBProducer],
),
"floor": (
"""
SELECT PS_SUPPLYCOST, floor(PS_SUPPLYCOST) AS FLOOR_SUPPLYCOST
SELECT floor(PS_SUPPLYCOST) AS FLOOR_SUPPLYCOST
FROM '{}';
""",
[DuckDBProducer],
Expand Down
25 changes: 19 additions & 6 deletions substrait_consumer/producers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from substrait_consumer.common import SubstraitUtils
from substrait_consumer.context import get_schema, produce_isthmus_substrait
from substrait_consumer.schema_updates import TABLE_TO_RECREATE


class DuckDBProducer:
Expand Down Expand Up @@ -46,7 +47,9 @@ def produce_substrait(

def format_sql(self, created_tables, sql_query, file_names):
if len(file_names) > 0:
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand Down Expand Up @@ -93,7 +96,9 @@ def produce_substrait(

def format_sql(self, created_tables, sql_query, file_names):
if len(file_names) > 0:
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand All @@ -116,9 +121,7 @@ def __init__(self, db_connection=None):
def set_db_connection(self, db_connection):
self.db_connection = db_connection

def produce_substrait(
self, sql_query: str, consumer, ibis_expr: str = None
) -> str:
def produce_substrait(self, sql_query: str, consumer, ibis_expr: str = None) -> str:
"""
Produce the Isthmus substrait plan using the given SQL query.

Expand All @@ -140,7 +143,9 @@ def format_sql(self, created_tables, sql_query, file_names):
sql_query = sql_query.replace("'t'", "t")
if len(file_names) > 0:
self.file_names = file_names
table_names = load_tables_from_parquet(self.db_connection, created_tables, file_names)
table_names = load_tables_from_parquet(
self.db_connection, created_tables, file_names
)
sql_query = sql_query.format(*table_names)
return sql_query

Expand Down Expand Up @@ -175,6 +180,14 @@ def load_tables_from_parquet(
create_table_sql = f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}');"
db_connection.execute(create_table_sql)
created_tables.add(table_name)
if table_name in TABLE_TO_RECREATE.keys():
db_connection.query(
f"ALTER TABLE {table_name} RENAME TO {table_name}_orig"
)
db_connection.query(f"{TABLE_TO_RECREATE[table_name]}")
db_connection.query(
f"insert into {table_name} select * from {table_name}_orig"
)
table_names.append(table_name)

return table_names
159 changes: 159 additions & 0 deletions substrait_consumer/schema_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from pathlib import Path
import pyarrow as pa


REPO_DIR = Path(__file__).parent.parent
schema_file = Path.joinpath(REPO_DIR, "tests/data/schema.sql")

ORDERS_TABLE = """
CREATE TABLE orders(
o_orderkey INTEGER NOT NULL,
o_custkey INTEGER NOT NULL,
o_orderstatus VARCHAR NOT NULL,
o_totalprice DOUBLE NOT NULL,
o_orderdate DATE NOT NULL,
o_orderpriority VARCHAR NOT NULL,
o_clerk VARCHAR NOT NULL,
o_shippriority INTEGER NOT NULL,
o_comment VARCHAR NOT NULL);
"""
ORDERS_PA_SCHEMA = pa.schema([
pa.field('o_orderkey', pa.int32()),
pa.field('o_custkey', pa.int32()),
pa.field('o_orderstatus', pa.string()),
pa.field('o_totalprice', pa.float64()),
pa.field('o_orderdate', pa.date32()),
])
LINEITEM_TABLE = """
CREATE TABLE lineitem(
l_orderkey INTEGER NOT NULL,
l_partkey INTEGER NOT NULL,
l_suppkey INTEGER NOT NULL,
l_linenumber INTEGER NOT NULL,
l_quantity DOUBLE NOT NULL,
l_extendedprice DOUBLE NOT NULL,
l_discount DOUBLE NOT NULL,
l_tax DOUBLE NOT NULL,
l_returnflag VARCHAR NOT NULL,
l_linestatus VARCHAR NOT NULL,
l_shipdate DATE NOT NULL,
l_commitdate DATE NOT NULL,
l_receiptdate DATE NOT NULL,
l_shipinstruct VARCHAR NOT NULL,
l_shipmode VARCHAR NOT NULL,
l_comment VARCHAR NOT NULL);
"""
LINEITEM_PA_SCHEMA = pa.schema([
pa.field('l_orderkey', pa.int32()),
pa.field('l_partkey', pa.int32()),
pa.field('l_suppkey', pa.int32()),
pa.field('l_linenumber', pa.int32()),
pa.field('l_quantity', pa.float64()),
pa.field('l_extendedprice', pa.float64()),
pa.field('l_discount', pa.float64()),
pa.field('l_tax', pa.float64()),
pa.field('l_returnflag', pa.string()),
pa.field('l_linestatus', pa.string()),
pa.field('l_shipdate', pa.date32()),
pa.field('l_commitdate', pa.date32()),
pa.field('l_receiptdate', pa.date32()),
pa.field('l_shipinstruct', pa.string()),
pa.field('l_shipmode', pa.string()),
pa.field('l_comment', pa.string()),
])
PARTSUPP_TABLE = """
CREATE TABLE partsupp(
ps_partkey INTEGER NOT NULL,
ps_suppkey INTEGER NOT NULL,
ps_availqty INTEGER NOT NULL,
ps_supplycost DOUBLE NOT NULL,
ps_comment VARCHAR NOT NULL);
"""
PARTSUPP_PA_SCHEMA = pa.schema([
pa.field('ps_partkey', pa.int32()),
pa.field('ps_suppkey', pa.int32()),
pa.field('ps_availqty', pa.int32()),
pa.field('ps_supplycost', pa.float64()),
pa.field('ps_comment', pa.string()),
])
PART_TABLE = """
CREATE TABLE part(
p_partkey INTEGER NOT NULL,
p_name VARCHAR NOT NULL,
p_mfgr VARCHAR NOT NULL,
p_brand VARCHAR NOT NULL,
p_type VARCHAR NOT NULL,
p_size INTEGER NOT NULL,
p_container VARCHAR NOT NULL,
p_retailprice DOUBLE NOT NULL,
p_comment VARCHAR NOT NULL);
"""
PART_PA_SCHEMA = pa.schema([
pa.field('ps_partkey', pa.int32()),
pa.field('p_name', pa.string()),
pa.field('p_mfgr', pa.string()),
pa.field('p_brand', pa.string()),
pa.field('p_type', pa.string()),
pa.field('p_size', pa.int32()),
pa.field('p_container', pa.string()),
pa.field('p_retailprice', pa.float64()),
pa.field('p_comment', pa.string()),
])
CUSTOMER_TABLE = """
CREATE TABLE customer(
c_custkey INTEGER NOT NULL,
c_name VARCHAR NOT NULL,
c_address VARCHAR NOT NULL,
c_nationkey INTEGER NOT NULL,
c_phone VARCHAR NOT NULL,
c_acctbal DOUBLE NOT NULL,
c_mktsegment VARCHAR NOT NULL,
c_comment VARCHAR NOT NULL);
"""
CUSTOMER_PA_SCHEMA = pa.schema([
pa.field('c_custkey', pa.int32()),
pa.field('c_name', pa.string()),
pa.field('c_address', pa.string()),
pa.field('c_nationkey', pa.int32()),
pa.field('c_phone', pa.string()),
pa.field('c_acctbal', pa.float64()),
pa.field('c_mktsegment', pa.string()),
pa.field('c_comment', pa.string()),
])
SUPPLIER_TABLE = """
CREATE TABLE supplier(
s_suppkey INTEGER NOT NULL,
s_name VARCHAR NOT NULL,
s_address VARCHAR NOT NULL,
s_nationkey INTEGER NOT NULL,
s_phone VARCHAR NOT NULL,
s_acctbal DOUBLE NOT NULL,
s_comment VARCHAR NOT NULL);
"""
CUSTOMER_PA_SCHEMA = pa.schema([
pa.field('s_suppkey', pa.int32()),
pa.field('s_name', pa.string()),
pa.field('s_address', pa.string()),
pa.field('s_nationkey', pa.int32()),
pa.field('s_phone', pa.string()),
pa.field('s_acctbal', pa.float64()),
pa.field('c_comment', pa.string()),
])
TABLE_TO_RECREATE = {
"orders": ORDERS_TABLE,
"lineitem": LINEITEM_TABLE,
"partsupp": PARTSUPP_TABLE,
"part": PART_TABLE,
"customer": CUSTOMER_TABLE,
"supplier": SUPPLIER_TABLE,
}
PA_SCHEMA = {
"orders": ORDERS_PA_SCHEMA,
"lineitem": LINEITEM_PA_SCHEMA,
"partsupp": PARTSUPP_PA_SCHEMA,
"part": PART_PA_SCHEMA,
"customer": CUSTOMER_PA_SCHEMA,
"supplier": PARTSUPP_PA_SCHEMA,
"nation": None,
"region": None
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_boolean_functions(
ibis_expr: Callable[[Table], Table],
producer,
consumer,
partsupp
) -> None:
substrait_function_test(
self.db_connection,
Expand All @@ -55,6 +54,5 @@ def test_boolean_functions(
ibis_expr,
producer,
consumer,
partsupp,
self.table_t,
)
Loading