Skip to content

Commit

Permalink
GH-44214: [C++] JsonExtensionType equality check ignores storage type (
Browse files Browse the repository at this point in the history
…#44215)

### Rationale for this change

As noted in #13901 (review):
```cpp
bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const {
  return other.extension_name() == this->extension_name();
}
```
> This equality check does not take into account the storage type, but only the name.
> As a consequence, a JsonExtensionType<string> type will be seen as equal to JsonExtensionType<large_string>.

### What changes are included in this PR?

This change introduces storage equality check into `JsonExtensionType` equality check.

This also fixes a storage type check in `JsonExtensionType::Make`.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

No.
* GitHub Issue: #44214

Lead-authored-by: Rok Mihevc <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
rok and pitrou authored Oct 8, 2024
1 parent 61c99a5 commit 64891d1
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 22 deletions.
31 changes: 19 additions & 12 deletions cpp/src/arrow/extension/json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,13 @@
namespace arrow::extension {

bool JsonExtensionType::ExtensionEquals(const ExtensionType& other) const {
return other.extension_name() == this->extension_name();
return other.extension_name() == this->extension_name() &&
other.storage_type()->Equals(storage_type_);
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
if (storage_type->id() != Type::STRING && storage_type->id() != Type::STRING_VIEW &&
storage_type->id() != Type::LARGE_STRING) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(storage_type);
return JsonExtensionType::Make(std::move(storage_type));
}

std::string JsonExtensionType::Serialize() const { return ""; }
Expand All @@ -51,11 +47,22 @@ std::shared_ptr<Array> JsonExtensionType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}

std::shared_ptr<DataType> json(const std::shared_ptr<DataType> storage_type) {
ARROW_CHECK(storage_type->id() != Type::STRING ||
storage_type->id() != Type::STRING_VIEW ||
storage_type->id() != Type::LARGE_STRING);
return std::make_shared<JsonExtensionType>(storage_type);
bool JsonExtensionType::IsSupportedStorageType(Type::type type_id) {
return type_id == Type::STRING || type_id == Type::STRING_VIEW ||
type_id == Type::LARGE_STRING;
}

Result<std::shared_ptr<DataType>> JsonExtensionType::Make(
std::shared_ptr<DataType> storage_type) {
if (!IsSupportedStorageType(storage_type->id())) {
return Status::Invalid("Invalid storage type for JsonExtensionType: ",
storage_type->ToString());
}
return std::make_shared<JsonExtensionType>(std::move(storage_type));
}

std::shared_ptr<DataType> json(std::shared_ptr<DataType> storage_type) {
return JsonExtensionType::Make(std::move(storage_type)).ValueOrDie();
}

} // namespace arrow::extension
4 changes: 4 additions & 0 deletions cpp/src/arrow/extension/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ARROW_EXPORT JsonExtensionType : public ExtensionType {

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make(std::shared_ptr<DataType> storage_type);

static bool IsSupportedStorageType(Type::type type_id);

private:
std::shared_ptr<DataType> storage_type_;
};
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/arrow/extension/json_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,18 @@ TEST_F(TestJsonExtensionType, InvalidUTF8) {
}
}

TEST_F(TestJsonExtensionType, StorageTypeValidation) {
ASSERT_TRUE(json(utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(large_utf8())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(utf8())));
ASSERT_FALSE(json(utf8_view())->Equals(json(large_utf8())));

for (const auto& storage_type : {int16(), binary(), float64(), null()}) {
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Invalid storage type for JsonExtensionType: " +
storage_type->ToString(),
extension::JsonExtensionType::Make(storage_type));
}
}

} // namespace arrow
26 changes: 19 additions & 7 deletions cpp/src/parquet/arrow/arrow_schema_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,23 +757,35 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file does not contain Arrow schema.
// If Arrow extensions are enabled, both fields should be treated as json() extension
// fields.
// If Arrow extensions are enabled, fields will be interpreted as json(utf8())
// extension fields.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
auto arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
::arrow::field("json_2", ::arrow::extension::json(::arrow::utf8()), true)});
std::shared_ptr<KeyValueMetadata> metadata{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
CheckFlatSchema(arrow_schema);

// If original data was e.g. json(large_utf8()) it will be interpreted as json(utf8())
// in absence of Arrow schema.
arrow_schema = ::arrow::schema(
{::arrow::field("json_1", ::arrow::extension::json(), true),
::arrow::field("json_2", ::arrow::extension::json(::arrow::large_utf8()),
true)});
metadata = std::shared_ptr<KeyValueMetadata>{};
ASSERT_OK(ConvertSchema(parquet_fields, metadata, props));
EXPECT_TRUE(result_schema_->field(1)->type()->Equals(
::arrow::extension::json(::arrow::utf8())));
EXPECT_FALSE(
result_schema_->field(1)->type()->Equals(arrow_schema->field(1)->type()));
}

{
// Parquet file contains Arrow schema.
// Both json_1 and json_2 should be returned as a json() field
// even though extensions are not enabled.
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8())
// fields even though extensions are not enabled.
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(false);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand All @@ -791,7 +803,7 @@ TEST_F(TestConvertParquetSchema, ParquetSchemaArrowExtensions) {

{
// Parquet file contains Arrow schema. Extensions are enabled.
// Both json_1 and json_2 should be returned as a json() field
// json_1 and json_2 will be interpreted as json(utf8()) and json(large_utf8()).
ArrowReaderProperties props;
props.set_arrow_extensions_enabled(true);
std::shared_ptr<KeyValueMetadata> field_metadata =
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/parquet/arrow/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,8 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer
const auto& ex_type = checked_cast<const ::arrow::ExtensionType&>(*origin_type);
if (inferred_type->id() != ::arrow::Type::EXTENSION &&
ex_type.extension_name() == std::string("arrow.json") &&
(inferred_type->id() == ::arrow::Type::STRING ||
inferred_type->id() == ::arrow::Type::LARGE_STRING ||
inferred_type->id() == ::arrow::Type::STRING_VIEW)) {
::arrow::extension::JsonExtensionType::IsSupportedStorageType(
inferred_type->id())) {
// Schema mismatch.
//
// Arrow extensions are DISABLED in Parquet.
Expand All @@ -1009,6 +1008,18 @@ Result<bool> ApplyOriginalMetadata(const Field& origin_field, SchemaField* infer
// Origin type is restored as Arrow should be considered the source of truth.
inferred->field = inferred->field->WithType(origin_type);
RETURN_NOT_OK(ApplyOriginalStorageMetadata(origin_field, inferred));
} else if (inferred_type->id() == ::arrow::Type::EXTENSION &&
ex_type.extension_name() == std::string("arrow.json")) {
// Potential schema mismatch.
//
// Arrow extensions are ENABLED in Parquet.
// origin_type is arrow::extension::json(...)
// inferred_type is arrow::extension::json(arrow::utf8())
auto origin_storage_field = origin_field.WithType(ex_type.storage_type());

// Apply metadata recursively to storage type
RETURN_NOT_OK(ApplyOriginalStorageMetadata(*origin_storage_field, inferred));
inferred->field = inferred->field->WithType(origin_type);
} else {
auto origin_storage_field = origin_field.WithType(ex_type.storage_type());

Expand Down

0 comments on commit 64891d1

Please sign in to comment.