Skip to content

Commit

Permalink
Added support for an 'enum' extension type to the v2 reader/writer
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Oct 23, 2024
1 parent 536e73d commit f0c02c4
Show file tree
Hide file tree
Showing 11 changed files with 541 additions and 179 deletions.
103 changes: 103 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,106 @@ def test_blob(tmp_path):
reader = LanceFileReader(str(path))
assert len(reader.metadata().columns[0].pages) == 1
assert reader.read_all().to_table() == pa.table({"val": vals})


def test_enum_vs_categorical(tmp_path):
# Helper method to make two dict arrays, with same dictionary values
# but different indices
def make_tbls(values, indices1, indices2):
# Need to make two separate dictionaries here or else arrow-rs won't concat
d1 = pa.array(values, pa.string())
d2 = pa.array(values, pa.string())
i1 = pa.array(indices1, pa.int16())
i2 = pa.array(indices2, pa.int16())

dict1 = pa.DictionaryArray.from_arrays(i1, d1)
dict2 = pa.DictionaryArray.from_arrays(i2, d2)
tab1 = pa.table({"dictionary": dict1})
tab2 = pa.table({"dictionary": dict2})
return tab1, tab2

# Helper method to round trip two tables through lance and return the decoded
# dictionary array
def round_trip_dict(tab1: pa.Table, tab2: pa.Table) -> pa.DictionaryArray:
with LanceFileWriter(tmp_path / "categorical.lance") as writer:
writer.write_batch(tab1)
writer.write_batch(tab2)

reader = LanceFileReader(tmp_path / "categorical.lance")
round_tripped = reader.read_all().to_table()

arr2 = round_tripped.column("dictionary").chunk(0).dictionary
return arr2

# Helper method to convert a table with dictionary array into a table with
# enum array
def enumify(tbl) -> pa.Table:
categories = ",".join(tbl.column(0).chunk(0).dictionary.to_pylist())
enum_schema = pa.schema(
[
pa.field(
"dictionary",
pa.dictionary(pa.int16(), pa.string()),
metadata={
"ARROW:extension:name": "polars.enum",
"ARROW:extension:metadata": '{"categories": ['
+ categories
+ "]}",
},
)
]
)
return pa.table([tbl.column(0)], schema=enum_schema)

tab1, tab2 = make_tbls(
["blue", "red", "green", "yellow"],
[0, 1, 0, 1, 0, 1],
[1, 2, 1, 2, 1, 2],
)

round_trip = round_trip_dict(tab1, tab2)

# Sometimes array concatenation will just concatenate the dictionaries
assert round_trip.to_pylist() == [
"blue",
"red",
"green",
"yellow",
"blue",
"red",
"green",
"yellow",
]

tab1 = enumify(tab1)
tab2 = enumify(tab2)

round_trip = round_trip_dict(tab1, tab2)

# However, there should be no concatenation with the enum type
assert round_trip.to_pylist() == [
"blue",
"red",
"green",
"yellow",
]

tab1, tab2 = make_tbls(
[str(i) for i in range(1000)],
list(range(500)),
list(range(500, 900)),
)

round_trip = round_trip_dict(tab1, tab2)

# Other times array concatenation will combine the
# dictionaries and remove unused items
assert round_trip.to_pylist() == [str(i) for i in range(900)]

# Again, no concatenation with enum type
tab1 = enumify(tab1)
tab2 = enumify(tab2)

round_trip = round_trip_dict(tab1, tab2)

