Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing tests with provided OPENAI_API_KEY #85

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/curate_gpt/evaluation/dae_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
logger = logging.getLogger(__name__)


# TODO: missing abstract class evaluate_object, causes src/tests/evaluation/test_runner to fail
@dataclass
class DatabaseAugmentedCompletionEvaluator(BaseEvaluator):
"""
Expand Down Expand Up @@ -50,8 +49,7 @@ def evaluate(
"""
agent = self.agent
db = agent.knowledge_source
# TODO: use get()
test_objs = list(db.peek(collection=test_collection, limit=num_tests))
test_objs = list(db.find(collection=test_collection))
if any(obj for obj in test_objs if any(f not in obj for f in self.fields_to_predict)):
logger.info("Alternate strategy to get test objs; query whole collection")
test_objs = db.peek(collection=test_collection, limit=1000000)
Expand Down Expand Up @@ -133,3 +131,6 @@ def evaluate(
report_tsv_file.flush()
aggregated = aggregate_metrics(all_metrics)
return aggregated

def evaluate_object(self, obj, **kwargs) -> ClassificationMetrics:
pass
2 changes: 1 addition & 1 deletion src/curate_gpt/evaluation/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def stratify_collection_to_store(
size = len(objs)
cn = f"{collection}_{sn}_{size}"
collections[sn] = cn
logging.info(f"Writing {size} objects to {cn}")
logger.info(f"Writing {size} objects to {cn}")
if cn in existing_collections:
logger.info(f"Collection {cn} already exists")
if not force:
Expand Down
2 changes: 1 addition & 1 deletion src/curate_gpt/extract/openai_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def extract(
logger.debug(f"RESPONSE = {response}")
# print(response)
choice = response.choices[0]
message = choice["message"]
message = choice.message
if "function_call" not in message:
if self.raise_error_if_unparsable:
raise ValueError("No function call in response")
Expand Down
4 changes: 3 additions & 1 deletion src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,10 @@ def _get_embedding_dimension(self, model_name: str) -> int:
if isinstance(model_name, str):
if model_name.startswith("openai:"):
model_key = model_name.split("openai:", 1)[1]
if model_key == "" or model_key not in MODEL_MAP.keys():
model_key = DEFAULT_OPENAI_MODEL
model_info = MODEL_MAP.get(model_key, DEFAULT_OPENAI_MODEL)
return MODEL_MAP[model_info][1]
return model_info[1]
else:
return MODEL_MAP[DEFAULT_OPENAI_MODEL][1]

Expand Down
6 changes: 3 additions & 3 deletions src/curate_gpt/wrappers/clinical/clinvar_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
for r in results["eSummaryResult"]["DocumentSummarySet"]["DocumentSummary"]:
obj = {}
obj["id"] = "clinvar:" + r["accession"]
obj["clinical_significance"] = r["clinical_significance"]["description"]
obj["clinical_significance_status"] = r["clinical_significance"]["review_status"]
obj["clinical_significance"] = r["germline_classification"]["description"]
obj["clinical_significance_status"] = r["germline_classification"]["review_status"]
obj["gene_sort"] = r["gene_sort"]
if "genes" in r and r["genes"]:
if "gene" in r["genes"]:
Expand All @@ -46,7 +46,7 @@ def objects_from_dict(self, results: Dict) -> List[Dict]:
obj["protein_change"] = r["protein_change"]
obj["title"] = r["title"]
obj["traits"] = [
self._trait_from_dict(t) for t in r["trait_set"]["trait"] if isinstance(t, dict)
self._trait_from_dict(t) for t in r.get("trait_set", {}).get("trait", []) if isinstance(t, dict)
]
objs.append(obj)
return objs
Expand Down
8 changes: 7 additions & 1 deletion tests/agents/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
from tests import INPUT_DBS


# TODO: this has to be reviewed, isolate more, dont use one db for multiple tests
# - the current setup does not allow reset
# - set collection is v vulnerable, easier setting/creating new col for each test
# - using a loaded a test ontology can also be mocked for ease
# - ? use structure from tests/wrapper (vstore,wrapper fixtures)
# - or create collection in each test to use and load all collections with the whole data and reset/remove collection after
@pytest.fixture
def go_test_chroma_db() -> ChromaDBAdapter:
"""
Fixture for a ChromaDBAdapter instance with the test ontology loaded.

Note: the chromadb is not checked into github - instead,
this relies on test_chromadb_dapter.test_store to create the test db.
this relies on test_chromadb_adapter.test_store to create the test db.
"""
db = ChromaDBAdapter(str(INPUT_DBS / "go-nucleus-chroma"))
db.schema_proxy = SchemaProxy(ONTOLOGY_MODEL_PATH)
Expand Down
1 change: 1 addition & 0 deletions tests/store/test_duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_the_embedding_function_variations(
expected_name = "test_collection"
else:
# Specific case: Collection specified, model may or may not be specified
print("\n\n",model,"\n\n")
db.insert(objs, collection=collection, model=model)
expected_model = model if model else "all-MiniLM-L6-v2"
expected_name = collection
Expand Down
25 changes: 25 additions & 0 deletions tests/utils/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

from src.curate_gpt.store import ChromaDBAdapter

DEBUG_MODE = False

def create_db_dir(tmp_path, out_dir) -> Path:
"""Creates a temporary directory or uses the provided debug directory."""
if DEBUG_MODE:
temp_dir = out_dir
if not temp_dir.exists():
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
else:
return tmp_path


def setup_db(temp_dir: Path) -> ChromaDBAdapter:
"""Sets up the DBAdapter and optionally resets it."""
# TODO: for now ChromaDB, later add DuckDB
# db = get_store("chromadb", str(temp_dir))
db = ChromaDBAdapter(str(temp_dir))
# reset only when we use the db in try block, or in test
return db

20 changes: 14 additions & 6 deletions tests/wrappers/test_bioportal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest

from curate_gpt import ChromaDBAdapter
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.ontology.bioportal_wrapper import BioportalWrapper
from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper
from tests import OUTPUT_DIR
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_OAKVIEW_DB = OUTPUT_DIR / "bioportal_tmp"

Expand All @@ -18,15 +18,23 @@


@pytest.fixture
def vstore() -> OntologyWrapper:
db = ChromaDBAdapter(str(TEMP_OAKVIEW_DB))
def vstore(tmp_path) -> OntologyWrapper:
tmp_dir = create_db_dir(tmp_path=tmp_path, out_dir=TEMP_OAKVIEW_DB)
db = setup_db(tmp_dir)
db.reset()
view = BioportalWrapper(local_store=db, extractor=BasicExtractor())
assert view.fetch_definitions is False
try:
view = BioportalWrapper(local_store=db, extractor=BasicExtractor())
assert view.fetch_definitions is False
yield view
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
db.reset()

# view = BioportalView(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
# view.fetch_definitions = False
# view.fetch_relationships = False
return view


@pytest.mark.skip(reason="OAK bp wrapper doesn't support definitions yets")
Expand Down
20 changes: 13 additions & 7 deletions tests/wrappers/test_clinvar.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
import shutil
import time

import pytest
import requests
import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.chat_agent import ChatAgent
from curate_gpt.agents.dragon_agent import DragonAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.clinical.clinvar_wrapper import ClinVarWrapper
from tests import INPUT_DIR, OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_DB = OUTPUT_DIR / "obj_tmp"

Expand All @@ -29,12 +29,18 @@ def test_clinvar_transform():


@pytest.fixture
def wrapper() -> ClinVarWrapper:
shutil.rmtree(TEMP_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_DB))
def wrapper(tmp_path) -> ClinVarWrapper:
temp_dir = create_db_dir(tmp_path, TEMP_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
return ClinVarWrapper(local_store=db, extractor=extractor)
try:
yield ClinVarWrapper(local_store=db, extractor=extractor)
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error occurred: {e}")
raise e
finally:
if not DEBUG_MODE:
db.reset()


@requires_openai_api_key
Expand Down
20 changes: 13 additions & 7 deletions tests/wrappers/test_evidence_agent.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging
import shutil
from typing import Type

import pytest
import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.evidence_agent import EvidenceAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers import BaseWrapper
from curate_gpt.wrappers.literature import PubmedWrapper, WikipediaWrapper
from tests import OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_PUBMED_DB = OUTPUT_DIR / "pmid_tmp"

Expand All @@ -30,12 +29,19 @@
WikipediaWrapper,
],
)
def test_evidence_inference(source: Type[BaseWrapper]):
shutil.rmtree(TEMP_PUBMED_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_PUBMED_DB))
def test_evidence_inference(tmp_path, source: Type[BaseWrapper]):
tmp_dir = create_db_dir(tmp_path=tmp_path, out_dir=TEMP_PUBMED_DB)
db = setup_db(tmp_dir)
extractor = BasicExtractor()
db.reset()
pubmed = source(local_store=db, extractor=extractor)
try:
pubmed = source(local_store=db, extractor=extractor)
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
if tmp_dir.exists():
db.reset()

ea = EvidenceAgent(chat_agent=pubmed)
obj = {
"label": "acinar cells of the salivary gland",
Expand Down
15 changes: 7 additions & 8 deletions tests/wrappers/test_ncbi_biosample.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import logging
import shutil
import time

import yaml

from curate_gpt import ChromaDBAdapter
from curate_gpt.agents.chat_agent import ChatAgent
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.investigation.ncbi_biosample_wrapper import NCBIBiosampleWrapper
from tests import OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import create_db_dir, setup_db

TEMP_BIOSAMPLE_DB = OUTPUT_DIR / "biosample_tmp"

logger = logging.getLogger(__name__)


@requires_openai_api_key
def test_biosample_search():
shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_BIOSAMPLE_DB))
def test_biosample_search(tmp_path):
temp_dir = create_db_dir(tmp_path, TEMP_BIOSAMPLE_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
Expand All @@ -33,9 +32,9 @@ def test_biosample_search():


@requires_openai_api_key
def test_biosample_chat():
shutil.rmtree(TEMP_BIOSAMPLE_DB, ignore_errors=True)
db = ChromaDBAdapter(str(TEMP_BIOSAMPLE_DB))
def test_biosample_chat(tmp_path):
temp_dir = create_db_dir(tmp_path, TEMP_BIOSAMPLE_DB)
db = setup_db(temp_dir)
extractor = BasicExtractor()
db.reset()
wrapper = NCBIBiosampleWrapper(local_store=db, extractor=extractor)
Expand Down
45 changes: 23 additions & 22 deletions tests/wrappers/test_ontology.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import logging
import os
import shutil
import tempfile
from pprint import pprint

import pytest
from oaklib import get_adapter
from oaklib.datamodels.obograph import GraphDocument

from curate_gpt import ChromaDBAdapter
from curate_gpt.extract import BasicExtractor
from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper
from tests import INPUT_DIR, OUTPUT_DIR
from tests.store.conftest import requires_openai_api_key
from tests.utils.helper import DEBUG_MODE, create_db_dir, setup_db

TEMP_OAKVIEW_DB = OUTPUT_DIR / "oaktmp"
TEMP_OAKVIEW_DB2 = OUTPUT_DIR / "oaktmp2"
TEMP_OAK_OBJ = OUTPUT_DIR / "oak_tmp_obj"
TEMP_OAK_IND = OUTPUT_DIR / "oak_tmp_ind"
TEMP_OAK_SEARCH = OUTPUT_DIR / "oak_tmp_search"

# logger = logging.getLogger(__name__)

Expand All @@ -25,20 +23,27 @@


@pytest.fixture
def vstore():
with tempfile.TemporaryDirectory() as temp_dir:
db_path = os.path.join(temp_dir, "test_db")
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
db = ChromaDBAdapter(db_path)
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
def vstore(request, tmp_path):
temp_db_base = request.param
temp_dir = create_db_dir(tmp_path, temp_db_base)
db = setup_db(temp_dir)
extractor = BasicExtractor()
# mock, possible connection error?
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
try:
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=extractor)
db.insert(wrapper.objects())
yield wrapper
except Exception as e:
raise e
finally:
if not DEBUG_MODE:
db.reset()


