diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 1c5c8912e5a0b..f0ee295c6347d 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -283,6 +283,25 @@ bool RecordBatch::ApproxEquals(const RecordBatch& other, const EqualOptions& opt return true; } +Result> RecordBatch::ReplaceSchema( + std::shared_ptr schema) const { + if (schema_->num_fields() != schema->num_fields()) + return Status::Invalid("RecordBatch schema fields", schema_->num_fields(), + ", did not match new schema fields: ", schema->num_fields()); + auto fields = schema_->fields(); + int n_fields = static_cast(fields.size()); + for (int i = 0; i < n_fields; i++) { + auto old_type = fields[i]->type(); + auto replace_type = schema->field(i)->type(); + if (!old_type->Equals(replace_type)) { + return Status::Invalid( + "RecordBatch schema field index ", i, " type is ", old_type->ToString(), + ", did not match new schema field type: ", replace_type->ToString()); + } + } + return RecordBatch::Make(std::move(schema), num_rows(), columns()); +} + Result> RecordBatch::SelectColumns( const std::vector& indices) const { int n = static_cast(indices.size()); diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index d728d5eb0da2f..cb1f6d54f7cff 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -114,6 +114,11 @@ class ARROW_EXPORT RecordBatch { /// \return the record batch's schema const std::shared_ptr& schema() const { return schema_; } + /// \brief Replace the schema with another schema with the same types, but potentially + /// different field names and/or metadata. + Result> ReplaceSchema( + std::shared_ptr schema) const; + /// \brief Retrieve all columns at once virtual const std::vector>& columns() const = 0; diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index e8180c6740879..bc923a1444160 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -521,4 +521,38 @@ TEST_F(TestRecordBatchReader, ToTable) { ASSERT_EQ(table->column(0)->chunks().size(), 0); } +TEST_F(TestRecordBatch, ReplaceSchema) { + const int length = 10; + + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8()); + auto f2 = field("f2", int16()); + auto f3 = field("f3", int8()); + + auto schema = ::arrow::schema({f0, f1, f2}); + + random::RandomArrayGenerator gen(42); + + auto a0 = gen.ArrayOf(int32(), length); + auto a1 = gen.ArrayOf(uint8(), length); + auto a2 = gen.ArrayOf(int16(), length); + + auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2}); + + f0 = field("fd0", int32()); + f1 = field("fd1", uint8()); + f2 = field("fd2", int16()); + + schema = ::arrow::schema({f0, f1, f2}); + ASSERT_OK_AND_ASSIGN(auto mutated, b1->ReplaceSchema(schema)); + auto expected = RecordBatch::Make(schema, length, b1->columns()); + ASSERT_TRUE(mutated->Equals(*expected)); + + schema = ::arrow::schema({f0, f1, f3}); + ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); + + schema = ::arrow::schema({f0, f1}); + ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema)); +} + } // namespace arrow