From c75edf58423387ba0b1f11716da0d5f556cada69 Mon Sep 17 00:00:00 2001 From: ddebowczyk Date: Wed, 2 Oct 2024 11:35:23 +0200 Subject: [PATCH] Fixed PartialsGenerator --- README.md | 3 + docs/advanced/model_options.mdx | 3 + .../examples/advanced/context_cache.mdx | 4 +- .../examples/advanced/custom_client.mdx | 3 + docs/cookbook/examples/advanced/partials.mdx | 1 + examples/A02_Advanced/ContextCaching/run.php | 4 +- .../CustomClientParameters/run.php | 3 + examples/A02_Advanced/PartialUpdates/run.php | 1 + src/Core/StreamResponse/PartialsGenerator.php | 8 +- src/Utils/Json/ResilientJsonParser.php | 212 +++++++++--------- .../Feature/Utils/ResilientJsonParserTest.php | 159 +++++++++++++ 11 files changed, 295 insertions(+), 106 deletions(-) create mode 100644 tests/Feature/Utils/ResilientJsonParserTest.php diff --git a/README.md b/README.md index 2cb588cc..78f02821 100644 --- a/README.md +++ b/README.md @@ -673,7 +673,10 @@ $yourApiKey = Env::get('OPENAI_API_KEY'); // use your own API key $driver = new OpenAIDriver(new LLMConfig( apiUrl: 'https://api.openai.com/v1', // you can change base URI apiKey: $yourApiKey, + endpoint: '/chat/completions', metadata: ['organization' => ''], + model: 'gpt-4o-mini', + maxTokens: 128, )); /// Get Instructor with the default client component overridden with your own diff --git a/docs/advanced/model_options.mdx b/docs/advanced/model_options.mdx index 2f412d99..dfca5fd0 100644 --- a/docs/advanced/model_options.mdx +++ b/docs/advanced/model_options.mdx @@ -31,7 +31,10 @@ use Cognesy\Instructor\Extras\LLM\Drivers\OpenAIDriver; $driver = new OpenAIDriver(new LLMConfig( apiUrl: 'https://api.openai.com/v1', // you can change base URI apiKey: $yourApiKey, + endpoint: '/chat/completions', metadata: ['organization' => ''], + model: 'gpt-4o-mini', + maxTokens: 128, )); /// Get Instructor with the default client component overridden with your own diff --git a/docs/cookbook/examples/advanced/context_cache.mdx b/docs/cookbook/examples/advanced/context_cache.mdx index cdf582e2..6b24c48c 100644 --- a/docs/cookbook/examples/advanced/context_cache.mdx +++ b/docs/cookbook/examples/advanced/context_cache.mdx @@ -46,7 +46,7 @@ class Project { public array $applications; #[Description('Explain the purpose of the project and the domain specific problems it solves')] public string $description; - #[Description('Example code in Markdown demonstrating domain specific application of the library')] + #[Description('Example code as Markdown fragment, demonstrating domain specific application of the library')] public string $code; } ?> @@ -93,7 +93,7 @@ which results in faster processing and lower costs. ```php respond( - messages: "Describe the project in a way compelling to my audience: boutique CMS consulting company owner.", + messages: "Describe the project in a way compelling to my audience: lead gen software vendor.", responseModel: Project::class, options: ['max_tokens' => 4096], mode: Mode::Json, diff --git a/docs/cookbook/examples/advanced/custom_client.mdx b/docs/cookbook/examples/advanced/custom_client.mdx index ba9d864a..8c04b0d5 100644 --- a/docs/cookbook/examples/advanced/custom_client.mdx +++ b/docs/cookbook/examples/advanced/custom_client.mdx @@ -32,7 +32,10 @@ class User { $driver = new OpenAIDriver(new LLMConfig( apiUrl: 'https://api.openai.com/v1', apiKey: Env::get('OPENAI_API_KEY'), + endpoint: '/chat/completions', metadata: ['organization' => ''], + model: 'gpt-4o-mini', + maxTokens: 128, ) ); diff --git a/docs/cookbook/examples/advanced/partials.mdx b/docs/cookbook/examples/advanced/partials.mdx index d6a24a8f..074c22da 100644 --- a/docs/cookbook/examples/advanced/partials.mdx +++ b/docs/cookbook/examples/advanced/partials.mdx @@ -18,6 +18,7 @@ response is received. $loader = require 'vendor/autoload.php'; $loader->add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); +use Cognesy\Instructor\Events\Event; use Cognesy\Instructor\Instructor; class UserRole diff --git a/examples/A02_Advanced/ContextCaching/run.php b/examples/A02_Advanced/ContextCaching/run.php index cdf582e2..6b24c48c 100644 --- a/examples/A02_Advanced/ContextCaching/run.php +++ b/examples/A02_Advanced/ContextCaching/run.php @@ -46,7 +46,7 @@ class Project { public array $applications; #[Description('Explain the purpose of the project and the domain specific problems it solves')] public string $description; - #[Description('Example code in Markdown demonstrating domain specific application of the library')] + #[Description('Example code as Markdown fragment, demonstrating domain specific application of the library')] public string $code; } ?> @@ -93,7 +93,7 @@ class Project { ```php respond( - messages: "Describe the project in a way compelling to my audience: boutique CMS consulting company owner.", + messages: "Describe the project in a way compelling to my audience: lead gen software vendor.", responseModel: Project::class, options: ['max_tokens' => 4096], mode: Mode::Json, diff --git a/examples/A02_Advanced/CustomClientParameters/run.php b/examples/A02_Advanced/CustomClientParameters/run.php index ba9d864a..edda9942 100644 --- a/examples/A02_Advanced/CustomClientParameters/run.php +++ b/examples/A02_Advanced/CustomClientParameters/run.php @@ -32,7 +32,10 @@ class User { $driver = new OpenAIDriver(new LLMConfig( apiUrl: 'https://api.openai.com/v1', apiKey: Env::get('OPENAI_API_KEY'), + endpoint: '/chat/completions', metadata: ['organization' => ''], + model: 'gpt-3.5-turbo', + maxTokens: 128, ) ); diff --git a/examples/A02_Advanced/PartialUpdates/run.php b/examples/A02_Advanced/PartialUpdates/run.php index d6a24a8f..074c22da 100644 --- a/examples/A02_Advanced/PartialUpdates/run.php +++ b/examples/A02_Advanced/PartialUpdates/run.php @@ -18,6 +18,7 @@ $loader = require 'vendor/autoload.php'; $loader->add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); +use Cognesy\Instructor\Events\Event; use Cognesy\Instructor\Instructor; class UserRole diff --git a/src/Core/StreamResponse/PartialsGenerator.php b/src/Core/StreamResponse/PartialsGenerator.php index f4baf95a..4e2a96f7 100644 --- a/src/Core/StreamResponse/PartialsGenerator.php +++ b/src/Core/StreamResponse/PartialsGenerator.php @@ -146,7 +146,13 @@ protected function tryGetPartialObject( string $partialJsonData, ResponseModel $responseModel, ) : Result { - return Chain::from(fn() => Json::fix(Json::find($partialJsonData))) +// dump('raw:', $partialJsonData); +// $found = Json::findPartial($partialJsonData); +// dump('found:', $found); +// $json = Json::fix($found); +// dump('fixed:', $json); + + return Chain::from(fn() => Json::fix(Json::findPartial($partialJsonData))) ->through(fn($jsonData) => $this->responseDeserializer->deserialize($jsonData, $responseModel, $this?->toolCalls->last()->name)) ->through(fn($object) => $this->responseTransformer->transform($object)) ->result(); diff --git a/src/Utils/Json/ResilientJsonParser.php b/src/Utils/Json/ResilientJsonParser.php index d8aa427a..12447d09 100644 --- a/src/Utils/Json/ResilientJsonParser.php +++ b/src/Utils/Json/ResilientJsonParser.php @@ -2,30 +2,32 @@ namespace Cognesy\Instructor\Utils\Json; +namespace Cognesy\Instructor\Utils\Json; + class ResilientJsonParser { private string $input; private int $position = 0; private int $length; - private bool $inCodeBlock = false; public function __construct(string $input) { - $this->input = $input; + $this->input = trim($input); $this->length = strlen($input); } - // PUBLIC ///////////////////////////////////////////////////////////////// - public function parse(): mixed { + if (empty($this->input) || ($this->length === 0)) { + throw new \RuntimeException("Cannot parse an empty string"); + } $this->skipWhitespace(); + if ($this->position >= $this->length) { + throw new \RuntimeException("Input contains only whitespace"); + } return $this->parseValue(); } - // INTERNAL //////////////////////////////////////////////////////////////// - private function parseValue(): mixed { - $char = $this->getCurrentChar(); - return match ($char) { + return match ($this->getCurrentChar()) { '{' => $this->parseObject(), '[' => $this->parseArray(), '"' => $this->parseString(), @@ -41,20 +43,21 @@ private function parseObject(): array { $this->consume('{'); $this->skipWhitespace(); - while ($this->getCurrentChar() !== '}') { + if ($this->getCurrentChar() === '}') { + $this->consume('}'); + return $result; + } + + do { + $this->skipWhitespace(); $key = $this->parseString(); $this->skipWhitespace(); $this->consume(':'); $this->skipWhitespace(); $value = $this->parseValue(); $result[$key] = $value; - $this->skipWhitespace(); - if ($this->getCurrentChar() === ',') { - $this->consume(','); - $this->skipWhitespace(); - } - } + } while ($this->consumeIf(',')); $this->consume('}'); return $result; @@ -65,87 +68,67 @@ private function parseArray(): array { $this->consume('['); $this->skipWhitespace(); - while ($this->getCurrentChar() !== ']') { - $value = $this->parseValue(); - $result[] = $value; + if ($this->getCurrentChar() === ']') { + $this->consume(']'); + return $result; + } + do { $this->skipWhitespace(); - if ($this->getCurrentChar() === ',') { - $this->consume(','); - $this->skipWhitespace(); - } - } + $result[] = $this->parseValue(); + $this->skipWhitespace(); + } while ($this->consumeIf(',')); $this->consume(']'); return $result; } -// private function parseString(): string { -// $result = ''; -// $this->consume('"'); -// -// while (true) { -// $char = $this->getCurrentChar(); -// if ($char === '"' && $this->getPreviousChar() !== '\\') { -// break; -// } -// if ($char === "\n" || $char === "\r") { -// $result .= '\n'; -// $this->position++; -// } elseif ($char === '\\') { -// $result .= $char . $this->getNextChar(); -// $this->position += 2; -// } else { -// $result .= $char; -// $this->position++; -// } -// } -// -// $this->consume('"'); -// return $result; -// } - - private function parseString(): string - { - $result = ''; + private function parseString(): string { $this->consume('"'); - - while (true) { - $char = $this->getCurrentChar(); - if ($char === '`' && $this->getNextChar() === '`' && $this->getNextNextChar() === '`') { - $this->inCodeBlock = !$this->inCodeBlock; - $result .= '```'; - $this->position += 3; - continue; - } - if ($char === '"' && $this->getPreviousChar() !== '\\' && !$this->inCodeBlock) { - break; + $result = ''; + while ($this->position < $this->length) { + $char = $this->input[$this->position]; + if ($char === '"') { + $this->position++; + return $result; } - if ($char === "\n" || $char === "\r") { - $result .= '\n'; + if ($char === '\\') { $this->position++; - } elseif ($char === '\\') { - $result .= $char . $this->getNextChar(); - $this->position += 2; + if ($this->position >= $this->length) { + throw new \RuntimeException("Unterminated string escape at position {$this->position}"); + } + $escapeChar = $this->input[$this->position]; + $result .= $this->parseEscapeChar($escapeChar); } else { $result .= $char; - $this->position++; } + $this->position++; } - - $this->consume('"'); - return $result; + throw new \RuntimeException("Unterminated string at position {$this->position}"); + } + + private function parseEscapeChar(string $char): string { + return match($char) { + '"' => '"', + '\\' => '\\', + '/' => '/', + 'b' => "\b", + 'f' => "\f", + 'n' => "\n", + 'r' => "\r", + 't' => "\t", + 'u' => $this->parseUnicodeEscape(), + default => throw new \RuntimeException("Invalid escape character '\\$char' at position {$this->position}") + }; } - private function parseNumber(): float|int { - $start = $this->position; - while (preg_match('/[\d.+-e]/i', $this->getCurrentChar())) { - $this->position++; + private function parseUnicodeEscape(): string { + $hex = substr($this->input, $this->position + 1, 4); + if (strlen($hex) !== 4 || !ctype_xdigit($hex)) { + throw new \RuntimeException("Invalid Unicode escape sequence at position {$this->position}"); } - $numberString = substr($this->input, $start, $this->position - $start); - return is_numeric($numberString) - ? $this->toNumber($numberString) - : 0; + $this->position += 4; + return html_entity_decode("&#x$hex;", ENT_QUOTES, 'UTF-8'); } private function parseTrue(): bool { @@ -163,40 +146,67 @@ private function parseNull(): ?string { return null; } - private function skipWhitespace(): void { - while ($this->position < $this->length && ctype_space($this->getCurrentChar())) { + private function parseNumber(): float|int { + $start = $this->position; + $allowedChars = '0123456789.eE+-'; + $gotDecimalPoint = false; + $gotExponent = false; + + while ($this->position < $this->length && strpos($allowedChars, $this->getCurrentChar()) !== false) { + $char = $this->getCurrentChar(); + if ($char === '.') { + if ($gotDecimalPoint) { + throw new \RuntimeException("Invalid number format: multiple decimal points at position {$this->position}"); + } + $gotDecimalPoint = true; + } elseif ($char === 'e' || $char === 'E') { + if ($gotExponent) { + throw new \RuntimeException("Invalid number format: multiple exponents at position {$this->position}"); + } + $gotExponent = true; + } $this->position++; } - } - private function consume(string $expected): void { - $length = strlen($expected); - if (substr($this->input, $this->position, $length) !== $expected) { - throw new \RuntimeException("Expected '$expected' at position {$this->position}"); + $numberString = substr($this->input, $start, $this->position - $start); + if (!is_numeric($numberString)) { + throw new \RuntimeException("Invalid number format at position $start"); } - $this->position += $length; + + return $this->toNumber($numberString); } - private function getCurrentChar(): string { - return $this->position < $this->length ? $this->input[$this->position] : ''; + private function skipWhitespace(): void { + while ($this->position < $this->length && preg_match('/\s/', $this->getCurrentChar())) { + $this->position++; + } } - private function getNextChar(): string { - return $this->position + 1 < $this->length ? $this->input[$this->position + 1] : ''; + private function consume(string $expected): void { + $this->skipWhitespace(); + if (substr($this->input, $this->position, strlen($expected)) !== $expected) { + throw new \RuntimeException("Expected '$expected' at position {$this->position}"); + } + $this->position += strlen($expected); } - private function getPreviousChar(): string { - return $this->position > 0 ? $this->input[$this->position - 1] : ''; + private function consumeIf(string $expected): bool { + $this->skipWhitespace(); + if (substr($this->input, $this->position, strlen($expected)) === $expected) { + $this->position += strlen($expected); + return true; + } + return false; } - private function getNextNextChar(): string - { - return $this->position + 2 < $this->length ? $this->input[$this->position + 2] : ''; + private function getCurrentChar(): string { + return $this->position < $this->length ? $this->input[$this->position] : ''; } - private function toNumber(float|int|string $numberString) : float|int { - return strpos($numberString, '.') !== false - ? (float) $numberString - : (int) $numberString; + private function toNumber(string $numberString): float|int { + if (strpos($numberString, '.') !== false || stripos($numberString, 'e') !== false) { + return (float) $numberString; + } + return (int) $numberString; } -} \ No newline at end of file +} diff --git a/tests/Feature/Utils/ResilientJsonParserTest.php b/tests/Feature/Utils/ResilientJsonParserTest.php new file mode 100644 index 00000000..5f485b3d --- /dev/null +++ b/tests/Feature/Utils/ResilientJsonParserTest.php @@ -0,0 +1,159 @@ +parser = new ResilientJsonParser(''); +}); + +test('parse empty object', function () { + $result = (new ResilientJsonParser('{}'))->parse(); + expect($result)->toBe([]); +}); + +test('parse empty array', function () { + $result = (new ResilientJsonParser('[]'))->parse(); + expect($result)->toBe([]); +}); + +test('parse simple object', function () { + $result = (new ResilientJsonParser('{"key": "value"}'))->parse(); + expect($result)->toBe(['key' => 'value']); +}); + +test('parse simple array', function () { + $result = (new ResilientJsonParser('["value1", "value2"]'))->parse(); + expect($result)->toBe(['value1', 'value2']); +}); + +test('parse nested object', function () { + $json = '{"outer": {"inner": "value"}}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe(['outer' => ['inner' => 'value']]); +}); + +test('parse nested array', function () { + $json = '["outer", ["inner1", "inner2"]]'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe(['outer', ['inner1', 'inner2']]); +}); + +test('parse numbers', function () { + $json = '{"integer": 42, "float": 3.14, "exponent": 1.23e-4, "negative": -10, "large": 1234567890}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe([ + 'integer' => 42, + 'float' => 3.14, + 'exponent' => 1.23e-4, + 'negative' => -10, + 'large' => 1234567890 + ]); +}); + +test('parse boolean and null values', function () { + $json = '{"bool_true": true, "bool_false": false, "null_value": null}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe([ + 'bool_true' => true, + 'bool_false' => false, + 'null_value' => null + ]); +}); + +test('parse string with escaped characters', function () { + $json = '{"escaped": "This is a \"quoted\" string with backslash and \t tab and \n newline"}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result['escaped'])->toBe("This is a \"quoted\" string with backslash and \t tab and \n newline"); +}); + +test('parse string with code block', function () { + $json = '{"code": "Here is some code: ```var x = 5;``` End of code."}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result['code'])->toBe('Here is some code: ```var x = 5;``` End of code.'); +}); + +test('throw exception for invalid number', function () { + $json = '{"invalid_number": 12.34.56}'; + expect(fn() => (new ResilientJsonParser($json))->parse())->toThrow(\RuntimeException::class); +}); + +test('throw exception for invalid boolean', function () { + $json = '{"invalid_boolean": truefalse}'; + expect(fn() => (new ResilientJsonParser($json))->parse())->toThrow(\RuntimeException::class); +}); + +test('throw exception for unclosed object', function () { + $json = '{"unclosed": "object"'; + expect(fn() => (new ResilientJsonParser($json))->parse())->toThrow(\RuntimeException::class); +}); + +test('throw exception for unclosed array', function () { + $json = '["unclosed", "array"'; + expect(fn() => (new ResilientJsonParser($json))->parse())->toThrow(\RuntimeException::class); +}); + +test('throw exception for invalid JSON structure', function () { + $json = '{"key": "value",}'; // Trailing comma + expect(fn() => (new ResilientJsonParser($json))->parse())->toThrow(\RuntimeException::class); +}); + +test('parse empty string', function () { + expect(fn() => (new ResilientJsonParser(''))->parse())->toThrow(\RuntimeException::class); +}); + +test('parse whitespace-only string', function () { + expect(fn() => (new ResilientJsonParser(' '))->parse())->toThrow(\RuntimeException::class); +}); + +test('parse complex nested structure', function () { + $json = ' + { + "name": "John Doe", + "age": 30, + "isStudent": false, + "courses": ["Math", "Physics", "Chemistry"], + "address": { + "street": "123 Main St", + "city": "Anytown", + "zipCode": "12345" + }, + "grades": [ + {"subject": "Math", "score": 90}, + {"subject": "Physics", "score": 85}, + {"subject": "Chemistry", "score": 92} + ] + }'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe([ + "name" => "John Doe", + "age" => 30, + "isStudent" => false, + "courses" => ["Math", "Physics", "Chemistry"], + "address" => [ + "street" => "123 Main St", + "city" => "Anytown", + "zipCode" => "12345" + ], + "grades" => [ + ["subject" => "Math", "score" => 90], + ["subject" => "Physics", "score" => 85], + ["subject" => "Chemistry", "score" => 92] + ] + ]); +}); + +test('parse deeply nested structure', function () { + $json = '{"level1":{"level2":{"level3":{"level4":{"level5":"deep value"}}}}}'; + $result = (new ResilientJsonParser($json))->parse(); + expect($result)->toBe([ + "level1" => [ + "level2" => [ + "level3" => [ + "level4" => [ + "level5" => "deep value" + ] + ] + ] + ] + ]); +}); \ No newline at end of file