Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): split files properly after reaching ta…
Browse files Browse the repository at this point in the history
…rgetSize on ingestion (#2026)

Fixes: #1997 

**Core Changes**
- Change ingestion `writeParquet` function to use unbuffered writer,
skipping 0-row records to avoid recurrence of #1847
- Use parquet writer's internal `RowGroupTotalBytesWritten()` method to
track output file size in favor of `limitWriter`
- Unit test to validate that file cutoff occurs precisely when expected

**Secondary Changes**
- Bump arrow dependency to `v18` to pull in the changes from
[ARROW-43326](apache/arrow#43326)
- Fix flightsql test that depends on hardcoded arrow version
  • Loading branch information
joellubi authored Jul 29, 2024
1 parent d6255ac commit d05e1ba
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 33 deletions.
44 changes: 15 additions & 29 deletions go/adbc/driver/snowflake/bulk_ingestion.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,31 @@ func writeParquet(
parquetProps *parquet.WriterProperties,
arrowProps pqarrow.ArrowWriterProperties,
) error {
limitWr := &limitWriter{w: w, limit: targetSize}
pqWriter, err := pqarrow.NewFileWriter(schema, limitWr, parquetProps, arrowProps)
pqWriter, err := pqarrow.NewFileWriter(schema, w, parquetProps, arrowProps)
if err != nil {
return err
}
defer pqWriter.Close()

var bytesWritten int64
for rec := range in {
err = pqWriter.WriteBuffered(rec)
if rec.NumRows() == 0 {
rec.Release()
continue
}

err = pqWriter.Write(rec)
rec.Release()
if err != nil {
return err
}
if limitWr.LimitExceeded() {

if targetSize < 0 {
continue
}

bytesWritten += pqWriter.RowGroupTotalBytesWritten()
if bytesWritten >= int64(targetSize) {
return nil
}
}
Expand Down Expand Up @@ -584,28 +595,3 @@ func (bp *bufferPool) PutBuffer(buf *bytes.Buffer) {
buf.Reset()
bp.Pool.Put(buf)
}

// Wraps an io.Writer and specifies a limit.
// Keeps track of how many bytes have been written and can report whether the limit has been exceeded.
// TODO(ARROW-39789): We prefer to use RowGroupTotalBytesWritten on the ParquetWriter, but there seems to be a discrepency with the count.
type limitWriter struct {
w io.Writer
limit int

bytesWritten int
}

func (lw *limitWriter) Write(p []byte) (int, error) {
n, err := lw.w.Write(p)
lw.bytesWritten += n

return n, err
}

func (lw *limitWriter) LimitExceeded() bool {
if lw.limit > 0 {
return lw.bytesWritten > lw.limit
}
// Limit disabled
return false
}
108 changes: 108 additions & 0 deletions go/adbc/driver/snowflake/bulk_ingestion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"bytes"
"context"
"fmt"
"io"
"testing"

"github.com/apache/arrow/go/v18/arrow"
"github.com/apache/arrow/go/v18/arrow/array"
"github.com/apache/arrow/go/v18/arrow/memory"
"github.com/apache/arrow/go/v18/parquet/pqarrow"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestIngestBatchedParquetWithFileLimit(t *testing.T) {
var buf bytes.Buffer
ctx := context.Background()
mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer mem.AssertSize(t, 0)

ingestOpts := DefaultIngestOptions()
parquetProps, arrowProps := newWriterProps(mem, ingestOpts)

nCols := 3
nRecs := 10
nRows := 1000
targetFileSize := 10000

rec := makeRec(mem, nCols, nRows)
defer rec.Release()

// Create a temporary parquet writer and write a single row group so we know
// approximately how many bytes it should take
tempWriter, err := pqarrow.NewFileWriter(rec.Schema(), &buf, parquetProps, arrowProps)
require.NoError(t, err)

// Write 1 record and check the size before closing so footer bytes are not included
require.NoError(t, tempWriter.Write(rec))
expectedRowGroupSize := buf.Len()
require.NoError(t, tempWriter.Close())

recs := make([]arrow.Record, nRecs)
for i := 0; i < nRecs; i++ {
recs[i] = rec
}

rdr, err := array.NewRecordReader(rec.Schema(), recs)
require.NoError(t, err)
defer rdr.Release()

records := make(chan arrow.Record)
go func() { assert.NoError(t, readRecords(ctx, rdr, records)) }()

buf.Reset()
// Expected to read multiple records but then stop after targetFileSize, indicated by nil error
require.NoError(t, writeParquet(rdr.Schema(), &buf, records, targetFileSize, parquetProps, arrowProps))

// Expect to exceed the targetFileSize but by no more than the size of 1 row group
assert.Greater(t, buf.Len(), targetFileSize)
assert.Less(t, buf.Len(), targetFileSize+expectedRowGroupSize)

// Drain the remaining records with no limit on file size, expect EOF
require.ErrorIs(t, writeParquet(rdr.Schema(), &buf, records, -1, parquetProps, arrowProps), io.EOF)
}

func makeRec(mem memory.Allocator, nCols, nRows int) arrow.Record {
vals := make([]int8, nRows)
for val := 0; val < nRows; val++ {
vals[val] = int8(val)
}

bldr := array.NewInt8Builder(mem)
defer bldr.Release()

bldr.AppendValues(vals, nil)
arr := bldr.NewArray()
defer arr.Release()

fields := make([]arrow.Field, nCols)
cols := make([]arrow.Array, nCols)
for i := 0; i < nCols; i++ {
fields[i] = arrow.Field{Name: fmt.Sprintf("field_%d", i), Type: arrow.PrimitiveTypes.Int8}
cols[i] = arr // array.NewRecord will retain these
}

schema := arrow.NewSchema(fields, nil)
return array.NewRecord(schema, cols, int64(nRows))
}
8 changes: 4 additions & 4 deletions go/adbc/pkg/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ import (
"unsafe"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/cdata"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/apache/arrow/go/v17/arrow/memory/mallocator"
"github.com/apache/arrow/go/v18/arrow/array"
"github.com/apache/arrow/go/v18/arrow/cdata"
"github.com/apache/arrow/go/v18/arrow/memory"
"github.com/apache/arrow/go/v18/arrow/memory/mallocator"
)

// Must use malloc() to respect CGO rules
Expand Down

0 comments on commit d05e1ba

Please sign in to comment.