Skip to content

Commit

Permalink
Merge pull request #82 from monarch-initiative/80-silent-json-parsing…
Browse files Browse the repository at this point in the history
…-errors-in-the-app

Improve handling of Markdown parsing errors
  • Loading branch information
caufieldjh authored Sep 11, 2024
2 parents 5ef0df7 + 4cb6ea6 commit af6fa60
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
8 changes: 5 additions & 3 deletions src/curate_gpt/agents/mapping_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, ConfigDict

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.formatters.format_utils import remove_formatting
from curate_gpt.store.db_adapter import SEARCH_RESULT
from curate_gpt.utils.tokens import estimate_num_tokens, max_tokens_by_model

Expand Down Expand Up @@ -139,10 +140,10 @@ def match(
raise ValueError(f"Prompt too long: {prompt}.")
kb_results.pop()
response = model.prompt(prompt)

# Need to remove Markdown formatting here or it won't parse as JSON
response_text = response.text()
if response_text.startswith("```json"):
response_text = response_text[7:-3]
response_text = remove_formatting(text=response.text(), expect_format="json")

mappings = []
try:
for m in json.loads(response_text):
Expand All @@ -160,6 +161,7 @@ def match(
)
)
except json.decoder.JSONDecodeError:
# This will happen if the response is still not valid JSON
# This returns an empty set of mappings, but the prompt and response text are retained
return MappingSet(mappings=mappings, prompt=prompt, response_text=response_text)

Expand Down
10 changes: 4 additions & 6 deletions src/curate_gpt/extract/basic_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import yaml
from pydantic import ConfigDict

from curate_gpt.formatters.format_utils import remove_formatting

from ..utils.tokens import estimate_num_tokens, max_tokens_by_model
from .extractor import AnnotatedObject, Extractor

Expand Down Expand Up @@ -87,6 +89,7 @@ def deserialize(self, text: str, format=None, **kwargs) -> AnnotatedObject:
if format == "yaml":
return self.deserialize_yaml(text, **kwargs)
logger.debug(f"Parsing {text}")
text = remove_formatting(text=text, expect_format="json")
try:
obj = json.loads(text)
if isinstance(obj, str):
Expand All @@ -108,12 +111,7 @@ def deserialize(self, text: str, format=None, **kwargs) -> AnnotatedObject:

def deserialize_yaml(self, text: str, multiple=False) -> AnnotatedObject:
logger.debug(f"Parsing YAML: {text}")
if "```" in text:
logger.debug("Removing code block")
text = text.split("```")[1]
text = text.strip()
if text.startswith("yaml"):
text = text[4:]
text = remove_formatting(text=text, expect_format="yaml")
try:
if multiple:
obj = yaml.safe_load_all(text)
Expand Down
13 changes: 13 additions & 0 deletions src/curate_gpt/formatters/format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,16 @@ def object_as_yaml(obj: Dict) -> str:
:return:
"""
return yaml.dump({k: v for k, v in obj.items() if v}, sort_keys=False)


def remove_formatting(text: str, expect_format: str = "") -> str:
"""
Remove markdown formatting from text if present.
:param text:
:param expect_format: The expected format of the text, e.g., "json" (optional)
:return:
"""
if text.startswith("```" + expect_format):
return text[3 + len(expect_format) : -3]
return text

0 comments on commit af6fa60

Please sign in to comment.