diff --git a/crates/llm-chain-openai/src/chatgpt/executor.rs b/crates/llm-chain-openai/src/chatgpt/executor.rs index 2cd9f911..ea7a697c 100644 --- a/crates/llm-chain-openai/src/chatgpt/executor.rs +++ b/crates/llm-chain-openai/src/chatgpt/executor.rs @@ -8,7 +8,7 @@ use llm_chain::tokens::TokenCollection; use super::prompt::create_chat_completion_request; use super::prompt::format_chat_messages; -use async_openai::error::OpenAIError; +use async_openai::{error::OpenAIError, types::ChatCompletionRequestMessage}; use llm_chain::prompt::Prompt; use llm_chain::tokens::PromptTokensError; @@ -19,7 +19,7 @@ use llm_chain::traits::{ExecutorCreationError, ExecutorError}; use async_trait::async_trait; use llm_chain::tokens::TokenCount; -use tiktoken_rs::async_openai::num_tokens_from_messages; +use tiktoken_rs::get_chat_completion_max_tokens; use std::sync::Arc; @@ -109,9 +109,12 @@ impl traits::Executor for Executor { ) -> Result { let opts_cas = self.cascade(Some(opts)); let model = self.get_model_from_invocation_options(&opts_cas); - let messages = format_chat_messages(prompt.to_chat())?; - let tokens_used = num_tokens_from_messages(&model, &messages) - .map_err(|_| PromptTokensError::NotAvailable)?; + let messages: Vec = format_chat_messages(prompt.to_chat())?; + let no_messages: Vec = Vec::new(); + let tokens_used = get_chat_completion_max_tokens(&model, no_messages.as_slice()) + .map_err(|_| PromptTokensError::NotAvailable)? + - get_chat_completion_max_tokens(&model, as_tiktoken_messages(messages).as_slice()) + .map_err(|_| PromptTokensError::NotAvailable)?; Ok(TokenCount::new( self.max_tokens_allowed(opts), @@ -136,6 +139,22 @@ impl traits::Executor for Executor { } } +fn as_tiktoken_message( + message: &ChatCompletionRequestMessage, +) -> tiktoken_rs::ChatCompletionRequestMessage { + tiktoken_rs::ChatCompletionRequestMessage { + role: message.role.to_string(), + content: message.content.clone(), + name: message.name.clone(), + } +} + +fn as_tiktoken_messages( + messages: Vec, +) -> Vec { + messages.iter().map(|x| as_tiktoken_message(x)).collect() +} + pub struct OpenAITokenizer { model_name: String, }