diff --git a/controller.py b/controller.py index 8ff6951..5e293d6 100644 --- a/controller.py +++ b/controller.py @@ -288,6 +288,7 @@ def run_encoding( send_project_update(project_id, f"notification_created:{user_id}", True) embedding.delete_tensors(embedding_id, with_commit=True) chunk = 0 + embedding_canceled = False for pair in generate_batches( project_id, record_ids, @@ -301,9 +302,19 @@ def run_encoding( record_ids_batched = pair["record_ids"] attribute_values_encoded_batch = pair["embeddings"] - if not embedding.get(project_id, embedding_id): + embedding_entity = embedding.get(project_id, embedding_id) + if not embedding_entity: logger.info(f"Aborted {embedding_name}") break + elif embedding_entity.state == enums.EmbeddingState.FAILED.value: + embedding_canceled = True + send_project_update( + project_id, + f"embedding:{embedding_id}:state:{enums.EmbeddingState.FAILED.value}", + ) + logger.info(f"Canceled {embedding_name}") + break + embedding.create_tensors( project_id, embedding_id, @@ -401,7 +412,7 @@ def run_encoding( doc_ock.post_embedding_failed(user_id, f"{model}-{platform}") return status.HTTP_500_INTERNAL_SERVER_ERROR - if embedding.get(project_id, embedding_id): + if embedding.get(project_id, embedding_id) and not embedding_canceled: for warning_type, idx_list in embedder.get_warnings().items(): # use last record with warning as example example_record_id = record_ids[idx_list[-1]] diff --git a/submodules/model b/submodules/model index 92610a4..454a0e8 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 92610a4f27d6fddeea838a91d7a83b0e6d0265f0 +Subproject commit 454a0e84f7b6f9d81ae3ddd8908e6ac36afec992