From d05e1ba142fb38d47749cb7904a7db4182093e64 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Mon, 29 Jul 2024 17:54:23 -0400 Subject: [PATCH] fix(go/adbc/driver/snowflake): split files properly after reaching targetSize on ingestion (#2026) Fixes: #1997 **Core Changes** - Change ingestion `writeParquet` function to use unbuffered writer, skipping 0-row records to avoid recurrence of #1847 - Use parquet writer's internal `RowGroupTotalBytesWritten()` method to track output file size in favor of `limitWriter` - Unit test to validate that file cutoff occurs precisely when expected **Secondary Changes** - Bump arrow dependency to `v18` to pull in the changes from [ARROW-43326](https://github.com/apache/arrow/pull/43326) - Fix flightsql test that depends on hardcoded arrow version --- go/adbc/driver/snowflake/bulk_ingestion.go | 44 +++---- .../driver/snowflake/bulk_ingestion_test.go | 108 ++++++++++++++++++ go/adbc/pkg/_tmpl/driver.go.tmpl | 8 +- 3 files changed, 127 insertions(+), 33 deletions(-) create mode 100644 go/adbc/driver/snowflake/bulk_ingestion_test.go diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go b/go/adbc/driver/snowflake/bulk_ingestion.go index 8b80ee49d8..7f3d6bbd24 100644 --- a/go/adbc/driver/snowflake/bulk_ingestion.go +++ b/go/adbc/driver/snowflake/bulk_ingestion.go @@ -334,20 +334,31 @@ func writeParquet( parquetProps *parquet.WriterProperties, arrowProps pqarrow.ArrowWriterProperties, ) error { - limitWr := &limitWriter{w: w, limit: targetSize} - pqWriter, err := pqarrow.NewFileWriter(schema, limitWr, parquetProps, arrowProps) + pqWriter, err := pqarrow.NewFileWriter(schema, w, parquetProps, arrowProps) if err != nil { return err } defer pqWriter.Close() + var bytesWritten int64 for rec := range in { - err = pqWriter.WriteBuffered(rec) + if rec.NumRows() == 0 { + rec.Release() + continue + } + + err = pqWriter.Write(rec) rec.Release() if err != nil { return err } - if limitWr.LimitExceeded() { + + if targetSize < 0 { + continue + } + + bytesWritten += pqWriter.RowGroupTotalBytesWritten() + if bytesWritten >= int64(targetSize) { return nil } } @@ -584,28 +595,3 @@ func (bp *bufferPool) PutBuffer(buf *bytes.Buffer) { buf.Reset() bp.Pool.Put(buf) } - -// Wraps an io.Writer and specifies a limit. -// Keeps track of how many bytes have been written and can report whether the limit has been exceeded. -// TODO(ARROW-39789): We prefer to use RowGroupTotalBytesWritten on the ParquetWriter, but there seems to be a discrepency with the count. -type limitWriter struct { - w io.Writer - limit int - - bytesWritten int -} - -func (lw *limitWriter) Write(p []byte) (int, error) { - n, err := lw.w.Write(p) - lw.bytesWritten += n - - return n, err -} - -func (lw *limitWriter) LimitExceeded() bool { - if lw.limit > 0 { - return lw.bytesWritten > lw.limit - } - // Limit disabled - return false -} diff --git a/go/adbc/driver/snowflake/bulk_ingestion_test.go b/go/adbc/driver/snowflake/bulk_ingestion_test.go new file mode 100644 index 0000000000..6ae0e3a4d6 --- /dev/null +++ b/go/adbc/driver/snowflake/bulk_ingestion_test.go @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package snowflake + +import ( + "bytes" + "context" + "fmt" + "io" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/parquet/pqarrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIngestBatchedParquetWithFileLimit(t *testing.T) { + var buf bytes.Buffer + ctx := context.Background() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + ingestOpts := DefaultIngestOptions() + parquetProps, arrowProps := newWriterProps(mem, ingestOpts) + + nCols := 3 + nRecs := 10 + nRows := 1000 + targetFileSize := 10000 + + rec := makeRec(mem, nCols, nRows) + defer rec.Release() + + // Create a temporary parquet writer and write a single row group so we know + // approximately how many bytes it should take + tempWriter, err := pqarrow.NewFileWriter(rec.Schema(), &buf, parquetProps, arrowProps) + require.NoError(t, err) + + // Write 1 record and check the size before closing so footer bytes are not included + require.NoError(t, tempWriter.Write(rec)) + expectedRowGroupSize := buf.Len() + require.NoError(t, tempWriter.Close()) + + recs := make([]arrow.Record, nRecs) + for i := 0; i < nRecs; i++ { + recs[i] = rec + } + + rdr, err := array.NewRecordReader(rec.Schema(), recs) + require.NoError(t, err) + defer rdr.Release() + + records := make(chan arrow.Record) + go func() { assert.NoError(t, readRecords(ctx, rdr, records)) }() + + buf.Reset() + // Expected to read multiple records but then stop after targetFileSize, indicated by nil error + require.NoError(t, writeParquet(rdr.Schema(), &buf, records, targetFileSize, parquetProps, arrowProps)) + + // Expect to exceed the targetFileSize but by no more than the size of 1 row group + assert.Greater(t, buf.Len(), targetFileSize) + assert.Less(t, buf.Len(), targetFileSize+expectedRowGroupSize) + + // Drain the remaining records with no limit on file size, expect EOF + require.ErrorIs(t, writeParquet(rdr.Schema(), &buf, records, -1, parquetProps, arrowProps), io.EOF) +} + +func makeRec(mem memory.Allocator, nCols, nRows int) arrow.Record { + vals := make([]int8, nRows) + for val := 0; val < nRows; val++ { + vals[val] = int8(val) + } + + bldr := array.NewInt8Builder(mem) + defer bldr.Release() + + bldr.AppendValues(vals, nil) + arr := bldr.NewArray() + defer arr.Release() + + fields := make([]arrow.Field, nCols) + cols := make([]arrow.Array, nCols) + for i := 0; i < nCols; i++ { + fields[i] = arrow.Field{Name: fmt.Sprintf("field_%d", i), Type: arrow.PrimitiveTypes.Int8} + cols[i] = arr // array.NewRecord will retain these + } + + schema := arrow.NewSchema(fields, nil) + return array.NewRecord(schema, cols, int64(nRows)) +} diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl index 8dc352c0e2..77513949fd 100644 --- a/go/adbc/pkg/_tmpl/driver.go.tmpl +++ b/go/adbc/pkg/_tmpl/driver.go.tmpl @@ -59,10 +59,10 @@ import ( "unsafe" "github.com/apache/arrow-adbc/go/adbc" - "github.com/apache/arrow/go/v17/arrow/array" - "github.com/apache/arrow/go/v17/arrow/cdata" - "github.com/apache/arrow/go/v17/arrow/memory" - "github.com/apache/arrow/go/v17/arrow/memory/mallocator" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/cdata" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/arrow/memory/mallocator" ) // Must use malloc() to respect CGO rules