diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 0ea244720cc1d..66c8721519ede 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -645,7 +645,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, Scenario( "flight_sql:ingestion", description="Ensure Flight SQL ingestion works as expected.", - skip_testers={"JS", "C#", "Rust", "Java"} + skip_testers={"JS", "C#", "Rust"} ), ] diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java index 76d79b226623d..69b02030ccd3d 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java @@ -16,24 +16,17 @@ */ package org.apache.arrow.flight.integration.tests; -import java.util.HashMap; import java.util.Map; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.SchemaResult; -import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.CancelResult; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.DenseUnionVector; -import org.apache.arrow.vector.types.pojo.Schema; /** * Integration test scenario for validating Flight SQL specs across multiple implementations. This @@ -53,69 +46,32 @@ public void client(BufferAllocator allocator, Location location, FlightClient cl } private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { - FlightInfo info = sqlClient.getSqlInfo(); - Ticket ticket = info.getEndpoints().get(0).getTicket(); - - Map infoValues = new HashMap<>(); - try (FlightStream stream = sqlClient.getStream(ticket)) { - Schema actualSchema = stream.getSchema(); - IntegrationAssertions.assertEquals( - FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, actualSchema); - - while (stream.next()) { - UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0); - DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1); - - for (int i = 0; i < stream.getRoot().getRowCount(); i++) { - final int code = infoName.get(i); - if (infoValues.containsKey(code)) { - throw new AssertionError("Duplicate SqlInfo value: " + code); - } - Object object; - byte typeId = value.getTypeId(i); - switch (typeId) { - case 0: // string - object = - Preconditions.checkNotNull( - value.getVarCharVector(typeId).getObject(value.getOffset(i))) - .toString(); - break; - case 1: // bool - object = value.getBitVector(typeId).getObject(value.getOffset(i)); - break; - case 2: // int64 - object = value.getBigIntVector(typeId).getObject(value.getOffset(i)); - break; - case 3: // int32 - object = value.getIntVector(typeId).getObject(value.getOffset(i)); - break; - default: - throw new AssertionError("Decoding SqlInfo of type code " + typeId); - } - infoValues.put(code, object); - } - } - } - - IntegrationAssertions.assertEquals( - Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE)); - IntegrationAssertions.assertEquals( - Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE)); - IntegrationAssertions.assertEquals( - "min_version", - infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE)); - IntegrationAssertions.assertEquals( - "max_version", - infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE)); - IntegrationAssertions.assertEquals( - FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE, - infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE)); - IntegrationAssertions.assertEquals( - Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE)); - IntegrationAssertions.assertEquals( - 42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE)); - IntegrationAssertions.assertEquals( - 7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE)); + validate( + FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + sqlClient.getSqlInfo(), + sqlClient, + s -> { + Map infoValues = readSqlInfoStream(s); + IntegrationAssertions.assertEquals( + Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE)); + IntegrationAssertions.assertEquals( + Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE)); + IntegrationAssertions.assertEquals( + "min_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE)); + IntegrationAssertions.assertEquals( + "max_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE)); + IntegrationAssertions.assertEquals( + FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE)); + IntegrationAssertions.assertEquals( + Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE)); + IntegrationAssertions.assertEquals( + 42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE)); + IntegrationAssertions.assertEquals( + 7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE)); + }); } private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlIngestionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlIngestionScenario.java new file mode 100644 index 0000000000000..981ce89f1b88a --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlIngestionScenario.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight.integration.tests; + +import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlClient.ExecuteIngestOptions; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. This + * should ensure that RPC objects are being built and parsed correctly for multiple languages and + * that the Arrow schemas are returned as expected. + */ +public class FlightSqlIngestionScenario extends FlightSqlScenario { + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + FlightSqlScenarioProducer producer = + (FlightSqlScenarioProducer) super.producer(allocator, location); + producer + .getSqlInfoBuilder() + .withFlightSqlServerBulkIngestionTransaction(true) + .withFlightSqlServerBulkIngestion(true); + return producer; + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) { + validateMetadataRetrieval(sqlClient); + validateIngestion(allocator, sqlClient); + } + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + validate( + FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + sqlClient.getSqlInfo( + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED, + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION), + sqlClient, + s -> { + Map infoValues = readSqlInfoStream(s); + IntegrationAssertions.assertEquals( + Boolean.TRUE, + infoValues.get( + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED_VALUE)); + IntegrationAssertions.assertEquals( + Boolean.TRUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION_VALUE)); + }); + } + + private VectorSchemaRoot getIngestVectorRoot(BufferAllocator allocator) { + Schema schema = FlightSqlScenarioProducer.getIngestSchema(); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.setRowCount(3); + return root; + } + + private void validateIngestion(BufferAllocator allocator, FlightSqlClient sqlClient) { + try (VectorSchemaRoot data = getIngestVectorRoot(allocator)) { + TableDefinitionOptions tableDefinitionOptions = + TableDefinitionOptions.newBuilder() + .setIfExists(TableDefinitionOptions.TableExistsOption.TABLE_EXISTS_OPTION_REPLACE) + .setIfNotExist( + TableDefinitionOptions.TableNotExistOption.TABLE_NOT_EXIST_OPTION_CREATE) + .build(); + Map options = new HashMap<>(ImmutableMap.of("key1", "val1", "key2", "val2")); + ExecuteIngestOptions executeIngestOptions = + new ExecuteIngestOptions( + "test_table", tableDefinitionOptions, true, "test_catalog", "test_schema", options); + FlightSqlClient.Transaction transaction = + new FlightSqlClient.Transaction(BULK_INGEST_TRANSACTION_ID); + long updatedRows = sqlClient.executeIngest(data, executeIngestOptions, transaction); + + IntegrationAssertions.assertEquals(3L, updatedRows); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index 8918b252700ac..e370a30bdc6ff 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -16,8 +16,14 @@ */ package org.apache.arrow.flight.integration.tests; +import static java.util.Objects.isNull; + +import com.google.protobuf.Any; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightInfo; @@ -29,10 +35,14 @@ import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.FlightSqlUtils; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -52,6 +62,7 @@ public class FlightSqlScenario implements Scenario { public static final FlightSqlClient.SubstraitPlan SUBSTRAIT_PLAN = new FlightSqlClient.SubstraitPlan(SUBSTRAIT_PLAN_TEXT, SUBSTRAIT_VERSION); public static final byte[] TRANSACTION_ID = "transaction_id".getBytes(StandardCharsets.UTF_8); + public static final byte[] BULK_INGEST_TRANSACTION_ID = "123".getBytes(StandardCharsets.UTF_8); @Override public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { @@ -150,15 +161,23 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti validateSchema( FlightSqlProducer.Schemas.GET_TYPE_INFO_SCHEMA, sqlClient.getXdbcTypeInfoSchema(options)); - validate( - FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, + FlightInfo sqlInfoFlightInfo = sqlClient.getSqlInfo( new FlightSql.SqlInfo[] { FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY }, - options), - sqlClient); + options); + + Ticket ticket = sqlInfoFlightInfo.getEndpoints().get(0).getTicket(); + FlightSql.CommandGetSqlInfo requestSqlInfoCommand = + FlightSqlUtils.unpackOrThrow( + Any.parseFrom(ticket.getBytes()), FlightSql.CommandGetSqlInfo.class); + IntegrationAssertions.assertEquals( + requestSqlInfoCommand.getInfo(0), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals( + requestSqlInfoCommand.getInfo(1), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + validate(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlInfoFlightInfo, sqlClient); validateSchema( FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfoSchema(options)); } @@ -194,14 +213,64 @@ private void validatePreparedStatementExecution( protected void validate(Schema expectedSchema, FlightInfo flightInfo, FlightSqlClient sqlClient) throws Exception { + validate(expectedSchema, flightInfo, sqlClient, null); + } + + protected void validate( + Schema expectedSchema, + FlightInfo flightInfo, + FlightSqlClient sqlClient, + Consumer streamConsumer) + throws Exception { Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); try (FlightStream stream = sqlClient.getStream(ticket)) { Schema actualSchema = stream.getSchema(); IntegrationAssertions.assertEquals(expectedSchema, actualSchema); + if (!isNull(streamConsumer)) { + streamConsumer.accept(stream); + } } } protected void validateSchema(Schema expected, SchemaResult actual) { IntegrationAssertions.assertEquals(expected, actual.getSchema()); } + + protected Map readSqlInfoStream(FlightStream stream) { + Map infoValues = new HashMap<>(); + while (stream.next()) { + UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0); + DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1); + + for (int i = 0; i < stream.getRoot().getRowCount(); i++) { + final int code = infoName.get(i); + if (infoValues.containsKey(code)) { + throw new AssertionError("Duplicate SqlInfo value: " + code); + } + Object object; + byte typeId = value.getTypeId(i); + switch (typeId) { + case 0: // string + object = + Preconditions.checkNotNull( + value.getVarCharVector(typeId).getObject(value.getOffset(i))) + .toString(); + break; + case 1: // bool + object = value.getBitVector(typeId).getObject(value.getOffset(i)); + break; + case 2: // int64 + object = value.getBigIntVector(typeId).getObject(value.getOffset(i)); + break; + case 3: // int32 + object = value.getIntVector(typeId).getObject(value.getOffset(i)); + break; + default: + throw new AssertionError("Decoding SqlInfo of type code " + typeId); + } + infoValues.put(code, object); + } + } + return infoValues; + } } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index b7a75b459d176..be746b575761d 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -16,13 +16,16 @@ */ package org.apache.arrow.flight.integration.tests; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightDescriptor; @@ -38,6 +41,8 @@ import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.flight.sql.SqlInfoBuilder; import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableExistsOption; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableNotExistOption; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; @@ -48,10 +53,27 @@ /** Hardcoded Flight SQL producer used for cross-language integration tests. */ public class FlightSqlScenarioProducer implements FlightSqlProducer { + public static final String SERVER_NAME = "Flight SQL Integration Test Server"; private final BufferAllocator allocator; + private final SqlInfoBuilder sqlInfoBuilder; + + /** Constructor. */ public FlightSqlScenarioProducer(BufferAllocator allocator) { this.allocator = allocator; + sqlInfoBuilder = + new SqlInfoBuilder() + .withFlightSqlServerName(SERVER_NAME) + .withFlightSqlServerReadOnly(false) + .withFlightSqlServerSql(false) + .withFlightSqlServerSubstrait(true) + .withFlightSqlServerSubstraitMinVersion("min_version") + .withFlightSqlServerSubstraitMaxVersion("max_version") + .withFlightSqlServerTransaction( + FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT) + .withFlightSqlServerCancel(true) + .withFlightSqlServerStatementTimeout(42) + .withFlightSqlServerTransactionTimeout(7); } /** @@ -109,6 +131,15 @@ static Schema getQueryWithTransactionSchema() { null))); } + static Schema getIngestSchema() { + return new Schema( + Collections.singletonList(Field.nullable("test_field", new ArrowType.Int(64, true)))); + } + + protected SqlInfoBuilder getSqlInfoBuilder() { + return sqlInfoBuilder; + } + @Override public void beginSavepoint( FlightSql.ActionBeginSavepointRequest request, @@ -511,6 +542,44 @@ public Runnable acceptPutStatement( : FlightSqlScenario.UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); } + @Override + public Runnable acceptPutStatementBulkIngest( + FlightSql.CommandStatementIngest command, + CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + + IntegrationAssertions.assertEquals( + TableExistsOption.TABLE_EXISTS_OPTION_REPLACE, + command.getTableDefinitionOptions().getIfExists()); + IntegrationAssertions.assertEquals( + TableNotExistOption.TABLE_NOT_EXIST_OPTION_CREATE, + command.getTableDefinitionOptions().getIfNotExist()); + IntegrationAssertions.assertEquals("test_table", command.getTable()); + IntegrationAssertions.assertEquals("test_catalog", command.getCatalog()); + IntegrationAssertions.assertEquals("test_schema", command.getSchema()); + IntegrationAssertions.assertEquals(true, command.getTemporary()); + IntegrationAssertions.assertEquals( + FlightSqlScenario.BULK_INGEST_TRANSACTION_ID, command.getTransactionId().toByteArray()); + + Map expectedOptions = + new HashMap<>(ImmutableMap.of("key1", "val1", "key2", "val2")); + IntegrationAssertions.assertEquals(expectedOptions.size(), command.getOptionsCount()); + + for (Map.Entry optionEntry : expectedOptions.entrySet()) { + String key = optionEntry.getKey(); + IntegrationAssertions.assertEquals(optionEntry.getValue(), command.getOptionsOrThrow(key)); + } + + IntegrationAssertions.assertEquals(getIngestSchema(), flightStream.getSchema()); + long rowCount = 0; + while (flightStream.next()) { + rowCount += flightStream.getRoot().getRowCount(); + } + + return acceptPutReturnConstant(ackStream, rowCount); + } + @Override public Runnable acceptPutSubstraitPlan( FlightSql.CommandStatementSubstraitPlan command, @@ -577,35 +646,19 @@ public Runnable acceptPutPreparedStatementQuery( @Override public FlightInfo getFlightInfoSqlInfo( FlightSql.CommandGetSqlInfo request, CallContext context, FlightDescriptor descriptor) { - if (request.getInfoCount() == 2) { - // Integration test for the protocol messages - IntegrationAssertions.assertEquals( - request.getInfo(0), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); - IntegrationAssertions.assertEquals( - request.getInfo(1), FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); - } return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); } @Override public void getStreamSqlInfo( FlightSql.CommandGetSqlInfo command, CallContext context, ServerStreamListener listener) { - if (command.getInfoCount() == 2) { + if (command.getInfoCount() == 2 + && command.getInfo(0) == FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE + && command.getInfo(1) == FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE) { // Integration test for the protocol messages putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); return; } - SqlInfoBuilder sqlInfoBuilder = - new SqlInfoBuilder() - .withFlightSqlServerSql(false) - .withFlightSqlServerSubstrait(true) - .withFlightSqlServerSubstraitMinVersion("min_version") - .withFlightSqlServerSubstraitMaxVersion("max_version") - .withFlightSqlServerTransaction( - FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT) - .withFlightSqlServerCancel(true) - .withFlightSqlServerStatementTimeout(42) - .withFlightSqlServerTransactionTimeout(7); sqlInfoBuilder.send(command.getInfoList(), listener); } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index a294902a26d35..451edb6bd5a34 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -48,6 +48,7 @@ private Scenarios() { scenarios.put("poll_flight_info", PollFlightInfoScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); + scenarios.put("flight_sql:ingestion", FlightSqlIngestionScenario::new); scenarios.put("app_metadata_flight_info_endpoint", AppMetadataFlightInfoEndpointScenario::new); scenarios.put("session_options", SessionOptionsScenario::new); } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/TestBufferAllocationListener.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/TestBufferAllocationListener.java new file mode 100644 index 0000000000000..10594d4cf0962 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/TestBufferAllocationListener.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight.integration.tests; + +import java.util.ArrayList; +import java.util.List; +import org.apache.arrow.memory.AllocationListener; + +class TestBufferAllocationListener implements AllocationListener { + static class Entry { + StackTraceElement[] stackTrace; + long size; + boolean forAllocation; + + public Entry(StackTraceElement[] stackTrace, long size, boolean forAllocation) { + this.stackTrace = stackTrace; + this.size = size; + this.forAllocation = forAllocation; + } + } + + List trail = new ArrayList<>(); + + public void onAllocation(long size) { + trail.add(new Entry(Thread.currentThread().getStackTrace(), size, true)); + } + + public void onRelease(long size) { + trail.add(new Entry(Thread.currentThread().getStackTrace(), size, false)); + } + + public void reThrowWithAddedAllocatorInfo(Exception e) { + StringBuilder sb = new StringBuilder(); + sb.append(e.getMessage()); + sb.append("\n"); + sb.append("[[Buffer allocation and release trail during the test execution: \n"); + for (Entry trailEntry : trail) { + sb.append( + String.format( + "%s: %d: %n%s", + trailEntry.forAllocation ? "allocate" : "release", + trailEntry.size, + getStackTraceAsString(trailEntry.stackTrace))); + } + sb.append("]]"); + throw new IllegalStateException(sb.toString(), e); + } + + private String getStackTraceAsString(StackTraceElement[] elements) { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i < elements.length; i++) { + StackTraceElement s = elements[i]; + sb.append("\t"); + sb.append(s); + sb.append("\n"); + } + return sb.toString(); + } +} diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java index bdf1c43ce9da6..8419432c66227 100644 --- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -16,6 +16,10 @@ */ package org.apache.arrow.flight.integration.tests; +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; @@ -80,6 +84,11 @@ void flightSqlExtension() throws Exception { testScenario("flight_sql:extension"); } + @Test + void flightSqlIngestion() throws Exception { + testScenario("flight_sql:ingestion"); + } + @Test void appMetadataFlightInfoEndpoint() throws Exception { testScenario("app_metadata_flight_info_endpoint"); @@ -91,9 +100,16 @@ void sessionOptions() throws Exception { } void testScenario(String scenarioName) throws Exception { - try (final BufferAllocator allocator = new RootAllocator()) { + TestBufferAllocationListener listener = new TestBufferAllocationListener(); + try (final BufferAllocator allocator = new RootAllocator(listener, Long.MAX_VALUE)) { + final ExecutorService exec = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + .setNameFormat("integration-test-flight-server-executor-%d") + .build()); final FlightServer.Builder builder = FlightServer.builder() + .executor(exec) .allocator(allocator) .location(Location.forGrpcInsecure("0.0.0.0", 0)); final Scenario scenario = Scenarios.getScenario(scenarioName); @@ -108,6 +124,17 @@ void testScenario(String scenarioName) throws Exception { scenario.client(allocator, location, client); } } + + // Shutdown the executor while allowing existing tasks to finish. + // Without this wait, allocator.close() may get invoked earlier than an executor thread may + // have finished freeing up resources + // In that case, allocator.close() can throw an IllegalStateException for memory leak, leading + // to flaky tests + exec.shutdown(); + final boolean unused = exec.awaitTermination(3, TimeUnit.SECONDS); + } catch (IllegalStateException e) { + // this could be due to Allocator detecting memory leak. Add allocation trail to help debug + listener.reThrowWithAddedAllocatorInfo(e); } } } diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml index 92bab5e206757..021c1e65ab5b3 100644 --- a/java/flight/flight-sql/pom.xml +++ b/java/flight/flight-sql/pom.xml @@ -110,6 +110,12 @@ under the License. 2.12.0 test + + org.apache.commons + commons-text + 1.12.0 + test + org.hamcrest hamcrest diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index 4bc12d86b1d0e..9a6ffdfdca847 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -16,6 +16,7 @@ */ package org.apache.arrow.flight.sql; +import static java.util.Objects.isNull; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; @@ -54,8 +55,10 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.CallOption; @@ -82,11 +85,14 @@ import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions; import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -206,6 +212,130 @@ public SchemaResult getExecuteSubstraitSchema( return getExecuteSubstraitSchema(substraitPlan, /*transaction*/ null, options); } + /** + * Execute a bulk ingest on the server. + * + * @param data data to be ingested + * @param ingestOptions options for the ingest request. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeIngest( + final VectorSchemaRoot data, + final ExecuteIngestOptions ingestOptions, + final CallOption... options) { + return executeIngest(data, ingestOptions, /*transaction*/ null, options); + } + + /** + * Execute a bulk ingest on the server. + * + * @param dataReader data stream to be ingested + * @param ingestOptions options for the ingest request. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeIngest( + final ArrowStreamReader dataReader, + final ExecuteIngestOptions ingestOptions, + final CallOption... options) { + return executeIngest(dataReader, ingestOptions, /*transaction*/ null, options); + } + + /** + * Execute a bulk ingest on the server. + * + * @param data data to be ingested + * @param ingestOptions options for the ingest request. + * @param transaction The transaction that this ingest request is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeIngest( + final VectorSchemaRoot data, + final ExecuteIngestOptions ingestOptions, + Transaction transaction, + final CallOption... options) { + return executeIngest( + data, ingestOptions, transaction, FlightClient.ClientStreamListener::putNext, options); + } + + /** + * Execute a bulk ingest on the server. + * + * @param dataReader data stream to be ingested + * @param ingestOptions options for the ingest request. + * @param transaction The transaction that this ingest request is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeIngest( + final ArrowStreamReader dataReader, + final ExecuteIngestOptions ingestOptions, + Transaction transaction, + final CallOption... options) { + + try { + return executeIngest( + dataReader.getVectorSchemaRoot(), + ingestOptions, + transaction, + listener -> { + while (true) { + try { + if (!dataReader.loadNextBatch()) { + break; + } + } catch (IOException e) { + throw CallStatus.UNKNOWN.withCause(e).toRuntimeException(); + } + listener.putNext(); + } + }, + options); + } catch (IOException e) { + throw CallStatus.UNKNOWN.withCause(e).toRuntimeException(); + } + } + + private long executeIngest( + final VectorSchemaRoot data, + final ExecuteIngestOptions ingestOptions, + final Transaction transaction, + final Consumer dataPutter, + final CallOption... options) { + try { + final CommandStatementIngest.Builder builder = CommandStatementIngest.newBuilder(); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + ingestOptions.updateCommandBuilder(builder); + + final FlightDescriptor descriptor = + FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + try (final SyncPutListener putListener = new SyncPutListener()) { + + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, data, putListener, options); + dataPutter.accept(listener); + listener.completed(); + listener.getResult(); + + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = + DoPutUpdateResult.parseFrom(result.getApplicationMetadata().nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } + } + } catch (final InterruptedException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e.getCause()).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } + /** * Execute an update query on the server. * @@ -245,8 +375,10 @@ public long executeUpdate( } finally { listener.getResult(); } - } catch (final InterruptedException | ExecutionException e) { + } catch (final InterruptedException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e.getCause()).toRuntimeException(); } catch (final InvalidProtocolBufferException e) { throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } @@ -295,8 +427,10 @@ public long executeSubstraitUpdate( } finally { listener.getResult(); } - } catch (final InterruptedException | ExecutionException e) { + } catch (final InterruptedException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e.getCause()).toRuntimeException(); } catch (final InvalidProtocolBufferException e) { throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } @@ -1003,6 +1137,82 @@ public void close() throws Exception { AutoCloseables.close(client); } + /** Class to encapsulate Flight SQL bulk ingest request options. * */ + public static class ExecuteIngestOptions { + private final String table; + private final TableDefinitionOptions tableDefinitionOptions; + private final boolean useTemporaryTable; + private final String catalog; + private final String schema; + private final Map options; + + /** + * Constructor. + * + * @param table The table to load data into. + * @param tableDefinitionOptions The behavior for handling the table definition. + * @param catalog The catalog of the destination table to load data into. If null, a + * backend-specific default may be used. + * @param schema The schema of the destination table to load data into. If null, a + * backend-specific default may be used. + * @param options Backend-specific options. Can be null if there are no options to be set. + */ + public ExecuteIngestOptions( + String table, + TableDefinitionOptions tableDefinitionOptions, + String catalog, + String schema, + Map options) { + this(table, tableDefinitionOptions, false, catalog, schema, options); + } + + /** + * Constructor. + * + * @param table The table to load data into. + * @param tableDefinitionOptions The behavior for handling the table definition. + * @param useTemporaryTable Use a temporary table for bulk ingestion. Temporary table may get + * placed in a backend-specific schema and/or catalog and gets dropped at the end of the + * session. If backend does not support ingesting using a temporary table or an explicit + * choice of schema or catalog is incompatible with the server's namespacing decision, an + * error is returned as part of {@link #executeIngest} request. + * @param catalog The catalog of the destination table to load data into. If null, a + * backend-specific default may be used. + * @param schema The schema of the destination table to load data into. If null, a + * backend-specific default may be used. + * @param options Backend-specific options. Can be null if there are no options to be set. + */ + public ExecuteIngestOptions( + String table, + TableDefinitionOptions tableDefinitionOptions, + boolean useTemporaryTable, + String catalog, + String schema, + Map options) { + this.table = table; + this.tableDefinitionOptions = tableDefinitionOptions; + this.useTemporaryTable = useTemporaryTable; + this.catalog = catalog; + this.schema = schema; + this.options = options; + } + + protected void updateCommandBuilder(CommandStatementIngest.Builder builder) { + builder.setTable(table); + builder.setTableDefinitionOptions(tableDefinitionOptions); + builder.setTemporary(useTemporaryTable); + if (!isNull(catalog)) { + builder.setCatalog(catalog); + } + if (!isNull(schema)) { + builder.setSchema(schema); + } + if (!isNull(options)) { + builder.putAllOptions(options); + } + } + } + /** Helper class to encapsulate Flight SQL prepared statement logic. */ public static class PreparedStatement implements AutoCloseable { private final FlightClient client; @@ -1140,10 +1350,12 @@ public FlightInfo execute(final CallOption... options) { } } } - } catch (final InterruptedException | ExecutionException e) { + } catch (final InterruptedException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e.getCause()).toRuntimeException(); } catch (final InvalidProtocolBufferException e) { - throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } } @@ -1198,10 +1410,12 @@ public long executeUpdate(final CallOption... options) { DoPutUpdateResult.parseFrom(metadata.nioBuffer()); return doPutUpdateResult.getRecordCount(); } - } catch (final InterruptedException | ExecutionException e) { + } catch (final InterruptedException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e.getCause()).toRuntimeException(); } catch (final InvalidProtocolBufferException e) { - throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } } diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index 0afef79160621..9465e5ff88053 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -83,6 +83,7 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; @@ -281,7 +282,8 @@ default void getStream(CallContext context, Ticket ticket, ServerStreamListener /** * Depending on the provided command, method either: 1. Execute provided SQL query as an update * statement, or 2. Execute provided update SQL query prepared statement. In this case, parameters - * binding is allowed, or 3. Binds parameters to the provided prepared statement. + * binding is allowed, or 3. Binds parameters to the provided prepared statement, or 4. Bulk + * ingests data provided through the flightStream. * * @param context Per-call context. * @param flightStream The data stream being uploaded. @@ -299,6 +301,12 @@ default Runnable acceptPut( context, flightStream, ackStream); + } else if (command.is(CommandStatementIngest.class)) { + return acceptPutStatementBulkIngest( + FlightSqlUtils.unpackOrThrow(command, CommandStatementIngest.class), + context, + flightStream, + ackStream); } else if (command.is(CommandStatementSubstraitPlan.class)) { return acceptPutSubstraitPlan( FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), @@ -777,6 +785,27 @@ Runnable acceptPutStatement( FlightStream flightStream, StreamListener ackStream); + /** + * Accepts uploaded data for a particular bulk ingest data stream. + * + *

`PutResult`s must be in the form of a {@link DoPutUpdateResult}. + * + * @param command The bulk ingestion request. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + default Runnable acceptPutStatementBulkIngest( + CommandStatementIngest command, + CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + return () -> { + ackStream.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + }; + } + /** * Handle a Substrait plan with uploaded data. * diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java index 5091017c13cd8..72fcae8c18003 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoOpFlightSqlProducer.java @@ -91,6 +91,18 @@ public Runnable acceptPutStatement( throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); } + @Override + public Runnable acceptPutStatementBulkIngest( + FlightSql.CommandStatementIngest command, + CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + return () -> { + ackStream.onError( + CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + }; + } + @Override public Runnable acceptPutPreparedStatementUpdate( FlightSql.CommandPreparedStatementUpdate command, diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java index 2a31bc77365e2..cbe4989d14744 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java @@ -148,6 +148,17 @@ public SqlInfoBuilder withFlightSqlServerCancel(boolean value) { return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE, value); } + /** Set a value for bulk ingestion support. */ + public SqlInfoBuilder withFlightSqlServerBulkIngestion(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION_VALUE, value); + } + + /** Set a value for transaction support for bulk ingestion. */ + public SqlInfoBuilder withFlightSqlServerBulkIngestionTransaction(boolean value) { + return withBooleanProvider( + SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED_VALUE, value); + } + /** Set a value for statement timeouts. */ public SqlInfoBuilder withFlightSqlServerStatementTimeout(int value) { return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE, value); diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index 67bfc85c48602..f9d0551a3aa22 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -55,6 +55,7 @@ import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.DriverManager; @@ -82,6 +83,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; @@ -112,6 +114,10 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableExistsOption; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableNotExistOption; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; @@ -146,6 +152,7 @@ import org.apache.commons.dbcp2.PoolingDataSource; import org.apache.commons.pool2.ObjectPool; import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.text.StringEscapeUtils; import org.slf4j.Logger; /** @@ -245,7 +252,9 @@ public FlightSqlExample(final Location location, final String dbName) { : SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN) .withSqlAllTablesAreSelectable(true) .withSqlNullOrdering(SqlNullOrdering.SQL_NULLS_SORTED_AT_END) - .withSqlMaxColumnsInTable(42); + .withSqlMaxColumnsInTable(42) + .withFlightSqlServerBulkIngestion(true) + .withFlightSqlServerBulkIngestionTransaction(false); } catch (SQLException e) { throw new RuntimeException(e); } @@ -714,6 +723,34 @@ private static ByteBuffer serializeMetadata(final Schema schema) { } } + private static String getRootAsCSVNoHeader(final VectorSchemaRoot root) { + StringBuilder sb = new StringBuilder(); + Schema schema = root.getSchema(); + int rowCount = root.getRowCount(); + List fieldVectors = root.getFieldVectors(); + + List row = new ArrayList<>(schema.getFields().size()); + for (int i = 0; i < rowCount; i++) { + if (i > 0) { + sb.append("\n"); + } + row.clear(); + for (FieldVector v : fieldVectors) { + row.add(v.getObject(i)); + } + printRowAsCSV(sb, row); + } + return sb.toString(); + } + + private static void printRowAsCSV(StringBuilder sb, List values) { + sb.append( + values.stream() + .map(v -> isNull(v) ? "" : v.toString()) + .map(StringEscapeUtils::escapeCsv) + .collect(Collectors.joining(","))); + } + @Override public void getStreamPreparedStatement( final CommandPreparedStatementQuery command, @@ -951,6 +988,138 @@ public Runnable acceptPutStatement( }; } + @Override + public Runnable acceptPutStatementBulkIngest( + CommandStatementIngest command, + CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + + final String schema = command.hasSchema() ? command.getSchema() : null; + final String table = command.getTable(); + final boolean temporary = command.getTemporary(); + final boolean transactionId = command.hasTransactionId(); + final TableDefinitionOptions tableDefinitionOptions = + command.hasTableDefinitionOptions() ? command.getTableDefinitionOptions() : null; + + return () -> { + TableExistsOption ifExists = TableExistsOption.TABLE_EXISTS_OPTION_APPEND; + if (temporary) { + ackStream.onError( + CallStatus.UNIMPLEMENTED + .withDescription("Bulk ingestion using temporary tables is not supported") + .toRuntimeException()); + } else if (transactionId) { + ackStream.onError( + CallStatus.UNIMPLEMENTED + .withDescription( + "Bulk ingestion automatically happens in a transaction. Specifying explicit transaction is not supported.") + .toRuntimeException()); + } else if (isNull(tableDefinitionOptions)) { + ackStream.onError( + CallStatus.INVALID_ARGUMENT + .withDescription("TableDefinitionOptions not provided.") + .toRuntimeException()); + } else { + TableNotExistOption ifNotExist = tableDefinitionOptions.getIfNotExist(); + ifExists = tableDefinitionOptions.getIfExists(); + + if (!TableNotExistOption.TABLE_NOT_EXIST_OPTION_FAIL.equals(ifNotExist)) { + ackStream.onError( + CallStatus.UNIMPLEMENTED + .withDescription( + "Only supported option is TABLE_NOT_EXIST_OPTION_FAIL for TableNotExistsOption.") + .toRuntimeException()); + } else if (TableExistsOption.TABLE_EXISTS_OPTION_UNSPECIFIED.equals(ifExists)) { + ackStream.onError( + CallStatus.INVALID_ARGUMENT + .withDescription("TableExistsOption must be specified") + .toRuntimeException()); + } else if (TableExistsOption.TABLE_EXISTS_OPTION_FAIL.equals(ifExists)) { + ackStream.onError( + CallStatus.UNIMPLEMENTED + .withDescription("TABLE_EXISTS_OPTION_FAIL is not supported.") + .toRuntimeException()); + } + } + + Path tempFile = null; + try { + tempFile = Files.createTempFile(null, null); + + VectorSchemaRoot root = null; + int counter = 0; + while (flightStream.next()) { + if (counter > 0) { + Files.writeString(tempFile, "\n", StandardCharsets.UTF_8, StandardOpenOption.APPEND); + } + counter += 1; + root = flightStream.getRoot(); + Files.writeString( + tempFile, + getRootAsCSVNoHeader(root), + StandardCharsets.UTF_8, + StandardOpenOption.APPEND); + } + + if (counter > 0) { + Files.writeString(tempFile, "\n", StandardCharsets.UTF_8, StandardOpenOption.APPEND); + } + + if (!isNull(root)) { + String header = + root.getSchema().getFields().stream() + .map(Field::getName) + .collect(Collectors.joining(",")); + + try (final Connection connection = dataSource.getConnection(); + final PreparedStatement preparedStatement = + connection.prepareStatement( + "CALL SYSCS_UTIL.SYSCS_IMPORT_DATA (?,?,?,null,?,?,?,?,?)")) { + + preparedStatement.setString(1, schema); + preparedStatement.setString(2, table); + preparedStatement.setString(3, header); + preparedStatement.setString(4, tempFile.toString()); + preparedStatement.setString(5, ","); + preparedStatement.setString(6, "\""); + preparedStatement.setString(7, "UTF-8"); + preparedStatement.setInt( + 8, TableExistsOption.TABLE_EXISTS_OPTION_REPLACE.equals(ifExists) ? 1 : 0); + preparedStatement.execute(); + + final DoPutUpdateResult build = + DoPutUpdateResult.newBuilder().setRecordCount(-1).build(); + + try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { + buffer.writeBytes(build.toByteArray()); + ackStream.onNext(PutResult.metadata(buffer)); + ackStream.onCompleted(); + } + } catch (SQLException e) { + ackStream.onError( + CallStatus.INTERNAL + .withDescription("Failed to execute bulk ingest: " + e) + .toRuntimeException()); + } + } + } catch (IOException e) { + ackStream.onError( + CallStatus.INTERNAL + .withDescription("Failed to create temp file for bulk loading: " + e) + .toRuntimeException()); + } finally { + if (!isNull(tempFile)) { + try { + Files.delete(tempFile); + } catch (IOException e) { + // + } + } + } + }; + } + @Override public Runnable acceptPutPreparedStatementUpdate( CommandPreparedStatementUpdate command, diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index 2eb74adc5bc0e..3f769363fb64d 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -30,6 +30,10 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; @@ -53,6 +57,9 @@ import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.flight.sql.example.FlightSqlExample; import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableExistsOption; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableNotExistOption; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.BufferAllocator; @@ -60,11 +67,15 @@ import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; +import org.apache.arrow.vector.util.VectorBatchAppender; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; import org.junit.jupiter.api.AfterAll; @@ -96,6 +107,43 @@ public class TestFlightSql { protected static FlightServer server; protected static FlightSqlClient sqlClient; + private static void populateNext10RowsInIngestRootBatch( + int startRowNumber, + IntVector valueVector, + VarCharVector keyNameVector, + IntVector foreignIdVector, + VarCharVector keyNamesToBeDeletedVector, + VectorSchemaRoot ingestRoot) { + + final int NumRowsInBatch = 10; + + valueVector.reset(); + keyNameVector.reset(); + foreignIdVector.reset(); + + final IntStream range = IntStream.range(1, NumRowsInBatch); + + range.forEach( + i -> { + valueVector.setSafe(i - 1, (i + startRowNumber - 1) * NumRowsInBatch); + keyNameVector.setSafe(i - 1, new Text("value" + (i + startRowNumber - 1))); + foreignIdVector.setSafe(i - 1, 1); + }); + // put some comma and double-quote containing string as well + valueVector.setSafe(NumRowsInBatch - 1, (NumRowsInBatch + startRowNumber - 1) * NumRowsInBatch); + keyNameVector.setSafe( + NumRowsInBatch - 1, + new Text( + String.format( + "value%d, is \"%d\"", + (NumRowsInBatch + startRowNumber - 1), + (NumRowsInBatch + startRowNumber - 1) * NumRowsInBatch))); + foreignIdVector.setSafe(NumRowsInBatch - 1, 1); + ingestRoot.setRowCount(NumRowsInBatch); + + VectorBatchAppender.batchAppend(keyNamesToBeDeletedVector, keyNameVector); + } + @BeforeAll public static void setUp() throws Exception { setUpClientServer(); @@ -537,6 +585,119 @@ public void testSimplePreparedStatementUpdateResults() throws SQLException { } } + @Test + public void testBulkIngest() throws IOException { + // For bulk ingest DerbyDB requires uppercase column names + var keyName = new Field("KEYNAME", FieldType.nullable(new ArrowType.Utf8()), null); + var value = new Field("VALUE", FieldType.nullable(new ArrowType.Int(32, true)), null); + var foreignId = new Field("FOREIGNID", FieldType.nullable(new ArrowType.Int(32, true)), null); + + Schema dataSchema = new Schema(List.of(keyName, value, foreignId)); + + try (final VectorSchemaRoot ingestRoot = VectorSchemaRoot.create(dataSchema, allocator); + final VarCharVector keyNamesToBeDeletedVector = new VarCharVector(keyName, allocator)) { + final VarCharVector keyNameVector = (VarCharVector) ingestRoot.getVector(0); + final IntVector valueVector = (IntVector) ingestRoot.getVector(1); + final IntVector foreignIdVector = (IntVector) ingestRoot.getVector(2); + ingestRoot.allocateNew(); + keyNamesToBeDeletedVector.allocateNew(); + + try (PipedInputStream inPipe = new PipedInputStream(1024); + PipedOutputStream outPipe = new PipedOutputStream(inPipe); + ArrowStreamReader reader = new ArrowStreamReader(inPipe, allocator)) { + + new Thread( + () -> { + try (ArrowStreamWriter writer = + new ArrowStreamWriter(ingestRoot, null, outPipe)) { + writer.start(); + populateNext10RowsInIngestRootBatch( + 1, + valueVector, + keyNameVector, + foreignIdVector, + keyNamesToBeDeletedVector, + ingestRoot); + writer.writeBatch(); + populateNext10RowsInIngestRootBatch( + 11, + valueVector, + keyNameVector, + foreignIdVector, + keyNamesToBeDeletedVector, + ingestRoot); + writer.writeBatch(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }) + .start(); + + // Ingest from a stream + final long updatedRows = + sqlClient.executeIngest( + reader, + new FlightSqlClient.ExecuteIngestOptions( + "INTTABLE", + TableDefinitionOptions.newBuilder() + .setIfExists(TableExistsOption.TABLE_EXISTS_OPTION_APPEND) + .setIfNotExist(TableNotExistOption.TABLE_NOT_EXIST_OPTION_FAIL) + .build(), + null, + null, + null)); + + MatcherAssert.assertThat(updatedRows, is(-1L)); + + // Ingest directly using VectorSchemaRoot + populateNext10RowsInIngestRootBatch( + 21, valueVector, keyNameVector, foreignIdVector, keyNamesToBeDeletedVector, ingestRoot); + sqlClient.executeIngest( + ingestRoot, + new FlightSqlClient.ExecuteIngestOptions( + "INTTABLE", + TableDefinitionOptions.newBuilder() + .setIfExists(TableExistsOption.TABLE_EXISTS_OPTION_APPEND) + .setIfNotExist(TableNotExistOption.TABLE_NOT_EXIST_OPTION_FAIL) + .build(), + null, + null, + null)); + + try (PreparedStatement deletePrepare = + sqlClient.prepare("DELETE FROM INTTABLE WHERE keyName = ?")) { + final long deletedRows; + try (final VectorSchemaRoot deleteRoot = VectorSchemaRoot.of(keyNamesToBeDeletedVector)) { + deletePrepare.setParameters(deleteRoot); + deletedRows = deletePrepare.executeUpdate(); + } + + MatcherAssert.assertThat(deletedRows, is(30L)); + } + } + } + } + + @Test + public void testBulkIngestTransaction() { + assertThrows( + RuntimeException.class, + () -> { + sqlClient.executeIngest( + VectorSchemaRoot.create(new Schema(List.of()), allocator), + new FlightSqlClient.ExecuteIngestOptions( + "INTTABLE", + TableDefinitionOptions.newBuilder() + .setIfExists(TableExistsOption.TABLE_EXISTS_OPTION_APPEND) + .setIfNotExist(TableNotExistOption.TABLE_NOT_EXIST_OPTION_FAIL) + .build(), + null, + null, + null), + new FlightSqlClient.Transaction("123".getBytes(StandardCharsets.UTF_8))); + }); + } + @Test public void testSimplePreparedStatementUpdateResultsWithoutParameters() throws SQLException { try (PreparedStatement prepare =