diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 476b68ce33..903d2a3b64 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -64,6 +64,7 @@ make_arrow_table_schema, TArrowSchema, NULL_SCHEMA, + TArrowField, ) from dlt.destinations.impl.lancedb.utils import ( list_unique_identifiers, @@ -75,7 +76,9 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { + v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() +} class LanceDBTypeMapper(TypeMapper): @@ -191,7 +194,9 @@ def upload_batch( tbl.add(records, mode="replace") elif write_disposition == "merge": if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + raise ValueError( + "To perform a merge update, 'id_field_name' must be specified." + ) tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -262,14 +267,20 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: schema, ) - def create_table(self, table_name: str, schema: TArrowSchema) -> Table: + @lancedb_error + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: schema: The table schema to create. table_name: The name of the table to create. + mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + By default, if the table already exists, an exception is raised. + If you want to overwrite the table, use mode="overwrite". """ - return self.db_client.create_table(table_name, schema=schema) + return self.db_client.create_table(table_name, schema=schema, mode=mode) def delete_table(self, table_name: str) -> None: """Delete a LanceDB table. @@ -360,7 +371,9 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None @@ -397,9 +410,14 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] table_schema[field_type] = schema_c return True, table_schema - def add_table_field(self, table_name: str, field_schema: pa.DataType) -> None: + @lancedb_error + def add_table_field(self, table_name: str, field_schema: TArrowField) -> Table: """Add a field to the LanceDB table. + Since arrow tables are immutable, this is done via a staging mechanism. + The data is stored in-memory in a staging arrow table, evolved then stored + written over the old table. + Args: table_name: The name of the table to create the field on. field_schema: The field to create. @@ -407,20 +425,37 @@ def add_table_field(self, table_name: str, field_schema: pa.DataType) -> None: # TODO: Arrow tables are immutable. # This is tricky without creating a new table. # Perhaps my performing a merge this can work tbl.merge - raise NotImplementedError + # Open existing LanceDB table directly as PyArrow Table + arrow_table = self.db_client.open_table(table_name).to_arrow() + + # Create an array of null values for the new column. + null_array = pa.nulls(len(arrow_table), type=field_schema.type) + + # Create staging Table with new column appended. + stage = arrow_table.append_column(field_schema, null_array) + + return self.db_client.create_table(table_name, stage, mode="overwrite") def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns(table_name, existing_columns) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + new_columns = self.schema.get_new_table_columns( + table_name, existing_columns + ) + embedding_fields: List[str] = get_columns_names_with_prop( + self.schema.get_table(table_name), VECTORIZE_HINT + ) + logger.info( + f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" + ) if len(new_columns) > 0: if exists: for column in new_columns: - field_schema = make_arrow_field_schema( - column["name"], column, self.type_mapper + field_schema: TArrowField = make_arrow_field_schema( + column["name"], column, self.type_mapper, embedding_fields ) - self.add_table_field(table_name, field_schema) + fq_table_name = self.make_qualified_table_name(table_name) + self.add_table_field(fq_table_name, field_schema) else: embedding_fields = get_columns_names_with_prop( self.schema.get_table(table_name=table_name), VECTORIZE_HINT @@ -450,7 +485,9 @@ def update_schema_in_storage(self) -> None: "schema": json.dumps(self.schema.to_dict()), } ] - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -464,8 +501,12 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Loads compressed state from destination storage by finding a load ID that was completed.""" - fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_state_table_name = self.make_qualified_table_name( + self.schema.state_table_name + ) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) state_records = ( self.db_client.open_table(fq_state_table_name) @@ -489,8 +530,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: return None @lancedb_error - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) try: response = ( @@ -507,7 +552,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) try: response = ( @@ -542,7 +589,9 @@ def complete_load(self, load_id: str) -> None: "inserted_at": str(pendulum.now()), } ] - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -556,7 +605,9 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: return LoadLanceDBJob( self.schema, table, @@ -600,7 +651,9 @@ def __init__( self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name self.unique_identifiers: Sequence[str] = list_unique_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_fields: List[str] = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 392b28ccde..c135baa3c0 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -30,10 +30,17 @@ def make_arrow_field_schema( column_name: str, column: TColumnSchema, type_mapper: TypeMapper, - embedding_model_func: Optional[TextEmbeddingFunction] = None, embedding_fields: Optional[List[str]] = None, -) -> TArrowDataType: - raise NotImplementedError +) -> TArrowField: + """Creates a PyArrow field from a dlt column schema.""" + dtype = cast(TArrowDataType, type_mapper.to_db_type(column)) + + if embedding_fields and column_name in embedding_fields: + metadata = {"embedding_source": "true"} + else: + metadata = None + + return pa.field(column_name, dtype, metadata=metadata) def make_arrow_table_schema( @@ -54,7 +61,9 @@ def make_arrow_table_schema( if embedding_fields: vec_size = embedding_model_dimensions or embedding_model_func.ndims() - arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size))) + arrow_schema.append( + pa.field(vector_field_name, pa.list_(pa.float32(), vec_size)) + ) for column_name, column in schema.get_table_columns(table_name).items(): dtype = cast(TArrowDataType, type_mapper.to_db_type(column))