Skip to content

Commit

Permalink
Final adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
pdet committed Sep 24, 2024
1 parent 61bfd1f commit 3fa6609
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
15 changes: 8 additions & 7 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ string SubstraitToDuckDB::RemoveExtension(const string &function_name) {
return name;
}

SubstraitToDuckDB::SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json)
: context(context_p) {
SubstraitToDuckDB::SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json,
bool acquire_lock_p)
: context(context_p), acquire_lock(acquire_lock_p) {
if (!json) {
if (!plan.ParseFromString(serialized)) {
throw std::runtime_error("Was not possible to convert binary into Substrait plan");
Expand Down Expand Up @@ -549,9 +550,9 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
if (!table_info) {
throw CatalogException("Table '%s' does not exist!", table_name);
}
return make_shared_ptr<TableRelation>(context, std::move(table_info), false);
scan = make_shared_ptr<TableRelation>(context, std::move(table_info), acquire_lock);
} catch (...) {
scan = make_shared_ptr<ViewRelation>(context, DEFAULT_SCHEMA, table_name, false);
scan = make_shared_ptr<ViewRelation>(context, DEFAULT_SCHEMA, table_name, acquire_lock);
}
} else if (sget.has_local_files()) {
vector<Value> parquet_files;
Expand All @@ -573,8 +574,8 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
string name = "parquet_" + StringUtil::GenerateRandomName();
named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(false)}});
vector<Value> parameters {Value::LIST(parquet_files)};
auto scan_rel = make_shared_ptr<TableFunctionRelation>(context, "parquet_scan", parameters,
std::move(named_parameters), nullptr, true, false);
auto scan_rel = make_shared_ptr<TableFunctionRelation>(
context, "parquet_scan", parameters, std::move(named_parameters), nullptr, true, acquire_lock);
auto rel = static_cast<Relation *>(scan_rel.get());
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
Expand All @@ -590,7 +591,7 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
expression_rows.emplace_back(expression_row);
}
vector<string> column_names;
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names, "values", false);
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names, "values", acquire_lock);
} else {
throw NotImplementedException("Unsupported type of read operator for substrait");
}
Expand Down
5 changes: 4 additions & 1 deletion src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace duckdb {

class SubstraitToDuckDB {
public:
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false);
SubstraitToDuckDB(shared_ptr<ClientContext> &context_p, const string &serialized, bool json = false,
bool acquire_lock = false);
//! Transforms Substrait Plan to DuckDB Relation
shared_ptr<Relation> TransformPlan();

Expand Down Expand Up @@ -67,5 +68,7 @@ class SubstraitToDuckDB {
static const unordered_map<std::string, std::string> function_names_remap;
static const case_insensitive_set_t valid_extract_subfields;
vector<ParsedExpression *> struct_expressions;
//! If we should acquire a client context lock when creating the relatiosn
const bool acquire_lock;
};
} // namespace duckdb
13 changes: 9 additions & 4 deletions src/substrait_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,23 @@ static unique_ptr<FunctionData> ToJsonBind(ClientContext &context, TableFunction
}

shared_ptr<Relation> SubstraitPlanToDuckDBRel(shared_ptr<ClientContext> &context, const string &serialized,
bool json = false) {
SubstraitToDuckDB transformer_s2d(context, serialized, json);
bool json = false, bool acquire_lock = false) {
SubstraitToDuckDB transformer_s2d(context, serialized, json, acquire_lock);
return transformer_s2d.TransformPlan();
}

//! This function matches results of substrait plans with direct Duckdb queries
//! Is only executed when pragma enable_verification = true
//! It creates extra connections to be able to execute the consumed DuckDB Plan
//! And the SQL query itself, ideally this wouldn't be necessary and won't
//! work for round-tripping tests over temporary objects.
static void VerifySubstraitRoundtrip(unique_ptr<LogicalOperator> &query_plan, ClientContext &context,
ToSubstraitFunctionData &data, const string &serialized, bool is_json) {
// We round-trip the generated json and verify if the result is the same
auto con = Connection(*context.db);
auto actual_result = con.Query(data.query);
shared_ptr<ClientContext> c_ptr(&context, do_nothing);
auto sub_relation = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json);
auto con_2 = Connection(*context.db);
auto sub_relation = SubstraitPlanToDuckDBRel(con_2.context, serialized, is_json, true);
auto substrait_result = sub_relation->Execute();
substrait_result->names = actual_result->names;
unique_ptr<MaterializedQueryResult> substrait_materialized;
Expand Down
2 changes: 1 addition & 1 deletion test/python/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def test_substrait_pyarrow(require):

arrow_result = execute_query(connection, "arrow_integers")

assert connection.execute("FROM arrow_result").fetchall() == 0
assert connection.execute("FROM arrow_result").fetchall() == [(0, 'a'), (1, 'b')]

0 comments on commit 3fa6609

Please sign in to comment.