Skip to content

Commit

Permalink
Code cleanup - RequestHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
ddebowczyk committed Oct 5, 2024
1 parent 4c8d030 commit d20cf91
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 64 deletions.
5 changes: 2 additions & 3 deletions examples/A01_Basics/Validation/run.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ class UserDetails
messages: [['role' => 'user', 'content' => "you can reply to me via mail -- Jason"]],
responseModel: UserDetails::class,
)->get();
dump($user);
} catch(Exception $e) {
$caughtException = true;
}

dump($user);

assert($user === null);
assert(!isset($user));
assert($caughtException === true);
?>
```
2 changes: 1 addition & 1 deletion examples/A01_Basics/ValidationCustom/run.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class UserDetails
#[Assert\Callback]
public function validateName(ExecutionContextInterface $context, mixed $payload) {
if ($this->name !== strtoupper($this->name)) {
$context->buildViolation("Name must be in all uppercase letters.")
$context->buildViolation("Name must be all uppercase.")
->atPath('name')
->setInvalidValue($this->name)
->addViolation();
Expand Down
6 changes: 1 addition & 5 deletions src/Core/PartialsGenerator.php
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public function getPartialResponses(Generator $stream, ResponseModel $responseMo
$this->resetPartialResponse();

// receive data
/** @var \Cognesy\Instructor\Extras\LLM\Data\PartialLLMResponse $partialResponse */
/** @var PartialLLMResponse $partialResponse */
foreach($stream as $partialResponse) {
$this->events->dispatch(new StreamedResponseReceived($partialResponse));
// store partial response
Expand Down Expand Up @@ -108,10 +108,6 @@ public function getPartialResponses(Generator $stream, ResponseModel $responseMo
}
$this->events->dispatch(new PartialJsonReceived($this->responseJson));

// yield new PartialProcessedResponse(
// result: $result,
// partialLLMResponse: $partialResponse,
// );
yield $result->unwrap();
}
$this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse()));
Expand Down
63 changes: 20 additions & 43 deletions src/Core/RequestHandler.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use Cognesy\Instructor\Contracts\CanGeneratePartials;
use Cognesy\Instructor\Contracts\CanGenerateResponse;
use Cognesy\Instructor\Data\Request;
use Cognesy\Instructor\Data\ResponseModel;
use Cognesy\Instructor\Enums\Mode;
use Cognesy\Instructor\Events\EventDispatcher;
use Cognesy\Instructor\Events\Instructor\InstructorDone;
Expand All @@ -30,9 +29,7 @@ class RequestHandler
protected EventDispatcher $events;

protected int $retries = 0;
protected array $messages = [];
protected array $errors = [];
protected ?ResponseModel $responseModel;

public function __construct(
protected Request $request,
Expand Down Expand Up @@ -76,16 +73,21 @@ public function stream() : Stream {
* Generates response value
*/
protected function responseFor(Request $request) : mixed {
$this->init($request);
$this->init();

$processingResult = Result::failure("No response generated");
while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
$llmResponse = $this->getLLMResponse($request);
$llmResponse = $this->getInference($request)->toLLMResponse();

$llmResponse->content = match($request->mode()) {
Mode::Text => $llmResponse->content,
default => Json::from($llmResponse->content)->toString(),
};
$partialResponses = [];
$processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
}

$value = $this->processResult($processingResult, $request, $llmResponse, $partialResponses);
$value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);

return $value;
}
Expand All @@ -96,61 +98,36 @@ protected function responseFor(Request $request) : mixed {
* @return Generator<mixed>
*/
protected function streamResponseFor(Request $request) : Generator {
$this->init($request);
$this->init();

$processingResult = Result::failure("No response generated");
while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
yield from $this->getStreamedLLMResponses($request);
$stream = $this->getInference($request)->toPartialLLMResponses();
yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel());

$llmResponse = $this->partialsGenerator->getCompleteResponse();
$partialResponses = $this->partialsGenerator->partialResponses();
$processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
}

$value = $this->processResult($processingResult, $request, $llmResponse, $partialResponses);
$value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);

yield $value;
}

protected function init(Request $request) : void {
$this->responseModel = $request->responseModel();
if ($this->responseModel === null) {
throw new Exception("Request does not have a response model");
}

protected function init() : void {
$this->retries = 0;
$this->messages = $request->messages(); // TODO: tx messages to Scripts
$this->errors = [];
}

protected function getLLMResponse(Request $request) : LLMResponse {
protected function getInference(Request $request) : InferenceResponse {
$this->events->dispatch(new RequestSentToLLM($request));
try {
$this->events->dispatch(new RequestSentToLLM($request));
$llmResponse = $this->makeInference($request)->toLLMResponse();
$llmResponse->content = match($request->mode()) {
Mode::Text => $llmResponse->content,
default => Json::from($llmResponse->content)->toString(),
};
return $this->makeInference($request);
} catch (Exception $e) {
$this->events->dispatch(new RequestToLLMFailed($request, $e->getMessage()));
throw $e;
}
return $llmResponse;
}

/**
* @param Request $request
* @return Generator<mixed>
*/
protected function getStreamedLLMResponses(Request $request) : Generator {
try {
$this->events->dispatch(new RequestSentToLLM($request));
$stream = $this->makeInference($request)->toPartialLLMResponses();
yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel());
} catch(Exception $e) {
$this->events->dispatch(new RequestToLLMFailed($request, $e->getMessage()));
throw $e;
}
}

protected function makeInference(Request $request) : InferenceResponse {
Expand All @@ -176,7 +153,7 @@ protected function processResponse(Request $request, LLMResponse $llmResponse, a
$this->events->dispatch(new ResponseReceivedFromLLM($llmResponse));

// we have LLMResponse here - let's process it: deserialize, validate, transform
$processingResult = $this->responseGenerator->makeResponse($llmResponse, $this->responseModel);
$processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel());

if ($processingResult->isFailure()) {
// retry - we have not managed to deserialize, validate or transform the response
Expand All @@ -186,7 +163,7 @@ protected function processResponse(Request $request, LLMResponse $llmResponse, a
return $processingResult;
}

protected function processResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed {
protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed {
if ($processingResult->isFailure()) {
$this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors));
throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors));
Expand All @@ -195,7 +172,7 @@ protected function processResult(Result $processingResult, Request $request, LLM
// get final value
$value = $processingResult->unwrap();
// store response
$request->setResponse($this->messages, $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts
$request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts
// notify on response generation
$this->events->dispatch(new ResponseGenerated($value));

Expand All @@ -207,7 +184,7 @@ protected function handleError(Result $processingResult, Request $request, LLMRe
$this->errors = is_array($error) ? $error : [$error];

// store failed response
$request->addFailedResponse($this->messages, $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts
$request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts
$this->retries++;
if (!$this->maxRetriesReached($request)) {
$this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors));
Expand Down
25 changes: 13 additions & 12 deletions src/Data/Traits/ChatTemplate/HandlesRetries.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@

trait HandlesRetries
{
protected function addRetryMessages() : void {
$failedResponse = $this->request->lastFailedResponse();
if (!$failedResponse || !$this->request->hasLastResponseFailed()) {
return;
}
foreach($this->request->attempts() as $attempt) {
$messages = $this->makeRetryMessages(
[], $attempt->llmResponse()->content, $attempt->errors()
);
$this->script->section('retries')->appendMessages($messages);
}
}

protected function makeRetryMessages(
array $messages,
string $jsonData,
Expand All @@ -19,16 +32,4 @@ protected function makeRetryMessages(
protected function makeRetryPrompt() : string {
return $this->request->retryPrompt() ?: $this->defaultRetryPrompt;
}

protected function addRetryMessages() {
$failedResponse = $this->request->lastFailedResponse();
if (!$failedResponse || !$this->request->hasLastResponseFailed()) {
return;
}
$this->script->section('retries')->appendMessages(
$this->makeRetryMessages(
[], $failedResponse->llmResponse()->content, $failedResponse->errors()
)
);
}
}
1 change: 1 addition & 0 deletions src/Extras/LLM/Inference.php
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ public function create(
$request = new InferenceRequest(
$messages, $model, $tools, $toolChoice, $responseFormat, $options, $mode, $this->cachedContext ?? null
);
dump($request);
$this->events->dispatch(new InferenceRequested($request));
return new InferenceResponse(
response: $this->driver->handle($request),
Expand Down

0 comments on commit d20cf91

Please sign in to comment.