@pytest.mark.parametrize('vstore', [TEMP_OAK_OBJ], indirect=True)
def test_oak_objects(vstore):
"""Test that the objects are extracted from the oak adapter."""
shutil.rmtree(TEMP_OAKVIEW_DB, ignore_errors=True)
# vstore.local_store.reset()
objs = list(vstore.objects())
[nucleus] = [obj for obj in objs if obj["id"] == "Nucleus"]
assert nucleus["label"] == "nucleus"
Expand All @@ -50,22 +55,17 @@ def test_oak_objects(vstore):
assert len(reversed.graphs[0].edges) == 2


@pytest.mark.parametrize('vstore', [TEMP_OAK_IND], indirect=True)
def test_oak_index(vstore):
"""Test that the objects are indexed in the local store."""
shutil.rmtree(TEMP_OAKVIEW_DB2, ignore_errors=True)
adapter = get_adapter(INPUT_DIR / "go-nucleus.db")
db = ChromaDBAdapter(str(TEMP_OAKVIEW_DB2))
db.reset()
wrapper = OntologyWrapper(oak_adapter=adapter, local_store=db, extractor=BasicExtractor())
db.insert(wrapper.objects())
g = wrapper.unwrap_object(
g = vstore.unwrap_object(
{
"id": "Nucleus",
"label": "nucleus",
"relationships": [{"predicate": "rdfs:subClassOf", "target": "Organelle"}],
"original_id": "GO:0005634",
},
store=db,
store=vstore.local_store,
)
if isinstance(g, GraphDocument):
pprint(g.__dict__, width=100, indent=2)
Expand All @@ -80,6 +80,7 @@ def test_oak_index(vstore):
print(edge.sub, edge.pred, edge.obj)


@pytest.mark.parametrize('vstore', [TEMP_OAK_SEARCH], indirect=True)
@requires_openai_api_key
def test_oak_search(vstore):
"""Test that the objects are indexed and searchable in the local store."""
Expand Down
Loading
Loading