Skip to content

Commit

Permalink
GH-38255: [Java] Implement Flight SQL Bulk Ingestion (#43551)
Browse files Browse the repository at this point in the history
Please look at #38255 for details on this functionality. Support for Go and C++ was added as part of #38385.
This pull request is to add the required support for Java.
* GitHub Issue: #38255

Lead-authored-by: Amit Mittal <[email protected]>
Co-authored-by: Amit Mittal <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
eramitmittal authored Sep 5, 2024
1 parent 50219ef commit 5ca12bd
Show file tree
Hide file tree
Showing 15 changed files with 990 additions and 103 deletions.
2 changes: 1 addition & 1 deletion dev/archery/archery/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
Scenario(
"flight_sql:ingestion",
description="Ensure Flight SQL ingestion works as expected.",
skip_testers={"JS", "C#", "Rust", "Java"}
skip_testers={"JS", "C#", "Rust"}
),
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,17 @@
*/
package org.apache.arrow.flight.integration.tests;

import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.SchemaResult;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.CancelResult;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.DenseUnionVector;
import org.apache.arrow.vector.types.pojo.Schema;

/**
* Integration test scenario for validating Flight SQL specs across multiple implementations. This
Expand All @@ -53,69 +46,32 @@ public void client(BufferAllocator allocator, Location location, FlightClient cl
}

private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception {
FlightInfo info = sqlClient.getSqlInfo();
Ticket ticket = info.getEndpoints().get(0).getTicket();

Map<Integer, Object> infoValues = new HashMap<>();
try (FlightStream stream = sqlClient.getStream(ticket)) {
Schema actualSchema = stream.getSchema();
IntegrationAssertions.assertEquals(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, actualSchema);

while (stream.next()) {
UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0);
DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1);

for (int i = 0; i < stream.getRoot().getRowCount(); i++) {
final int code = infoName.get(i);
if (infoValues.containsKey(code)) {
throw new AssertionError("Duplicate SqlInfo value: " + code);
}
Object object;
byte typeId = value.getTypeId(i);
switch (typeId) {
case 0: // string
object =
Preconditions.checkNotNull(
value.getVarCharVector(typeId).getObject(value.getOffset(i)))
.toString();
break;
case 1: // bool
object = value.getBitVector(typeId).getObject(value.getOffset(i));
break;
case 2: // int64
object = value.getBigIntVector(typeId).getObject(value.getOffset(i));
break;
case 3: // int32
object = value.getIntVector(typeId).getObject(value.getOffset(i));
break;
default:
throw new AssertionError("Decoding SqlInfo of type code " + typeId);
}
infoValues.put(code, object);
}
}
}

IntegrationAssertions.assertEquals(
Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE));
IntegrationAssertions.assertEquals(
"min_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE));
IntegrationAssertions.assertEquals(
"max_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE));
IntegrationAssertions.assertEquals(
FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE));
IntegrationAssertions.assertEquals(
42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE));
IntegrationAssertions.assertEquals(
7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE));
validate(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
sqlClient.getSqlInfo(),
sqlClient,
s -> {
Map<Integer, Object> infoValues = readSqlInfoStream(s);
IntegrationAssertions.assertEquals(
Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE));
IntegrationAssertions.assertEquals(
"min_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE));
IntegrationAssertions.assertEquals(
"max_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE));
IntegrationAssertions.assertEquals(
FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE));
IntegrationAssertions.assertEquals(
42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE));
IntegrationAssertions.assertEquals(
7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE));
});
}

private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.integration.tests;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlClient.ExecuteIngestOptions;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;

/**
* Integration test scenario for validating Flight SQL specs across multiple implementations. This
* should ensure that RPC objects are being built and parsed correctly for multiple languages and
* that the Arrow schemas are returned as expected.
*/
public class FlightSqlIngestionScenario extends FlightSqlScenario {

@Override
public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception {
FlightSqlScenarioProducer producer =
(FlightSqlScenarioProducer) super.producer(allocator, location);
producer
.getSqlInfoBuilder()
.withFlightSqlServerBulkIngestionTransaction(true)
.withFlightSqlServerBulkIngestion(true);
return producer;
}

@Override
public void client(BufferAllocator allocator, Location location, FlightClient client)
throws Exception {
try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) {
validateMetadataRetrieval(sqlClient);
validateIngestion(allocator, sqlClient);
}
}

private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception {
validate(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
sqlClient.getSqlInfo(
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED,
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION),
sqlClient,
s -> {
Map<Integer, Object> infoValues = readSqlInfoStream(s);
IntegrationAssertions.assertEquals(
Boolean.TRUE,
infoValues.get(
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION_VALUE));
});
}

private VectorSchemaRoot getIngestVectorRoot(BufferAllocator allocator) {
Schema schema = FlightSqlScenarioProducer.getIngestSchema();
VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
root.setRowCount(3);
return root;
}

private void validateIngestion(BufferAllocator allocator, FlightSqlClient sqlClient) {
try (VectorSchemaRoot data = getIngestVectorRoot(allocator)) {
TableDefinitionOptions tableDefinitionOptions =
TableDefinitionOptions.newBuilder()
.setIfExists(TableDefinitionOptions.TableExistsOption.TABLE_EXISTS_OPTION_REPLACE)
.setIfNotExist(
TableDefinitionOptions.TableNotExistOption.TABLE_NOT_EXIST_OPTION_CREATE)
.build();
Map<String, String> options = new HashMap<>(ImmutableMap.of("key1", "val1", "key2", "val2"));
ExecuteIngestOptions executeIngestOptions =
new ExecuteIngestOptions(
"test_table", tableDefinitionOptions, true, "test_catalog", "test_schema", options);
FlightSqlClient.Transaction transaction =
new FlightSqlClient.Transaction(BULK_INGEST_TRANSACTION_ID);
long updatedRows = sqlClient.executeIngest(data, executeIngestOptions, transaction);

IntegrationAssertions.assertEquals(3L, updatedRows);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
*/
package org.apache.arrow.flight.integration.tests;

import static java.util.Objects.isNull;

import com.google.protobuf.Any;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
Expand All @@ -29,10 +35,14 @@
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.FlightSqlUtils;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.DenseUnionVector;
import org.apache.arrow.vector.types.pojo.Schema;

/**
Expand All @@ -52,6 +62,7 @@ public class FlightSqlScenario implements Scenario {
public static final FlightSqlClient.SubstraitPlan SUBSTRAIT_PLAN =
new FlightSqlClient.SubstraitPlan(SUBSTRAIT_PLAN_TEXT, SUBSTRAIT_VERSION);
public static final byte[] TRANSACTION_ID = "transaction_id".getBytes(StandardCharsets.UTF_8);
public static final byte[] BULK_INGEST_TRANSACTION_ID = "123".getBytes(StandardCharsets.UTF_8);

@Override
public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception {
Expand Down Expand Up @@ -150,15 +161,23 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti
validateSchema(
FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfoSchema(options));

validate(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
FlightInfo sqlInfoFlightInfo =
sqlClient.getSqlInfo(
new FlightSql.SqlInfo[] {
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME,
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY
},
options),
sqlClient);
options);

Ticket ticket = sqlInfoFlightInfo.getEndpoints().get(0).getTicket();
FlightSql.CommandGetSqlInfo requestSqlInfoCommand =
FlightSqlUtils.unpackOrThrow(
Any.parseFrom(ticket.getBytes()), FlightSql.CommandGetSqlInfo.class);
IntegrationAssertions.assertEquals(
requestSqlInfoCommand.getInfo(0), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE);
IntegrationAssertions.assertEquals(
requestSqlInfoCommand.getInfo(1), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE);
validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlInfoFlightInfo, sqlClient);
validateSchema(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfoSchema(options));
}
Expand Down Expand Up @@ -194,14 +213,64 @@ private void validatePreparedStatementExecution(

protected void validate(Schema expectedSchema, FlightInfo flightInfo, FlightSqlClient sqlClient)
throws Exception {
validate(expectedSchema, flightInfo, sqlClient, null);
}

protected void validate(
Schema expectedSchema,
FlightInfo flightInfo,
FlightSqlClient sqlClient,
Consumer<FlightStream> streamConsumer)
throws Exception {
Ticket ticket = flightInfo.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket)) {
Schema actualSchema = stream.getSchema();
IntegrationAssertions.assertEquals(expectedSchema, actualSchema);
if (!isNull(streamConsumer)) {
streamConsumer.accept(stream);
}
}
}

protected void validateSchema(Schema expected, SchemaResult actual) {
IntegrationAssertions.assertEquals(expected, actual.getSchema());
}

protected Map<Integer, Object> readSqlInfoStream(FlightStream stream) {
Map<Integer, Object> infoValues = new HashMap<>();
while (stream.next()) {
UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0);
DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1);

for (int i = 0; i < stream.getRoot().getRowCount(); i++) {
final int code = infoName.get(i);
if (infoValues.containsKey(code)) {
throw new AssertionError("Duplicate SqlInfo value: " + code);
}
Object object;
byte typeId = value.getTypeId(i);
switch (typeId) {
case 0: // string
object =
Preconditions.checkNotNull(
value.getVarCharVector(typeId).getObject(value.getOffset(i)))
.toString();
break;
case 1: // bool
object = value.getBitVector(typeId).getObject(value.getOffset(i));
break;
case 2: // int64
object = value.getBigIntVector(typeId).getObject(value.getOffset(i));
break;
case 3: // int32
object = value.getIntVector(typeId).getObject(value.getOffset(i));
break;
default:
throw new AssertionError("Decoding SqlInfo of type code " + typeId);
}
infoValues.put(code, object);
}
}
return infoValues;
}
}
Loading

0 comments on commit 5ca12bd

Please sign in to comment.