assert round_trip.to_pylist() == [str(i) for i in range(1000)]
2 changes: 2 additions & 0 deletions rust/lance-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ arrow-select = { workspace = true }
half = { workspace = true }
num-traits = { workspace = true }
rand.workspace = true
serde = { workspace = true }
serde_json = { workspace = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
146 changes: 146 additions & 0 deletions rust/lance-arrow/src/dict_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use std::sync::Arc;

use arrow_array::{cast::AsArray, Array, StringArray};
use arrow_schema::{ArrowError, Field as ArrowField};
use serde::{Deserialize, Serialize};

use crate::{
bfloat16::{ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY},
DataTypeExt, Result,
};

const ENUM_TYPE: &str = "polars.enum";

// TODO: Could be slightly more efficient to use custom JSON serialization
// to go straight from JSON to StringArray without the Vec<String> intermediate
// but this is fine for now
#[derive(Deserialize, Serialize)]
struct DictionaryEnumMetadata {
categories: Vec<String>,
}

pub struct DictionaryEnumType {
pub categories: Arc<dyn Array>,
}

impl DictionaryEnumType {
/// Adds extension type metadata to the given field
///
/// Fails if the field is already an extension type of some kind
pub fn wrap_field(&self, field: &ArrowField) -> Result<ArrowField> {
let mut metadata = field.metadata().clone();
if metadata.contains_key(ARROW_EXT_NAME_KEY) {
return Err(ArrowError::InvalidArgumentError(
"Field already has extension metadata".to_string(),
));
}
metadata.insert(ARROW_EXT_NAME_KEY.to_string(), ENUM_TYPE.to_string());
metadata.insert(
ARROW_EXT_META_KEY.to_string(),
serde_json::to_string(&DictionaryEnumMetadata {
categories: self
.categories
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.values()
.iter()
.map(|x| x.to_string())
.collect(),
})
.unwrap(),
);
Ok(field.clone().with_metadata(metadata))
}

/// Creates a new enum type from the given dictionary array
///
/// # Arguments
///
/// * `arr` - The dictionary array to create the enum type from
///
/// # Errors
///
/// An error is returned if the array is not a dictionary array or if the dictionary
/// array does not have string values
pub fn from_dict_array(arr: &dyn Array) -> Result<Self> {
let arr = arr.as_any_dictionary_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(
"Expected a dictionary array for enum type".to_string(),
)
})?;
if !arr.values().data_type().is_binary_like() {
Err(ArrowError::InvalidArgumentError(
"Expected a dictionary array with string values for enum type".to_string(),
))
} else {
Ok(Self {
categories: Arc::new(arr.values().clone()),
})
}
}

/// Attempts to parse the type from the given field
///
/// If the field is not an enum type then None is returned
///
/// Errors can occur if the field is an enum type but the metadata
/// is not correctly formatted
///
/// # Arguments
///
/// * `field` - The field to parse
/// * `sample_arr` - An optional sample array. If provided then categories will be extracted
/// from this array, avoiding the need to parse the metadata. This array should be a dictionary
/// array where the dictionary items are the categories.
///
/// The sample_arr is only used if the field is an enum type. E.g. it is safe to do something
/// like:
///
/// ```ignore
/// let arr = batch.column(0);
/// let field = batch.schema().field(0);
/// let enum_type = DictionaryEnumType::from_field(field, Some(arr));
/// ```
pub fn from_field(
field: &ArrowField,
sample_arr: Option<&Arc<dyn Array>>,
) -> Result<Option<Self>> {
if field
.metadata()
.get(ARROW_EXT_NAME_KEY)
.map(|k| k.eq_ignore_ascii_case(ENUM_TYPE))
.unwrap_or(false)
{
// Prefer extracting values from the first array if possible as it's cheaper
if let Some(arr) = sample_arr {
let dict_arr = arr.as_any_dictionary_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(
"Expected a dictionary array for enum type".to_string(),
)
})?;
Ok(Some(Self {
categories: dict_arr.values().clone(),
}))
} else {
// No arrays, need to use the field metadata
let meta = field.metadata().get(ARROW_EXT_META_KEY).ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"Field {} is missing extension metadata",
field.name()
))
})?;
let meta: DictionaryEnumMetadata = serde_json::from_str(meta).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Arrow extension metadata for enum was not correctly formed: {}",
e
))
})?;
let categories = Arc::new(StringArray::from_iter_values(meta.categories));
Ok(Some(Self { categories }))
}
} else {
Ok(None)
}
}
}
1 change: 1 addition & 0 deletions rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_select::{interleave::interleave, take::take};
use rand::prelude::*;

pub mod deepcopy;
pub mod dict_enum;
pub mod schema;
pub use schema::*;
pub mod bfloat16;
Expand Down
Loading

0 comments on commit f0c02c4

Please sign in to comment.