From 159d1ec6d6ee34515755424019abf996e061f993 Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Fri, 17 May 2024 18:05:17 +0900 Subject: [PATCH 1/2] Fixing stream stopping at wrong location (#898) Fixes https://github.com/outlines-dev/outlines/issues/896 --- outlines/generate/api.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3f4f182d2..51a995664 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -340,15 +340,6 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: return generated_token_ids = sequence.token_ids[:, -num_generated:] generated_sequences = self.tokenizer.decode(generated_token_ids) - next_tokens = [ - token[len(sequence) :] if not stop else "" - for token, sequence, stop in zip( - generated_sequences, - previously_generated_sequences, - is_stop_at_reached, - ) - ] - previously_generated_sequences = generated_sequences if stop_sequences: is_stop_at_reached = [ stop @@ -360,6 +351,25 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ) ] + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences # We reshape the output to (batch_size, sample_size) output: List[List[str]] = list() for i in range(batch_size): From 499d19dd3078e5e21cf68c7916a162d5e8ce0990 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 17 May 2024 09:09:03 +0000 Subject: [PATCH 2/2] Prevent Illegal Look-Around for OneOf in JSONSchema (#897) Fixes #823 This comment details the issues error: https://github.com/outlines-dev/outlines/issues/823#issuecomment-2116490949 The reproduction code provided results in a json schema with `OneOf[pets]`: ``` class Model(BaseModel): pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') ``` Before this PR: `OneOf` uses negative lookaheads to assert that only one schema member is included. This is illegal in `interegular`, more details available here: https://github.com/outlines-dev/outlines/issues/456 After `OneOf` uses or-joined non-capturing groups which don't have the same issues with `interegular`. --- outlines/fsm/json_schema.py | 8 +------- tests/fsm/test_json_schema.py | 31 ++++++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index d96597d4c..2c53fd240 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -195,13 +195,7 @@ def to_regex( to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] ] - xor_patterns = [] - # json schema validation ensured there is no overlapping schemas in oneOf - for subregex in subregexes: - other_subregexes = filter(lambda r: r != subregex, subregexes) - other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) - negative_lookahead = f"(?!.*({other_subregexes_str}))" - xor_patterns.append(f"({subregex}){negative_lookahead}") + xor_patterns = [f"(?:{subregex})" for subregex in subregexes] return rf"({'|'.join(xor_patterns)})" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 5b3ad9e39..b992f7aa5 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,9 +1,10 @@ import json import re -from typing import List +from typing import List, Literal, Union +import interegular import pytest -from pydantic import BaseModel, constr +from pydantic import BaseModel, Field, constr from outlines.fsm.json_schema import ( BOOLEAN, @@ -321,7 +322,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], }, - rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))", + rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', [ ("12.3", True), ("true", True), @@ -750,3 +751,27 @@ class MockModel(BaseModel): assert match_default_ws is None assert re.fullmatch(pattern, mock_result_maybe_ws) + + +def test_one_of_doesnt_produce_illegal_lookaround(): + """Reproduces failure in https://github.com/outlines-dev/outlines/issues/823""" + + class Cat(BaseModel): + pet_type: Literal["cat"] + meows: int + + class Dog(BaseModel): + pet_type: Literal["dog"] + barks: float + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") + n: int + + json_schema = Model.schema_json() + + json_schema = Model.schema_json() + pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) + + # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() + interegular.parse_pattern(pattern).to_fsm()