From dc253c760f2bff63c1a8c552a6e884ec44dacf65 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Wed, 19 Jul 2023 20:47:25 -0700 Subject: [PATCH 01/10] Initial changes for Union support --- .../Arrays/ArrayDataConcatenator.cs | 16 +++- .../Arrays/ArrayDataTypeComparer.cs | 12 ++- .../Apache.Arrow/Arrays/ArrowArrayFactory.cs | 2 +- .../Apache.Arrow/Arrays/DenseUnionArray.cs | 34 ++++++++ .../Apache.Arrow/Arrays/SparseUnionArray.cs | 29 +++++++ csharp/src/Apache.Arrow/Arrays/UnionArray.cs | 77 +++++++++++++++---- .../src/Apache.Arrow/C/CArrowArrayImporter.cs | 38 +++++++++ .../Apache.Arrow/C/CArrowSchemaExporter.cs | 18 +++++ .../Apache.Arrow/C/CArrowSchemaImporter.cs | 56 ++++++++++---- .../src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 17 ++++ .../Ipc/ArrowTypeFlatbufferBuilder.cs | 15 +++- csharp/src/Apache.Arrow/Types/UnionType.cs | 7 +- .../IntegrationCommand.cs | 6 ++ .../Apache.Arrow.Tests/ArrayTypeComparer.cs | 19 ++++- .../ArrowArrayConcatenatorTests.cs | 23 +++++- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 18 +++++ .../CDataInterfacePythonTests.cs | 34 ++++++-- csharp/test/Apache.Arrow.Tests/TestData.cs | 6 ++ 18 files changed, 384 insertions(+), 43 deletions(-) create mode 100644 csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs create mode 100644 csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs index 2b5e59517cf19..6f25c52398d1b 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs @@ -48,7 +48,8 @@ private class ArrayDataConcatenationVisitor : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { public ArrayData Result { get; private set; } private readonly IReadOnlyList _arrayDataList; @@ -113,6 +114,19 @@ public void Visit(StructType type) Result = new ArrayData(type, _arrayDataList[0].Length, _arrayDataList[0].NullCount, 0, _arrayDataList[0].Buffers, children); } + public void Visit(UnionType type) + { + CheckData(type, type.Mode == UnionMode.Sparse ? 1 : 2); + List children = new List(type.Fields.Count); + + for (int i = 0; i < type.Fields.Count; i++) + { + children.Add(Concatenate(SelectChildren(i), _allocator)); + } + + Result = new ArrayData(type, _arrayDataList[0].Length, _arrayDataList[0].NullCount, 0, _arrayDataList[0].Buffers, children); + } + public void Visit(IArrowType type) { throw new NotImplementedException($"Concatenation for {type.Name} is not supported yet."); diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs index 1e0524e53e309..b55281eab8643 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataTypeComparer.cs @@ -26,7 +26,8 @@ internal sealed class ArrayDataTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; private bool _dataTypeMatch; @@ -111,6 +112,15 @@ public void Visit(StructType actualType) } } + public void Visit(UnionType actualType) + { + if (_expectedType is UnionType expectedType + && CompareNested(expectedType, actualType)) + { + _dataTypeMatch = true; + } + } + private static bool CompareNested(NestedType expectedType, NestedType actualType) { if (expectedType.Fields.Count != actualType.Fields.Count) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index 845cbbd3e56f2..d4c60b31754fd 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -62,7 +62,7 @@ public static IArrowArray BuildArray(ArrayData data) case ArrowTypeId.Struct: return new StructArray(data); case ArrowTypeId.Union: - return new UnionArray(data); + return UnionArray.Create(data); case ArrowTypeId.Date64: return new Date64Array(data); case ArrowTypeId.Date32: diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs new file mode 100644 index 0000000000000..9f76e83984bbc --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; +using System; + +namespace Apache.Arrow +{ + public class DenseUnionArray : UnionArray + { + public ArrowBuffer ValueOffsetBuffer => Data.Buffers[1]; + + public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo(); + + public DenseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Dense, Type.Mode); + data.EnsureBufferCount(2); // TODO: + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs new file mode 100644 index 0000000000000..ba12c4aa587ef --- /dev/null +++ b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Apache.Arrow.Types; + +namespace Apache.Arrow +{ + public class SparseUnionArray : UnionArray + { + public SparseUnionArray(ArrayData data) + : base(data) + { + ValidateMode(UnionMode.Sparse, Type.Mode); + data.EnsureBufferCount(1); + } + } +} diff --git a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs index 8bccea2b59e31..1781061a4055e 100644 --- a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs @@ -15,37 +15,88 @@ using Apache.Arrow.Types; using System; +using System.Collections.Generic; +using System.Threading; namespace Apache.Arrow { - public class UnionArray: Array + public abstract class UnionArray : IArrowArray { - public UnionType Type => Data.DataType as UnionType; + private IReadOnlyList _fields; - public UnionMode Mode => Type.Mode; + public IReadOnlyList Fields => + LazyInitializer.EnsureInitialized(ref _fields, () => InitializeFields()); + + public ArrayData Data { get; } - public ArrowBuffer TypeBuffer => Data.Buffers[1]; + public UnionType Type => (UnionType)Data.DataType; - public ArrowBuffer ValueOffsetBuffer => Data.Buffers[2]; + public UnionMode Mode => Type.Mode; + + public ArrowBuffer TypeBuffer => Data.Buffers[0]; public ReadOnlySpan TypeIds => TypeBuffer.Span; - public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo().Slice(0, Length + 1); + public int Length => Data.Length; + + public int Offset => Data.Offset; - public UnionArray(ArrayData data) - : base(data) + public int NullCount => Data.NullCount; + + public bool IsValid(int index) => NullCount == 0 || Fields[TypeIds[index]].IsValid(index); + + public bool IsNull(int index) => !IsValid(index); + + protected UnionArray(ArrayData data) { + Data = data; data.EnsureDataType(ArrowTypeId.Union); - data.EnsureBufferCount(3); } - public IArrowArray GetChild(int index) + public static UnionArray Create(ArrayData data) { - // TODO: Implement - throw new NotImplementedException(); + return ((UnionType)data.DataType).Mode switch + { + UnionMode.Dense => new DenseUnionArray(data), + UnionMode.Sparse => new SparseUnionArray(data), + _ => throw new InvalidOperationException("unknown union mode in array creation") + }; } - public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor); + public void Accept(IArrowArrayVisitor visitor) => Array.Accept(this, visitor); + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + Data.Dispose(); + } + } + + protected static void ValidateMode(UnionMode expected, UnionMode actual) + { + if (expected != actual) + { + throw new ArgumentException( + $"Specified union mode <{actual}> does not match expected mode <{expected}>", + "Mode"); + } + } + + private IReadOnlyList InitializeFields() + { + IArrowArray[] result = new IArrowArray[Data.Children.Length]; + for (int i = 0; i < Data.Children.Length; i++) + { + result[i] = ArrowArrayFactory.BuildArray(Data.Children[i]); + } + return result; + } } } diff --git a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs index e1314e5a62253..8eebe2b35adcd 100644 --- a/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowArrayImporter.cs @@ -161,6 +161,15 @@ private ArrayData GetAsArrayData(CArrowArray* cArray, IArrowType type) buffers = new ArrowBuffer[] { ImportValidityBuffer(cArray) }; break; case ArrowTypeId.Union: + UnionType unionType = (UnionType)type; + children = ProcessStructChildren(cArray, unionType.Fields); + buffers = unionType.Mode switch + { + UnionMode.Dense => ImportDenseUnionBuffers(cArray), + UnionMode.Sparse => ImportSparseUnionBuffers(cArray), + _ => throw new InvalidOperationException("unknown union mode in import") + }; ; + break; case ArrowTypeId.Map: break; case ArrowTypeId.Null: @@ -264,6 +273,35 @@ private ArrowBuffer[] ImportListBuffers(CArrowArray* cArray) return buffers; } + private ArrowBuffer[] ImportDenseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 2) + { + throw new InvalidOperationException("Dense union arrays are expected to have exactly two children"); + } + int length = checked((int)cArray->length); + int offsetsLength = length * 4; + + ArrowBuffer[] buffers = new ArrowBuffer[2]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, length)); + buffers[1] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[1], 0, offsetsLength)); + + return buffers; + } + + private ArrowBuffer[] ImportSparseUnionBuffers(CArrowArray* cArray) + { + if (cArray->n_buffers != 1) + { + throw new InvalidOperationException("Sparse union arrays are expected to have exactly one child"); + } + + ArrowBuffer[] buffers = new ArrowBuffer[1]; + buffers[0] = new ArrowBuffer(AddMemory((IntPtr)cArray->buffers[0], 0, checked((int)cArray->length))); + + return buffers; + } + private ArrowBuffer[] ImportFixedWidthBuffers(CArrowArray* cArray, int bitWidth) { if (cArray->n_buffers != 2) diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index 9053e80664e31..b864fd23f6429 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -124,6 +124,23 @@ public static unsafe void ExportSchema(Schema schema, CArrowSchema* out_schema) _ => throw new InvalidDataException($"Unsupported time unit for export: {unit}"), }; + private static string FormatUnion(UnionType unionType) + { + StringBuilder builder = new StringBuilder(); + builder.Append(unionType.Mode switch + { + UnionMode.Sparse => "+us:", + UnionMode.Dense => "+ud:", + _ => throw new InvalidDataException($"Unsupported time unit for export: {unionType.Mode}"), + }); + for (int i = 0; i < unionType.TypeCodes.Count; i++) + { + if (i > 0) { builder.Append(','); } + builder.Append(unionType.TypeCodes[i]); + } + return builder.ToString(); + } + private static string GetFormat(IArrowType datatype) { switch (datatype) @@ -168,6 +185,7 @@ private static string GetFormat(IArrowType datatype) // Nested case ListType _: return "+l"; case StructType _: return "+s"; + case UnionType u: return FormatUnion(u); // Dictionary case DictionaryType dictionaryType: return GetFormat(dictionaryType.IndexType); diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs index 89c9481270c79..3b6af80ca2d73 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs @@ -180,21 +180,7 @@ public ArrowType GetAsType() } else if (format == "+s") { - var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; - - for (int i = 0; i < _cSchema->n_children; i++) - { - if (_cSchema->GetChild(i) == null) - { - throw new InvalidDataException("Expected struct type child to be non-null."); - } - child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); - } - - - List childFields = child_schemas.Select(schema => schema.GetAsField()).ToList(); - - return new StructType(childFields); + return new StructType(ParseChildren("struct")); } // TODO: Map type and large list type @@ -240,6 +226,30 @@ public ArrowType GetAsType() return new FixedSizeBinaryType(width); } + // Unions + if (format.StartsWith("+ud:") || format.StartsWith("+us:")) + { + UnionMode unionMode = format[2] == 'd' ? UnionMode.Dense : UnionMode.Sparse; + List typeCodes = new List(); + int pos = 4; + do + { + int next = format.IndexOf(',', pos); + if (next < 0) { next = format.Length; } + + byte code; + if (!byte.TryParse(format.Substring(pos, next - pos), out code)) + { + throw new InvalidDataException($"Invalid type code for union import: {format.Substring(pos, next - pos)}"); + } + typeCodes.Add(code); + + pos = next + 1; + } while (pos < format.Length); + + return new UnionType(ParseChildren("union"), typeCodes, unionMode); + } + return format switch { // Primitives @@ -299,6 +309,22 @@ public Schema GetAsSchema() } } + private List ParseChildren(string typeName) + { + var child_schemas = new ImportedArrowSchema[_cSchema->n_children]; + + for (int i = 0; i < _cSchema->n_children; i++) + { + if (_cSchema->GetChild(i) == null) + { + throw new InvalidDataException($"Expected {typeName} type child to be non-null."); + } + child_schemas[i] = new ImportedArrowSchema(_cSchema->GetChild(i), isRoot: false); + } + + return child_schemas.Select(schema => schema.GetAsField()).ToList(); + } + private unsafe static IReadOnlyDictionary GetMetadata(byte* metadata) { if (metadata == null) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 5ffda3dfba6f5..a3ae417d71b84 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -54,6 +54,7 @@ internal class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, @@ -148,6 +149,22 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + _buffers.Add(CreateBuffer(array.TypeBuffer)); + + ArrowBuffer? offsets = (array as DenseUnionArray)?.ValueOffsetBuffer; + if (offsets != null) + { + _buffers.Add(CreateBuffer(offsets.Value)); + } + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(this); + } + } + public void Visit(DictionaryArray array) { // Dictionary is serialized separately in Dictionary serialization. diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index 400937669e309..bf457155f0811 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -112,7 +112,10 @@ public void Visit(ListType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + Flatbuf.Union.StartUnion(Builder); + Result = FieldType.Build( + Flatbuf.Type.Union, + Flatbuf.Union.CreateUnion(Builder, ToFlatBuffer(type.Mode))); } public void Visit(StringType type) @@ -271,5 +274,15 @@ private static Flatbuf.TimeUnit ToFlatBuffer(TimeUnit unit) return result; } + + private static Flatbuf.UnionMode ToFlatBuffer(Types.UnionMode mode) + { + return mode switch + { + Types.UnionMode.Dense => Flatbuf.UnionMode.Dense, + Types.UnionMode.Sparse => Flatbuf.UnionMode.Sparse, + _ => throw new ArgumentException(nameof(mode), $"unsupported union mode <{mode}>") + }; + } } } diff --git a/csharp/src/Apache.Arrow/Types/UnionType.cs b/csharp/src/Apache.Arrow/Types/UnionType.cs index 293271018aa26..d290d88aa3d93 100644 --- a/csharp/src/Apache.Arrow/Types/UnionType.cs +++ b/csharp/src/Apache.Arrow/Types/UnionType.cs @@ -24,18 +24,19 @@ public enum UnionMode Dense } - public sealed class UnionType : ArrowType + public sealed class UnionType : NestedType { public override ArrowTypeId TypeId => ArrowTypeId.Union; public override string Name => "union"; public UnionMode Mode { get; } - - public IEnumerable TypeCodes { get; } + + public List TypeCodes { get; } public UnionType( IEnumerable fields, IEnumerable typeCodes, UnionMode mode = UnionMode.Sparse) + : base(fields.ToArray()) { TypeCodes = typeCodes.ToList(); Mode = mode; diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs index 5b1ed98993812..af70f8946da5d 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs @@ -275,6 +275,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor { private JsonFieldData JsonFieldData { get; } @@ -486,6 +487,11 @@ public void Visit(StructType type) throw new NotImplementedException(); } + public void Visit(UnionType type) + { + throw new NotImplementedException(); + } + private static byte[] ConvertHexStringToByteArray(string hexString) { byte[] data = new byte[hexString.Length / 2]; diff --git a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs index f75111b66d087..bb902190c763b 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs @@ -27,7 +27,8 @@ public class ArrayTypeComparer : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private readonly IArrowType _expectedType; @@ -103,6 +104,22 @@ public void Visit(StructType actualType) CompareNested(expectedType, actualType); } + public void Visit(UnionType actualType) + { + Assert.IsAssignableFrom(_expectedType); + UnionType expectedType = (UnionType)_expectedType; + + Assert.Equal(expectedType.Mode, actualType.Mode); + + Assert.Equal(expectedType.TypeCodes.Count, actualType.TypeCodes.Count); + for (int i = 0; i < expectedType.TypeCodes.Count; i++) + { + Assert.Equal(expectedType.TypeCodes[i], actualType.TypeCodes[i]); + } + + CompareNested(expectedType, actualType); + } + private static void CompareNested(NestedType expectedType, NestedType actualType) { Assert.Equal(expectedType.Fields.Count, actualType.Fields.Count); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index 6b3277ed572e0..472230d4423d2 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -76,6 +76,22 @@ private static IEnumerable, IArrowArray>> GenerateTestDa new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() }), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new byte[] { 0, 1 }, + UnionMode.Sparse + ), + new UnionType( + new List{ + new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), + new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() + }, + new byte[] { 0, 1 }, + UnionMode.Dense + ), }; foreach (IArrowType type in targetTypes) @@ -117,7 +133,8 @@ private class TestDataGenerator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, - IArrowTypeVisitor + IArrowTypeVisitor, + IArrowTypeVisitor { private List> _baseData; @@ -354,6 +371,10 @@ public void Visit(StructType type) ExpectedArray = new StructArray(type, 3, new List { resultStringArray, resultInt32Array }, nullBitmapBuffer, 1); } + public void Visit(UnionType type) + { + throw new NotImplementedException(); + } public void Visit(IArrowType type) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index acfe72f83195e..126a8825a1fb9 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -145,6 +145,24 @@ public void Visit(StructArray array) } } + public void Visit(UnionArray array) + { + Assert.IsAssignableFrom(_expectedArray); + UnionArray expectedArray = (UnionArray)_expectedArray; + + Assert.Equal(expectedArray.Mode, array.Mode); + Assert.Equal(expectedArray.Length, array.Length); + Assert.Equal(expectedArray.NullCount, array.NullCount); + Assert.Equal(expectedArray.Offset, array.Offset); + Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length); + Assert.Equal(expectedArray.Fields.Count, array.Fields.Count); + + for (int i = 0; i < array.Fields.Count; i++) + { + array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i], _strictCompare)); + } + } + public void Visit(DictionaryArray array) { Assert.IsAssignableFrom(_expectedArray); diff --git a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs index 084d7bfb014cc..14530c9534dd3 100644 --- a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs +++ b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs @@ -102,6 +102,9 @@ private static Schema GetTestSchema() .Field(f => f.Name("dict_string_ordered").DataType(new DictionaryType(Int32Type.Default, StringType.Default, true)).Nullable(false)) .Field(f => f.Name("list_dict_string").DataType(new ListType(new DictionaryType(Int32Type.Default, StringType.Default, false))).Nullable(false)) + .Field(f => f.Name("dense_union").DataType(new UnionType(new[] { new Field("i64", Int64Type.Default, false), new Field("f32", FloatType.Default, true), }, new[] { (byte)0, (byte)1 }, UnionMode.Dense))) + .Field(f => f.Name("sparse_union").DataType(new UnionType(new[] { new Field("i32", Int32Type.Default, true), new Field("f64", DoubleType.Default, false), }, new[] { (byte)0, (byte)1 }, UnionMode.Sparse))) + // Checking wider characters. .Field(f => f.Name("hello 你好 😄").DataType(BooleanType.Default).Nullable(true)) @@ -160,6 +163,9 @@ private static IEnumerable GetPythonFields() yield return pa.field("dict_string_ordered", pa.dictionary(pa.int32(), pa.utf8(), true), false); yield return pa.field("list_dict_string", pa.list_(pa.dictionary(pa.int32(), pa.utf8(), false)), false); + yield return pa.field("dense_union", pa.dense_union(List(pa.field("i64", pa.int64(), false), pa.field("f32", pa.float32(), true)))); + yield return pa.field("sparse_union", pa.sparse_union(List(pa.field("i32", pa.int32(), true), pa.field("f64", pa.float64(), false)))); + yield return pa.field("hello 你好 😄", pa.bool_(), true); } } @@ -473,19 +479,26 @@ public unsafe void ImportRecordBatch() pa.array(List(0.0, 1.4, 2.5, 3.6, 4.7)), pa.array(new PyObject[] { List(1, 2), List(3, 4), PyObject.None, PyObject.None, List(5, 4, 3) }), pa.StructArray.from_arrays( - new PyList(new PyObject[] - { + List( List(10, 9, null, null, null), List("banana", "apple", "orange", "cherry", "grape"), - List(null, 4.3, -9, 123.456, 0), - }), + List(null, 4.3, -9, 123.456, 0) + ), new[] { "fld1", "fld2", "fld3" }), pa.DictionaryArray.from_arrays( pa.array(List(1, 0, 1, 1, null)), - pa.array(List("foo", "bar")) + pa.array(List("foo", "bar"))), + pa.UnionArray.from_dense( + pa.array(List(0, 1, 1, 0, 0), type: "int8"), + pa.array(List(0, 0, 1, 1, 2), type: "int32"), + List( + pa.array(List(1, 4, null)), + pa.array(List("two", "three")) ), + /* field name */ List("i32", "s"), + /* type codes */ List(3, 2)), }), - new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7" }); + new[] { "col1", "col2", "col3", "col4", "col5", "col6", "col7", "col8" }); dynamic batch = table.to_batches()[0]; @@ -546,6 +559,10 @@ public unsafe void ImportRecordBatch() Assert.Equal(2, col7b.Length); Assert.Equal("foo", col7b.GetString(0)); Assert.Equal("bar", col7b.GetString(1)); + + UnionArray col8 = (UnionArray)recordBatch.Column("col8"); + Assert.Equal(5, col8.Length); + Assert.True(col8 is DenseUnionArray); } [SkippableFact] @@ -767,6 +784,11 @@ private static PyObject List(params string[] values) return new PyList(values.Select(i => i == null ? PyObject.None : new PyString(i)).ToArray()); } + private static PyObject List(params PyObject[] values) + { + return new PyList(values); + } + sealed class TestArrayStream : IArrowArrayStream { private readonly RecordBatch[] _batches; diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index 96c6fafee270c..a2e13f5df728b 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -123,6 +123,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, + IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, IArrowTypeVisitor, @@ -287,6 +288,11 @@ public void Visit(StructType type) Array = new StructArray(type, Length, childArrays, nullBitmap.Build()); } + public void Visit(UnionType type) + { + throw new NotImplementedException(); + } + public void Visit(DictionaryType type) { Int32Array.Builder indicesBuilder = new Int32Array.Builder().Reserve(Length); From 1bb1417cefb46c7124bea20d3c16fff715ac9f12 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 20 Jul 2023 06:58:11 -0700 Subject: [PATCH 02/10] Union-related fixes --- .../Arrays/ArrayDataConcatenator.cs | 50 ++++++++++- .../Apache.Arrow/Arrays/DenseUnionArray.cs | 18 ++++ .../Arrays/PrimitiveArrayBuilder.cs | 3 + .../Apache.Arrow/Arrays/SparseUnionArray.cs | 17 ++++ csharp/src/Apache.Arrow/Arrays/UnionArray.cs | 2 +- .../ArrowArrayConcatenatorTests.cs | 85 ++++++++++++++++++- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 1 + 7 files changed, 172 insertions(+), 4 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs index 6f25c52398d1b..071a9c25cc9cc 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrayDataConcatenator.cs @@ -116,7 +116,14 @@ public void Visit(StructType type) public void Visit(UnionType type) { - CheckData(type, type.Mode == UnionMode.Sparse ? 1 : 2); + int bufferCount = type.Mode switch + { + UnionMode.Sparse => 1, + UnionMode.Dense => 2, + _ => throw new InvalidOperationException("TODO"), + }; + + CheckData(type, bufferCount); List children = new List(type.Fields.Count); for (int i = 0; i < type.Fields.Count; i++) @@ -124,7 +131,14 @@ public void Visit(UnionType type) children.Add(Concatenate(SelectChildren(i), _allocator)); } - Result = new ArrayData(type, _arrayDataList[0].Length, _arrayDataList[0].NullCount, 0, _arrayDataList[0].Buffers, children); + ArrowBuffer[] buffers = new ArrowBuffer[bufferCount]; + buffers[0] = ConcatenateUnionTypeBuffer(); + if (bufferCount > 1) + { + buffers[1] = ConcatenateUnionOffsetBuffer(); + } + + Result = new ArrayData(type, _totalLength, _totalNullCount, 0, buffers, children); } public void Visit(IArrowType type) @@ -235,6 +249,38 @@ private ArrowBuffer ConcatenateOffsetBuffer() return builder.Build(_allocator); } + private ArrowBuffer ConcatenateUnionTypeBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + + foreach (ArrayData arrayData in _arrayDataList) + { + builder.Append(arrayData.Buffers[0]); + } + + return builder.Build(_allocator); + } + + private ArrowBuffer ConcatenateUnionOffsetBuffer() + { + var builder = new ArrowBuffer.Builder(_totalLength); + int baseOffset = 0; + + foreach (ArrayData arrayData in _arrayDataList) + { + ReadOnlySpan span = arrayData.Buffers[1].Span.CastTo(); + foreach (int offset in span) + { + builder.Append(baseOffset + offset); + } + + // The next offset must start from the current last offset. + baseOffset += span[arrayData.Length]; + } + + return builder.Build(_allocator); + } + private List SelectChildren(int index) { var children = new List(_arrayDataList.Count); diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs index 9f76e83984bbc..b9d32a5b5f9d9 100644 --- a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -15,6 +15,8 @@ using Apache.Arrow.Types; using System; +using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow { @@ -24,6 +26,22 @@ public class DenseUnionArray : UnionArray public ReadOnlySpan ValueOffsets => ValueOffsetBuffer.Span.CastTo(); + public DenseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + ArrowBuffer valuesOffsetBuffer, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds, valuesOffsetBuffer }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Sparse, Type.Mode); + } + public DenseUnionArray(ArrayData data) : base(data) { diff --git a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs index a50d4b52c3257..67fe46633c18f 100644 --- a/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs +++ b/csharp/src/Apache.Arrow/Arrays/PrimitiveArrayBuilder.cs @@ -137,6 +137,9 @@ public TBuilder Append(T value) return Instance; } + public TBuilder Append(T? value) => + (value == null) ? AppendNull() : Append(value.Value); + public TBuilder Append(ReadOnlySpan span) { int len = ValueBuffer.Length; diff --git a/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs index ba12c4aa587ef..b79c44c979e47 100644 --- a/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/SparseUnionArray.cs @@ -14,11 +14,28 @@ // limitations under the License. using Apache.Arrow.Types; +using System.Collections.Generic; +using System.Linq; namespace Apache.Arrow { public class SparseUnionArray : UnionArray { + public SparseUnionArray( + IArrowType dataType, + int length, + IEnumerable children, + ArrowBuffer typeIds, + int nullCount = 0, + int offset = 0) + : base(new ArrayData( + dataType, length, nullCount, offset, new[] { typeIds }, + children.Select(child => child.Data))) + { + _fields = children.ToArray(); + ValidateMode(UnionMode.Sparse, Type.Mode); + } + public SparseUnionArray(ArrayData data) : base(data) { diff --git a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs index 1781061a4055e..0a7ae288fd0c5 100644 --- a/csharp/src/Apache.Arrow/Arrays/UnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/UnionArray.cs @@ -22,7 +22,7 @@ namespace Apache.Arrow { public abstract class UnionArray : IArrowArray { - private IReadOnlyList _fields; + protected IReadOnlyList _fields; public IReadOnlyList Fields => LazyInitializer.EnsureInitialized(ref _fields, () => InitializeFields()); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index 472230d4423d2..c95e9a51ac341 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -17,6 +17,8 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; +using System.Reflection.Emit; +using System.Text; using Apache.Arrow.Memory; using Apache.Arrow.Types; using Xunit; @@ -373,7 +375,88 @@ public void Visit(StructType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + bool isDense = type.Mode == UnionMode.Dense; + + StringArray.Builder stringResultBuilder = new StringArray.Builder().Reserve(_baseDataTotalElementCount); + Int32Array.Builder intResultBuilder = new Int32Array.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder typeResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + ArrowBuffer.Builder offsetResultBuilder = new ArrowBuffer.Builder().Reserve(_baseDataTotalElementCount); + int resultNullCount = 0; + + for (int i = 0; i < _baseDataListCount; i++) + { + List dataList = _baseData[i]; + StringArray.Builder stringBuilder = new StringArray.Builder().Reserve(dataList.Count); + Int32Array.Builder intBuilder = new Int32Array.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder typeBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder().Reserve(dataList.Count); + int nullCount = 0; + + for (int j = 0; j < dataList.Count; j++) + { + byte index = (byte)Math.Max(j % 3, 1); + int? intValue = (index == 1) ? dataList[j] : null; + string stringValue = (index == 1) ? null : dataList[j]?.ToString(); + typeBuilder.Append(index); + + if (isDense) + { + if (index == 0) + { + offsetBuilder.Append(stringBuilder.Length); + offsetResultBuilder.Append(stringResultBuilder.Length); + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + } + else + { + offsetBuilder.Append(intBuilder.Length); + offsetResultBuilder.Append(intResultBuilder.Length); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + } + else + { + stringBuilder.Append(stringValue); + stringResultBuilder.Append(stringValue); + intBuilder.Append(intValue); + intResultBuilder.Append(intValue); + } + + if (dataList[j] == null) + { + nullCount++; + resultNullCount++; + } + } + + ArrowBuffer[] buffers; + if (isDense) + { + buffers = new[] { typeBuilder.Build(), offsetBuilder.Build() }; + } + else + { + buffers = new[] { typeBuilder.Build() }; + } + TestTargetArrayList.Add(UnionArray.Create(new ArrayData( + type, dataList.Count, nullCount, 0, buffers, + new[] { stringBuilder.Build().Data, intBuilder.Build().Data }))); + } + + ArrowBuffer[] resultBuffers; + if (isDense) + { + resultBuffers = new[] { typeResultBuilder.Build(), offsetResultBuilder.Build() }; + } + else + { + resultBuffers = new[] { typeResultBuilder.Build() }; + } + ExpectedArray = UnionArray.Create(new ArrayData( + type, _baseDataTotalElementCount, resultNullCount, 0, resultBuffers, + new[] { stringResultBuilder.Build().Data, intResultBuilder.Build().Data })); } public void Visit(IArrowType type) diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 126a8825a1fb9..713d17b55f9cb 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -88,6 +88,7 @@ private class ArrayComparer : IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, + IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, IArrowArrayVisitor, From 6b3ff656c176907b54eaf10331d839fa27099fbc Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 20 Jul 2023 14:35:45 -0700 Subject: [PATCH 03/10] Fixed IPC to work correctly for unions --- .../Apache.Arrow/Arrays/ArrowArrayFactory.cs | 14 +++++ .../Apache.Arrow/C/CArrowSchemaExporter.cs | 4 +- .../Apache.Arrow/C/CArrowSchemaImporter.cs | 10 ++-- csharp/src/Apache.Arrow/ChunkedArray.cs | 13 ++-- csharp/src/Apache.Arrow/Column.cs | 2 +- .../Extensions/FlatbufExtensions.cs | 10 ++++ .../Apache.Arrow/Interfaces/IArrowArray.cs | 4 -- .../Ipc/ArrowReaderImplementation.cs | 33 +++++----- .../Ipc/ArrowTypeFlatbufferBuilder.cs | 5 +- .../src/Apache.Arrow/Ipc/MessageSerializer.cs | 4 ++ csharp/src/Apache.Arrow/Table.cs | 4 +- csharp/src/Apache.Arrow/Types/UnionType.cs | 6 +- .../Apache.Arrow.Tests/ArrayTypeComparer.cs | 6 +- .../ArrowArrayConcatenatorTests.cs | 4 +- .../CDataInterfacePythonTests.cs | 4 +- csharp/test/Apache.Arrow.Tests/TableTests.cs | 4 +- csharp/test/Apache.Arrow.Tests/TestData.cs | 60 ++++++++++++++++++- 17 files changed, 133 insertions(+), 54 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs index d4c60b31754fd..0b665292ef5d4 100644 --- a/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs +++ b/csharp/src/Apache.Arrow/Arrays/ArrowArrayFactory.cs @@ -89,5 +89,19 @@ public static IArrowArray BuildArray(ArrayData data) throw new NotSupportedException($"An ArrowArray cannot be built for type {data.DataType.TypeId}."); } } + + public static IArrowArray Slice(IArrowArray array, int offset, int length) + { + if (offset > array.Length) + { + throw new ArgumentException($"Offset {offset} cannot be greater than Length {array.Length} for Array.Slice"); + } + + length = Math.Min(array.Data.Length - offset, length); + offset += array.Data.Offset; + + ArrayData newData = array.Data.Slice(offset, length); + return BuildArray(newData); + } } } diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index b864fd23f6429..21a08c21337d8 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -133,10 +133,10 @@ private static string FormatUnion(UnionType unionType) UnionMode.Dense => "+ud:", _ => throw new InvalidDataException($"Unsupported time unit for export: {unionType.Mode}"), }); - for (int i = 0; i < unionType.TypeCodes.Count; i++) + for (int i = 0; i < unionType.TypeIds.Length; i++) { if (i > 0) { builder.Append(','); } - builder.Append(unionType.TypeCodes[i]); + builder.Append(unionType.TypeIds[i]); } return builder.ToString(); } diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs index 3b6af80ca2d73..b661efbe5ce1d 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaImporter.cs @@ -230,24 +230,24 @@ public ArrowType GetAsType() if (format.StartsWith("+ud:") || format.StartsWith("+us:")) { UnionMode unionMode = format[2] == 'd' ? UnionMode.Dense : UnionMode.Sparse; - List typeCodes = new List(); + List typeIds = new List(); int pos = 4; do { int next = format.IndexOf(',', pos); if (next < 0) { next = format.Length; } - byte code; - if (!byte.TryParse(format.Substring(pos, next - pos), out code)) + int code; + if (!int.TryParse(format.Substring(pos, next - pos), out code)) { throw new InvalidDataException($"Invalid type code for union import: {format.Substring(pos, next - pos)}"); } - typeCodes.Add(code); + typeIds.Add(code); pos = next + 1; } while (pos < format.Length); - return new UnionType(ParseChildren("union"), typeCodes, unionMode); + return new UnionType(ParseChildren("union"), typeIds, unionMode); } return format switch diff --git a/csharp/src/Apache.Arrow/ChunkedArray.cs b/csharp/src/Apache.Arrow/ChunkedArray.cs index 5f25acfe04a2f..8565032276d09 100644 --- a/csharp/src/Apache.Arrow/ChunkedArray.cs +++ b/csharp/src/Apache.Arrow/ChunkedArray.cs @@ -15,7 +15,6 @@ using System; using System.Collections.Generic; -using Apache.Arrow; using Apache.Arrow.Types; namespace Apache.Arrow @@ -25,7 +24,7 @@ namespace Apache.Arrow /// public class ChunkedArray { - private IList Arrays { get; } + private IList Arrays { get; } public IArrowType DataType { get; } public long Length { get; } public long NullCount { get; } @@ -35,9 +34,9 @@ public int ArrayCount get => Arrays.Count; } - public Array Array(int index) => Arrays[index]; + public IArrowArray Array(int index) => Arrays[index]; - public ChunkedArray(IList arrays) + public ChunkedArray(IList arrays) { Arrays = arrays ?? throw new ArgumentNullException(nameof(arrays)); if (arrays.Count < 1) @@ -45,7 +44,7 @@ public ChunkedArray(IList arrays) throw new ArgumentException($"Count must be at least 1. Got {arrays.Count} instead"); } DataType = arrays[0].Data.DataType; - foreach (Array array in arrays) + foreach (IArrowArray array in arrays) { Length += array.Length; NullCount += array.NullCount; @@ -69,10 +68,10 @@ public ChunkedArray Slice(long offset, long length) curArrayIndex++; } - IList newArrays = new List(); + IList newArrays = new List(); while (curArrayIndex < numArrays && length > 0) { - newArrays.Add(Arrays[curArrayIndex].Slice((int)offset, + newArrays.Add(ArrowArrayFactory.Slice(Arrays[curArrayIndex], (int)offset, length > Arrays[curArrayIndex].Length ? Arrays[curArrayIndex].Length : (int)length)); length -= Arrays[curArrayIndex].Length - offset; offset = 0; diff --git a/csharp/src/Apache.Arrow/Column.cs b/csharp/src/Apache.Arrow/Column.cs index 4eaf9a559e75d..7d99ed24d02d0 100644 --- a/csharp/src/Apache.Arrow/Column.cs +++ b/csharp/src/Apache.Arrow/Column.cs @@ -27,7 +27,7 @@ public class Column public Field Field { get; } public ChunkedArray Data { get; } - public Column(Field field, IList arrays) + public Column(Field field, IList arrays) { Data = new ChunkedArray(arrays); Field = field; diff --git a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs index d2a70bca9e4ec..35c5b3e55157d 100644 --- a/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs +++ b/csharp/src/Apache.Arrow/Extensions/FlatbufExtensions.cs @@ -80,6 +80,16 @@ public static Types.TimeUnit ToArrow(this Flatbuf.TimeUnit unit) throw new ArgumentException($"Unexpected Flatbuf TimeUnit", nameof(unit)); } } + + public static Types.UnionMode ToArrow(this Flatbuf.UnionMode mode) + { + return mode switch + { + Flatbuf.UnionMode.Dense => Types.UnionMode.Dense, + Flatbuf.UnionMode.Sparse => Types.UnionMode.Sparse, + _ => throw new ArgumentException($"Unsupported Flatbuf UnionMode", nameof(mode)), + }; + } } } diff --git a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs index 50fbc3af6dd72..9bcee36ef4eaf 100644 --- a/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs +++ b/csharp/src/Apache.Arrow/Interfaces/IArrowArray.cs @@ -32,9 +32,5 @@ public interface IArrowArray : IDisposable ArrayData Data { get; } void Accept(IArrowArrayVisitor visitor); - - //IArrowArray Slice(int offset); - - //IArrowArray Slice(int offset, int length); } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index ed93000736e0b..bc2a948527874 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -245,28 +245,27 @@ private ArrayData LoadPrimitiveField( throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message } - if (field.DataType.TypeId == ArrowTypeId.Null) + int buffers; + switch (field.DataType.TypeId) { - return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); - } - - ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); - if (!recordBatchEnumerator.MoveNextBuffer()) - { - throw new Exception("Unable to move to the next buffer."); + case ArrowTypeId.Null: + return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); + case ArrowTypeId.Union: + buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1; + break; + case ArrowTypeId.Struct: + buffers = 1; + break; + default: + buffers = 2; + break; } - ArrowBuffer[] arrowBuff; - if (field.DataType.TypeId == ArrowTypeId.Struct) + ArrowBuffer[] arrowBuff = new ArrowBuffer[buffers]; + for (int i = 0; i < buffers; i++) { - arrowBuff = new[] { nullArrowBuffer }; - } - else - { - ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); + arrowBuff[i] = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator); recordBatchEnumerator.MoveNextBuffer(); - - arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer }; } ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index bf457155f0811..3f2ba03899e77 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -112,10 +112,9 @@ public void Visit(ListType type) public void Visit(UnionType type) { - Flatbuf.Union.StartUnion(Builder); Result = FieldType.Build( Flatbuf.Type.Union, - Flatbuf.Union.CreateUnion(Builder, ToFlatBuffer(type.Mode))); + Flatbuf.Union.CreateUnion(Builder, ToFlatBuffer(type.Mode), Flatbuf.Union.CreateTypeIdsVector(Builder, type.TypeIds))); } public void Visit(StringType type) @@ -281,7 +280,7 @@ private static Flatbuf.UnionMode ToFlatBuffer(Types.UnionMode mode) { Types.UnionMode.Dense => Flatbuf.UnionMode.Dense, Types.UnionMode.Sparse => Flatbuf.UnionMode.Sparse, - _ => throw new ArgumentException(nameof(mode), $"unsupported union mode <{mode}>") + _ => throw new ArgumentException($"unsupported union mode <{mode}>", nameof(mode)), }; } } diff --git a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs index 7426ba4e86959..63c34e3e5f68d 100644 --- a/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs +++ b/csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs @@ -195,6 +195,10 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field, Field[] c case Flatbuf.Type.Struct_: Debug.Assert(childFields != null); return new Types.StructType(childFields); + case Flatbuf.Type.Union: + Debug.Assert(childFields != null); + Flatbuf.Union unionMetadata = field.Type().Value; + return new Types.UnionType(childFields, unionMetadata.GetTypeIdsArray(), unionMetadata.Mode.ToArrow()); default: throw new InvalidDataException($"Arrow primitive '{field.TypeType}' is unsupported."); } diff --git a/csharp/src/Apache.Arrow/Table.cs b/csharp/src/Apache.Arrow/Table.cs index 0b9f31557bec8..939ec23f54ff2 100644 --- a/csharp/src/Apache.Arrow/Table.cs +++ b/csharp/src/Apache.Arrow/Table.cs @@ -37,10 +37,10 @@ public static Table TableFromRecordBatches(Schema schema, IList rec List columns = new List(nColumns); for (int icol = 0; icol < nColumns; icol++) { - List columnArrays = new List(nBatches); + List columnArrays = new List(nBatches); for (int jj = 0; jj < nBatches; jj++) { - columnArrays.Add(recordBatches[jj].Column(icol) as Array); + columnArrays.Add(recordBatches[jj].Column(icol)); } columns.Add(new Column(schema.GetFieldByIndex(icol), columnArrays)); } diff --git a/csharp/src/Apache.Arrow/Types/UnionType.cs b/csharp/src/Apache.Arrow/Types/UnionType.cs index d290d88aa3d93..23fa3b45ab278 100644 --- a/csharp/src/Apache.Arrow/Types/UnionType.cs +++ b/csharp/src/Apache.Arrow/Types/UnionType.cs @@ -31,14 +31,14 @@ public sealed class UnionType : NestedType public UnionMode Mode { get; } - public List TypeCodes { get; } + public int[] TypeIds { get; } public UnionType( - IEnumerable fields, IEnumerable typeCodes, + IEnumerable fields, IEnumerable typeIds, UnionMode mode = UnionMode.Sparse) : base(fields.ToArray()) { - TypeCodes = typeCodes.ToList(); + TypeIds = typeIds.ToArray(); Mode = mode; } diff --git a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs index bb902190c763b..15cc16e65a1b0 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrayTypeComparer.cs @@ -111,10 +111,10 @@ public void Visit(UnionType actualType) Assert.Equal(expectedType.Mode, actualType.Mode); - Assert.Equal(expectedType.TypeCodes.Count, actualType.TypeCodes.Count); - for (int i = 0; i < expectedType.TypeCodes.Count; i++) + Assert.Equal(expectedType.TypeIds.Length, actualType.TypeIds.Length); + for (int i = 0; i < expectedType.TypeIds.Length; i++) { - Assert.Equal(expectedType.TypeCodes[i], actualType.TypeCodes[i]); + Assert.Equal(expectedType.TypeIds[i], actualType.TypeIds[i]); } CompareNested(expectedType, actualType); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index c95e9a51ac341..d0044eab2e0ce 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -83,7 +83,7 @@ private static IEnumerable, IArrowArray>> GenerateTestDa new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() }, - new byte[] { 0, 1 }, + new[] { 0, 1 }, UnionMode.Sparse ), new UnionType( @@ -91,7 +91,7 @@ private static IEnumerable, IArrowArray>> GenerateTestDa new Field.Builder().Name("Strings").DataType(StringType.Default).Nullable(true).Build(), new Field.Builder().Name("Ints").DataType(Int32Type.Default).Nullable(true).Build() }, - new byte[] { 0, 1 }, + new[] { 0, 1 }, UnionMode.Dense ), }; diff --git a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs index 14530c9534dd3..2baed14e0190c 100644 --- a/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs +++ b/csharp/test/Apache.Arrow.Tests/CDataInterfacePythonTests.cs @@ -102,8 +102,8 @@ private static Schema GetTestSchema() .Field(f => f.Name("dict_string_ordered").DataType(new DictionaryType(Int32Type.Default, StringType.Default, true)).Nullable(false)) .Field(f => f.Name("list_dict_string").DataType(new ListType(new DictionaryType(Int32Type.Default, StringType.Default, false))).Nullable(false)) - .Field(f => f.Name("dense_union").DataType(new UnionType(new[] { new Field("i64", Int64Type.Default, false), new Field("f32", FloatType.Default, true), }, new[] { (byte)0, (byte)1 }, UnionMode.Dense))) - .Field(f => f.Name("sparse_union").DataType(new UnionType(new[] { new Field("i32", Int32Type.Default, true), new Field("f64", DoubleType.Default, false), }, new[] { (byte)0, (byte)1 }, UnionMode.Sparse))) + .Field(f => f.Name("dense_union").DataType(new UnionType(new[] { new Field("i64", Int64Type.Default, false), new Field("f32", FloatType.Default, true), }, new[] { 0, 1 }, UnionMode.Dense))) + .Field(f => f.Name("sparse_union").DataType(new UnionType(new[] { new Field("i32", Int32Type.Default, true), new Field("f64", DoubleType.Default, false), }, new[] { 0, 1 }, UnionMode.Sparse))) // Checking wider characters. .Field(f => f.Name("hello 你好 😄").DataType(BooleanType.Default).Nullable(true)) diff --git a/csharp/test/Apache.Arrow.Tests/TableTests.cs b/csharp/test/Apache.Arrow.Tests/TableTests.cs index 45a14cc25616e..d3b9f057b6556 100644 --- a/csharp/test/Apache.Arrow.Tests/TableTests.cs +++ b/csharp/test/Apache.Arrow.Tests/TableTests.cs @@ -30,7 +30,7 @@ public static Table MakeTableWithOneColumnOfTwoIntArrays(int lengthOfEachArray) Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); Schema s0 = new Schema.Builder().Field(field).Build(); - Column column = new Column(field, new List { intArray, intArrayCopy }); + Column column = new Column(field, new List { intArray, intArrayCopy }); Table table = new Table(s0, new List { column }); return table; } @@ -60,7 +60,7 @@ public void TestTableFromRecordBatches() Table table1 = Table.TableFromRecordBatches(recordBatch1.Schema, recordBatches); Assert.Equal(20, table1.RowCount); - Assert.Equal(23, table1.ColumnCount); + Assert.Equal(25, table1.ColumnCount); FixedSizeBinaryType type = new FixedSizeBinaryType(17); Field newField1 = new Field(type.Name, type, false); diff --git a/csharp/test/Apache.Arrow.Tests/TestData.cs b/csharp/test/Apache.Arrow.Tests/TestData.cs index a2e13f5df728b..4269198de764e 100644 --- a/csharp/test/Apache.Arrow.Tests/TestData.cs +++ b/csharp/test/Apache.Arrow.Tests/TestData.cs @@ -59,6 +59,8 @@ public static RecordBatch CreateSampleRecordBatch(int length, int columnSetCount { builder.Field(CreateField(new DictionaryType(Int32Type.Default, StringType.Default, false), i)); builder.Field(CreateField(new FixedSizeBinaryType(16), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Sparse), i)); + builder.Field(CreateField(new UnionType(new[] { CreateField(StringType.Default, i), CreateField(Int32Type.Default, i) }, new[] { 0, 1 }, UnionMode.Dense), -i)); } //builder.Field(CreateField(HalfFloatType.Default)); @@ -290,7 +292,63 @@ public void Visit(StructType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + int[] lengths = new int[type.Fields.Count]; + if (type.Mode == UnionMode.Sparse) + { + for (int i = 0; i < lengths.Length; i++) + { + lengths[i] = Length; + } + } + else + { + int totalLength = Length; + int oneLength = Length / lengths.Length; + for (int i = 1; i < lengths.Length; i++) + { + lengths[i] = oneLength; + totalLength -= oneLength; + } + lengths[0] = totalLength; + } + + ArrayData[] childArrays = new ArrayData[type.Fields.Count]; + for (int i = 0; i < childArrays.Length; i++) + { + childArrays[i] = CreateArray(type.Fields[i], lengths[i]).Data; + } + + ArrowBuffer.Builder typeIdBuilder = new ArrowBuffer.Builder(Length); + byte index = 0; + for (int i = 0; i < Length; i++) + { + typeIdBuilder.Append(index); + index++; + if (index == lengths.Length) + { + index = 0; + } + } + + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Sparse) + { + buffers = new ArrowBuffer[1]; + } + else + { + ArrowBuffer.Builder offsetBuilder = new ArrowBuffer.Builder(Length); + for (int i = 0; i < Length; i++) + { + offsetBuilder.Append(i / lengths.Length); + } + + buffers = new ArrowBuffer[2]; + buffers[1] = offsetBuilder.Build(); + } + buffers[0] = typeIdBuilder.Build(); + + Array = UnionArray.Create(new ArrayData(type, Length, 0, 0, buffers, childArrays)); } public void Visit(DictionaryType type) From c576ecae0cd9e95b371e7d1ae5e4213a21574403 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 20 Jul 2023 16:58:33 -0700 Subject: [PATCH 04/10] Implement Archery support for C# unions --- csharp/src/Apache.Arrow/Arrays/Array.cs | 13 +--- .../IntegrationCommand.cs | 77 +++++++++++++++++-- .../Apache.Arrow.IntegrationTest/JsonFile.cs | 4 + dev/archery/archery/integration/datagen.py | 1 - docs/source/status.rst | 4 +- 5 files changed, 80 insertions(+), 19 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/Array.cs b/csharp/src/Apache.Arrow/Arrays/Array.cs index a453b0807267f..0838134b19c6d 100644 --- a/csharp/src/Apache.Arrow/Arrays/Array.cs +++ b/csharp/src/Apache.Arrow/Arrays/Array.cs @@ -62,16 +62,7 @@ internal static void Accept(T array, IArrowArrayVisitor visitor) public Array Slice(int offset, int length) { - if (offset > Length) - { - throw new ArgumentException($"Offset {offset} cannot be greater than Length {Length} for Array.Slice"); - } - - length = Math.Min(Data.Length - offset, length); - offset += Data.Offset; - - ArrayData newData = Data.Slice(offset, length); - return ArrowArrayFactory.BuildArray(newData) as Array; + return ArrowArrayFactory.Slice(this, offset, length) as Array; } public void Dispose() @@ -88,4 +79,4 @@ protected virtual void Dispose(bool disposing) } } } -} \ No newline at end of file +} diff --git a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs index af70f8946da5d..608692939d5af 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/IntegrationCommand.cs @@ -128,7 +128,7 @@ private RecordBatch CreateRecordBatch(Schema schema, JsonRecordBatch jsonRecordB for (int i = 0; i < jsonRecordBatch.Columns.Count; i++) { JsonFieldData data = jsonRecordBatch.Columns[i]; - Field field = schema.GetFieldByName(data.Name); + Field field = schema.FieldsList[i]; ArrayCreator creator = new ArrayCreator(data); field.DataType.Accept(creator); arrays.Add(creator.Array); @@ -149,8 +149,20 @@ private static Schema CreateSchema(JsonSchema jsonSchema) private static void CreateField(Field.Builder builder, JsonField jsonField) { + Field[] children = null; + if (jsonField.Children?.Count > 0) + { + children = new Field[jsonField.Children.Count]; + for (int i = 0; i < jsonField.Children.Count; i++) + { + Field.Builder field = new Field.Builder(); + CreateField(field, jsonField.Children[i]); + children[i] = field.Build(); + } + } + builder.Name(jsonField.Name) - .DataType(ToArrowType(jsonField.Type)) + .DataType(ToArrowType(jsonField.Type, children)) .Nullable(jsonField.Nullable); if (jsonField.Metadata != null) @@ -159,7 +171,7 @@ private static void CreateField(Field.Builder builder, JsonField jsonField) } } - private static IArrowType ToArrowType(JsonArrowType type) + private static IArrowType ToArrowType(JsonArrowType type, Field[] children) { return type.Name switch { @@ -173,6 +185,7 @@ private static IArrowType ToArrowType(JsonArrowType type) "date" => ToDateArrowType(type), "time" => ToTimeArrowType(type), "timestamp" => ToTimestampArrowType(type), + "union" => ToUnionArrowType(type, children), "null" => NullType.Default, _ => throw new NotSupportedException($"JsonArrowType not supported: {type.Name}") }; @@ -251,6 +264,17 @@ private static IArrowType ToTimestampArrowType(JsonArrowType type) }; } + private static IArrowType ToUnionArrowType(JsonArrowType type, Field[] children) + { + UnionMode mode = type.Mode switch + { + "SPARSE" => UnionMode.Sparse, + "DENSE" => UnionMode.Dense, + _ => throw new NotSupportedException($"Union mode not supported: {type.Mode}"), + }; + return new UnionType(children, type.TypeIds, mode); + } + private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor, @@ -278,7 +302,7 @@ private class ArrayCreator : IArrowTypeVisitor, IArrowTypeVisitor { - private JsonFieldData JsonFieldData { get; } + private JsonFieldData JsonFieldData { get; set; } public IArrowArray Array { get; private set; } public ArrayCreator(JsonFieldData jsonFieldData) @@ -489,7 +513,39 @@ public void Visit(StructType type) public void Visit(UnionType type) { - throw new NotImplementedException(); + ArrowBuffer[] buffers; + if (type.Mode == UnionMode.Dense) + { + buffers = new ArrowBuffer[2]; + buffers[1] = GetOffsetBuffer(); + } + else + { + buffers = new ArrowBuffer[1]; + } + buffers[0] = GetTypeIdBuffer(); + + ArrayData[] children = GetChildren(type); + + int nullCount = 0; + ArrayData arrayData = new ArrayData(type, JsonFieldData.Count, nullCount, 0, buffers, children); + Array = UnionArray.Create(arrayData); + } + + private ArrayData[] GetChildren(NestedType type) + { + ArrayData[] children = new ArrayData[type.Fields.Count]; + + var data = JsonFieldData; + for (int i = 0; i < children.Length; i++) + { + JsonFieldData = data.Children[i]; + type.Fields[i].DataType.Accept(this); + children[i] = Array.Data; + } + JsonFieldData = data; + + return children; } private static byte[] ConvertHexStringToByteArray(string hexString) @@ -555,11 +611,22 @@ private void GenerateLongArray(Func valueOffsets = new ArrowBuffer.Builder(JsonFieldData.Offset.Length); valueOffsets.AppendRange(JsonFieldData.Offset); return valueOffsets.Build(default); } + private ArrowBuffer GetTypeIdBuffer() + { + ArrowBuffer.Builder typeIds = new ArrowBuffer.Builder(JsonFieldData.TypeId.Length); + for (int i = 0; i < JsonFieldData.TypeId.Length; i++) + { + typeIds.Append(checked((byte)JsonFieldData.TypeId[i])); + } + return typeIds.Build(default); + } + private ArrowBuffer GetValidityBuffer(out int nullCount) { if (JsonFieldData.Validity == null) diff --git a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs index f074afc010103..721b4b14ee123 100644 --- a/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs +++ b/csharp/test/Apache.Arrow.IntegrationTest/JsonFile.cs @@ -68,6 +68,10 @@ public class JsonArrowType // FixedSizeBinary fields public int ByteWidth { get; set; } + // union fields + public string Mode { get; set; } + public int[] TypeIds { get; set; } + [JsonExtensionData] public Dictionary ExtensionData { get; set; } } diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 11cb02a9f4ebc..8df5a6c3b11dd 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1713,7 +1713,6 @@ def _temp_path(): .skip_category('JS'), generate_unions_case() - .skip_category('C#') .skip_category('JS'), generate_custom_metadata_case() diff --git a/docs/source/status.rst b/docs/source/status.rst index 5c8895b114ae3..0fabc72ca9dba 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -83,9 +83,9 @@ Data Types +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Map | ✓ | ✓ | ✓ | ✓ | | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Dense Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | | +| Dense Union | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| Sparse Union | ✓ | ✓ | ✓ | | | ✓ | ✓ | | +| Sparse Union | ✓ | ✓ | ✓ | | ✓ | ✓ | ✓ | | +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +-------------------+-------+-------+-------+------------+-------+-------+-------+-------+ From a113a80beaa4b14db8fc56b25b712d4885746ed0 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 20 Jul 2023 18:02:32 -0700 Subject: [PATCH 05/10] Better backwards compatibility --- csharp/src/Apache.Arrow/ChunkedArray.cs | 21 ++++++++++++++-- csharp/src/Apache.Arrow/Column.cs | 24 +++++++++++-------- .../ArrowArrayConcatenatorTests.cs | 2 -- csharp/test/Apache.Arrow.Tests/ColumnTests.cs | 2 +- csharp/test/Apache.Arrow.Tests/TableTests.cs | 6 ++--- 5 files changed, 37 insertions(+), 18 deletions(-) diff --git a/csharp/src/Apache.Arrow/ChunkedArray.cs b/csharp/src/Apache.Arrow/ChunkedArray.cs index 8565032276d09..f5909f5adfe48 100644 --- a/csharp/src/Apache.Arrow/ChunkedArray.cs +++ b/csharp/src/Apache.Arrow/ChunkedArray.cs @@ -34,7 +34,14 @@ public int ArrayCount get => Arrays.Count; } - public IArrowArray Array(int index) => Arrays[index]; + public Array Array(int index) => Arrays[index] as Array; + + public IArrowArray ArrowArray(int index) => Arrays[index]; + + public ChunkedArray(IList arrays) + : this(Cast(arrays)) + { + } public ChunkedArray(IList arrays) { @@ -51,7 +58,7 @@ public ChunkedArray(IList arrays) } } - public ChunkedArray(Array array) : this(new[] { array }) { } + public ChunkedArray(Array array) : this(new IArrowArray[] { array }) { } public ChunkedArray Slice(long offset, long length) { @@ -85,6 +92,16 @@ public ChunkedArray Slice(long offset) return Slice(offset, Length - offset); } + private static IArrowArray[] Cast(IList arrays) + { + IArrowArray[] arrowArrays = new IArrowArray[arrays.Count]; + for (int i = 0; i < arrays.Count; i++) + { + arrowArrays[i] = arrays[i]; + } + return arrowArrays; + } + // TODO: Flatten for Structs } } diff --git a/csharp/src/Apache.Arrow/Column.cs b/csharp/src/Apache.Arrow/Column.cs index 7d99ed24d02d0..0709b9142cafd 100644 --- a/csharp/src/Apache.Arrow/Column.cs +++ b/csharp/src/Apache.Arrow/Column.cs @@ -27,20 +27,24 @@ public class Column public Field Field { get; } public ChunkedArray Data { get; } + public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) + { + } + public Column(Field field, IList arrays) + : this(field, new ChunkedArray(arrays), doValidation: true) { - Data = new ChunkedArray(arrays); - Field = field; - if (!ValidateArrayDataTypes()) - { - throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); - } } - private Column(Field field, ChunkedArray arrays) + private Column(Field field, ChunkedArray data, bool doValidation = false) { + Data = data; Field = field; - Data = arrays; + if (doValidation && !ValidateArrayDataTypes()) + { + throw new ArgumentException($"{Field.DataType} must match {Data.DataType}"); + } } public long Length => Data.Length; @@ -64,12 +68,12 @@ private bool ValidateArrayDataTypes() for (int i = 0; i < Data.ArrayCount; i++) { - if (Data.Array(i).Data.DataType.TypeId != Field.DataType.TypeId) + if (Data.ArrowArray(i).Data.DataType.TypeId != Field.DataType.TypeId) { return false; } - Data.Array(i).Data.DataType.Accept(dataTypeComparer); + Data.ArrowArray(i).Data.DataType.Accept(dataTypeComparer); if (!dataTypeComparer.DataTypeMatch) { diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index d0044eab2e0ce..352c9fd834e84 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -17,8 +17,6 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using System.Reflection.Emit; -using System.Text; using Apache.Arrow.Memory; using Apache.Arrow.Types; using Xunit; diff --git a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs index b90c681622d5f..2d867b79176aa 100644 --- a/csharp/test/Apache.Arrow.Tests/ColumnTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ColumnTests.cs @@ -39,7 +39,7 @@ public void TestColumn() Array intArrayCopy = MakeIntArray(10); Field field = new Field.Builder().Name("f0").DataType(Int32Type.Default).Build(); - Column column = new Column(field, new[] { intArray, intArrayCopy }); + Column column = new Column(field, new IArrowArray[] { intArray, intArrayCopy }); Assert.True(column.Name == field.Name); Assert.True(column.Field == field); diff --git a/csharp/test/Apache.Arrow.Tests/TableTests.cs b/csharp/test/Apache.Arrow.Tests/TableTests.cs index d3b9f057b6556..1f462cf6177a2 100644 --- a/csharp/test/Apache.Arrow.Tests/TableTests.cs +++ b/csharp/test/Apache.Arrow.Tests/TableTests.cs @@ -86,13 +86,13 @@ public void TestTableAddRemoveAndSetColumn() Array nonEqualLengthIntArray = ColumnTests.MakeIntArray(10); Field field1 = new Field.Builder().Name("f1").DataType(Int32Type.Default).Build(); - Column nonEqualLengthColumn = new Column(field1, new[] { nonEqualLengthIntArray}); + Column nonEqualLengthColumn = new Column(field1, new IArrowArray[] { nonEqualLengthIntArray }); Assert.Throws(() => table.InsertColumn(-1, nonEqualLengthColumn)); Assert.Throws(() => table.InsertColumn(1, nonEqualLengthColumn)); Array equalLengthIntArray = ColumnTests.MakeIntArray(20); Field field2 = new Field.Builder().Name("f2").DataType(Int32Type.Default).Build(); - Column equalLengthColumn = new Column(field2, new[] { equalLengthIntArray}); + Column equalLengthColumn = new Column(field2, new IArrowArray[] { equalLengthIntArray }); Column existingColumn = table.Column(0); Table newTable = table.InsertColumn(0, equalLengthColumn); @@ -118,7 +118,7 @@ public void TestBuildFromRecordBatch() RecordBatch batch = TestData.CreateSampleRecordBatch(schema, 10); Table table = Table.TableFromRecordBatches(schema, new[] { batch }); - Assert.NotNull(table.Column(0).Data.Array(0) as Int64Array); + Assert.NotNull(table.Column(0).Data.ArrowArray(0) as Int64Array); } } From 039ec9e2e2e3dea254d2c885ab37511f4e21099b Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Mon, 21 Aug 2023 16:10:37 -0700 Subject: [PATCH 06/10] Fix deserialization of fixed-size list --- csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 92e5edb21c61b..5a069a5fc01ff 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -254,6 +254,7 @@ private ArrayData LoadPrimitiveField( buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1; break; case ArrowTypeId.Struct: + case ArrowTypeId.FixedSizeList: buffers = 1; break; default: From 0327be29939143e4b06be8e6f84357c6e025bf14 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Thu, 24 Aug 2023 20:11:01 -0700 Subject: [PATCH 07/10] Fix bug in ctor --- csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs index b9d32a5b5f9d9..e7e4f6256957b 100644 --- a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -39,7 +39,7 @@ public DenseUnionArray( children.Select(child => child.Data))) { _fields = children.ToArray(); - ValidateMode(UnionMode.Sparse, Type.Mode); + ValidateMode(UnionMode.Dense, Type.Mode); } public DenseUnionArray(ArrayData data) From 832f7cb0e0c7b8a30c116580a80ec6fe0c9d7b39 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Tue, 5 Sep 2023 21:00:41 -0700 Subject: [PATCH 08/10] PR feedback --- csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs | 2 +- csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs index e7e4f6256957b..1aacbe11f08b9 100644 --- a/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/DenseUnionArray.cs @@ -46,7 +46,7 @@ public DenseUnionArray(ArrayData data) : base(data) { ValidateMode(UnionMode.Dense, Type.Mode); - data.EnsureBufferCount(2); // TODO: + data.EnsureBufferCount(2); } } } diff --git a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs index 876b1ace30773..c1a12362a942a 100644 --- a/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs +++ b/csharp/src/Apache.Arrow/C/CArrowSchemaExporter.cs @@ -131,7 +131,7 @@ private static string FormatUnion(UnionType unionType) { UnionMode.Sparse => "+us:", UnionMode.Dense => "+ud:", - _ => throw new InvalidDataException($"Unsupported time unit for export: {unionType.Mode}"), + _ => throw new InvalidDataException($"Unsupported union mode for export: {unionType.Mode}"), }); for (int i = 0; i < unionType.TypeIds.Length; i++) { From 5ebefcf6d951aafc9907cb0906f161bb96f9f598 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Fri, 22 Sep 2023 16:10:17 -0700 Subject: [PATCH 09/10] Increment metadata version to V5 and add handling for V4 unions. --- .../src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 10 ++++++++++ csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 5a069a5fc01ff..0bc5ba5e80603 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -251,6 +251,16 @@ private ArrayData LoadPrimitiveField( case ArrowTypeId.Null: return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); case ArrowTypeId.Union: + if (fieldNullCount > 0) + { + if (recordBatchEnumerator.CurrentBuffer.Length > 0) + { + // With V4 metadata we can get a validity bitmap. Fixing up union data is hard, + // so we will just quit. + throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + } + recordBatchEnumerator.MoveNextBuffer(); + } buffers = ((UnionType)field.DataType).Mode == Types.UnionMode.Dense ? 2 : 1; break; case ArrowTypeId.Struct: diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 23a0bc3211ad9..2b3815af71142 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -235,7 +235,7 @@ public void Visit(IArrowArray array) private readonly bool _leaveOpen; private readonly IpcOptions _options; - private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4; + private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V5; private static readonly byte[] s_padding = new byte[64]; From 143a469c785abd7f3f2b98ec3fba48ed41c90903 Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Sat, 23 Sep 2023 06:34:07 -0700 Subject: [PATCH 10/10] Correctly skip buffer in pre-V5 metadata --- .../Ipc/ArrowReaderImplementation.cs | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 0bc5ba5e80603..d3115da52cc6c 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -116,11 +116,11 @@ protected RecordBatch CreateArrowObjectFromMessage( break; case Flatbuf.MessageHeader.DictionaryBatch: Flatbuf.DictionaryBatch dictionaryBatch = message.Header().Value; - ReadDictionaryBatch(dictionaryBatch, bodyByteBuffer, memoryOwner); + ReadDictionaryBatch(message.Version, dictionaryBatch, bodyByteBuffer, memoryOwner); break; case Flatbuf.MessageHeader.RecordBatch: Flatbuf.RecordBatch rb = message.Header().Value; - List arrays = BuildArrays(Schema, bodyByteBuffer, rb); + List arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb); return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length); default: // NOTE: Skip unsupported message type @@ -136,7 +136,11 @@ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory buffer) return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer), 0); } - private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBuffer bodyByteBuffer, IMemoryOwner memoryOwner) + private void ReadDictionaryBatch( + MetadataVersion version, + Flatbuf.DictionaryBatch dictionaryBatch, + ByteBuffer bodyByteBuffer, + IMemoryOwner memoryOwner) { long id = dictionaryBatch.Id; IArrowType valueType = DictionaryMemo.GetDictionaryType(id); @@ -149,7 +153,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu Field valueField = new Field("dummy", valueType, true); var schema = new Schema(new[] { valueField }, default); - IList arrays = BuildArrays(schema, bodyByteBuffer, recordBatch.Value); + IList arrays = BuildArrays(version, schema, bodyByteBuffer, recordBatch.Value); if (arrays.Count != 1) { @@ -167,6 +171,7 @@ private void ReadDictionaryBatch(Flatbuf.DictionaryBatch dictionaryBatch, ByteBu } private List BuildArrays( + MetadataVersion version, Schema schema, ByteBuffer messageBuffer, Flatbuf.RecordBatch recordBatchMessage) @@ -187,8 +192,8 @@ private List BuildArrays( Flatbuf.FieldNode fieldNode = recordBatchEnumerator.CurrentNode; ArrayData arrayData = field.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, field, in fieldNode, messageBuffer, bufferCreator); arrays.Add(ArrowArrayFactory.BuildArray(arrayData)); } while (recordBatchEnumerator.MoveNextNode()); @@ -225,6 +230,7 @@ private IBufferCreator GetBufferCreator(BodyCompression? compression) } private ArrayData LoadPrimitiveField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -251,13 +257,16 @@ private ArrayData LoadPrimitiveField( case ArrowTypeId.Null: return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, System.Array.Empty()); case ArrowTypeId.Union: - if (fieldNullCount > 0) + if (version < MetadataVersion.V5) { - if (recordBatchEnumerator.CurrentBuffer.Length > 0) + if (fieldNullCount > 0) { - // With V4 metadata we can get a validity bitmap. Fixing up union data is hard, - // so we will just quit. - throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + if (recordBatchEnumerator.CurrentBuffer.Length > 0) + { + // With older metadata we can get a validity bitmap. Fixing up union data is hard, + // so we will just quit. + throw new NotSupportedException("Cannot read pre-1.0.0 Union array with top-level validity bitmap"); + } } recordBatchEnumerator.MoveNextBuffer(); } @@ -279,7 +288,7 @@ private ArrayData LoadPrimitiveField( recordBatchEnumerator.MoveNextBuffer(); } - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -292,6 +301,7 @@ private ArrayData LoadPrimitiveField( } private ArrayData LoadVariableField( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, in Flatbuf.FieldNode fieldNode, @@ -326,7 +336,7 @@ private ArrayData LoadVariableField( } ArrowBuffer[] arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer }; - ArrayData[] children = GetChildren(ref recordBatchEnumerator, field, bodyData, bufferCreator); + ArrayData[] children = GetChildren(version, ref recordBatchEnumerator, field, bodyData, bufferCreator); IArrowArray dictionary = null; if (field.DataType.TypeId == ArrowTypeId.Dictionary) @@ -339,6 +349,7 @@ private ArrayData LoadVariableField( } private ArrayData[] GetChildren( + MetadataVersion version, ref RecordBatchEnumerator recordBatchEnumerator, Field field, ByteBuffer bodyData, @@ -355,8 +366,8 @@ private ArrayData[] GetChildren( Field childField = type.Fields[index]; ArrayData child = childField.DataType.IsFixedPrimitive() - ? LoadPrimitiveField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) - : LoadVariableField(ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); + ? LoadPrimitiveField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator) + : LoadVariableField(version, ref recordBatchEnumerator, childField, in childFieldNode, bodyData, bufferCreator); children[index] = child; }