Skip to content

Commit

Permalink
calculate_image_tokens_openai
Browse files Browse the repository at this point in the history
Signed-off-by: Valeryi <[email protected]>
  • Loading branch information
valaises committed Apr 29, 2024
1 parent 6abec13 commit bfcc5f0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,6 @@ typetag = "0.2"
dyn_partial_eq = "=0.1.2"
rayon = "1.8.0"
backtrace = "0.3.71"
image = "0.23.14"
base64 = "0.13.0"

70 changes: 56 additions & 14 deletions src/scratchpads/chat_utils_limit_history.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
use std::io::Cursor;
use image::io::Reader as ImageReader;
use tracing::{error, info};

use crate::scratchpad_abstract::HasTokenizerAndEot;
use crate::call_validation::ChatMessage;


fn calculate_image_tokens_openai(image_string: &String) -> Result<i32, String> {
// as per https://platform.openai.com/docs/guides/vision
const SMALL_CHUNK_SIZE: u32 = 512;
const COST_PER_SMALL_CHUNK: i32 = 170;
const BIG_CHUNK_SIZE: u32 = 2048;
const CONST_COST: i32 = 85;

let image_bytes = base64::decode(image_string).map_err(|_| "base64 decode failed".to_string())?;
let cursor = Cursor::new(image_bytes);
let reader = ImageReader::new(cursor).with_guessed_format().map_err(|e| e.to_string())?;
let (mut width, mut height) = reader.into_dimensions().map_err(|_| "Failed to get dimensions".to_string())?;

let shrink_factor = (width.max(height) as f64) / (BIG_CHUNK_SIZE as f64);
if shrink_factor > 1.0 {
width = (width as f64 / shrink_factor) as u32;
height = (height as f64 / shrink_factor) as u32;
}

let width_chunks = (width as f64 / SMALL_CHUNK_SIZE as f64).ceil() as u32;
let height_chunks = (height as f64 / SMALL_CHUNK_SIZE as f64).ceil() as u32;
let small_chunks_needed = width_chunks * height_chunks;

Ok(small_chunks_needed as i32 * COST_PER_SMALL_CHUNK + CONST_COST)
}

fn calculate_t_cnt(msg: &ChatMessage, t: &HasTokenizerAndEot) -> Result<i32, String> {
return if msg.kind == "text" {
Ok(3 + t.count_tokens(msg.content.as_str())?)
}
else if msg.kind == "image" {
let t_cnt = calculate_image_tokens_openai(&msg.content).unwrap_or_else(|e| {
error!("calculate_image_tokens_openai failed: {}; applying max value: 2805", e);
2805
});
Ok(t_cnt)
}
else {
Err(format!("unknown msg kind: {}", msg.kind))
}
}

pub fn limit_messages_history(
t: &HasTokenizerAndEot,
messages: &Vec<ChatMessage>,
Expand All @@ -12,25 +57,22 @@ pub fn limit_messages_history(
) -> Result<Vec<ChatMessage>, String>
{
let tokens_limit: i32 = context_size as i32 - max_new_tokens as i32;
tracing::info!("limit_messages_history tokens_limit={} <= context_size={} - max_new_tokens={}", tokens_limit, context_size, max_new_tokens);
info!("limit_messages_history tokens_limit={} <= context_size={} - max_new_tokens={}", tokens_limit, context_size, max_new_tokens);
let mut tokens_used: i32 = 0;
let mut message_token_count: Vec<i32> = vec![0; messages.len()];
let mut message_take: Vec<bool> = vec![false; messages.len()];
let mut have_system = false;
for (i, msg) in messages.iter().enumerate() {
let tcnt = (3 + t.count_tokens(msg.content.as_str())?) as i32; // 3 for role "\n\nASSISTANT:" kind of thing
message_token_count[i] = tcnt;
let t_cnt = calculate_t_cnt(msg, t)?;
message_token_count[i] = t_cnt;
if i==0 && msg.role == "system" {
message_take[i] = true;
tokens_used += tcnt;
tokens_used += t_cnt;
have_system = true;
}
if msg.kind == "image" {
message_take[i] = true;
}
if i >= last_user_msg_starts {
message_take[i] = true;
tokens_used += tcnt;
tokens_used += t_cnt;
}
}
let need_default_system_msg = !have_system && default_system_message.len() > 0;
Expand All @@ -39,18 +81,18 @@ pub fn limit_messages_history(
tokens_used += tcnt;
}
for i in (0..messages.len()).rev() {
let tcnt = 3 + message_token_count[i];
let t_cnt = 3 + message_token_count[i];
if !message_take[i] {
if tokens_used + tcnt < tokens_limit {
if tokens_used + t_cnt < tokens_limit {
message_take[i] = true;
tokens_used += tcnt;
tracing::info!("take {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content, 30), tokens_used, tokens_limit);
tokens_used += t_cnt;
info!("take {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content, 30), tokens_used, tokens_limit);
} else {
tracing::info!("drop {:?} with {} tokens, quit", crate::nicer_logs::first_n_chars(&messages[i].content, 30), tcnt);
info!("drop {:?} with {} tokens, quit", crate::nicer_logs::first_n_chars(&messages[i].content, 30), t_cnt);
break;
}
} else {
tracing::info!("not allowed to drop {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content, 30), tokens_used, tokens_limit);
info!("not allowed to drop {:?}, tokens_used={} < {}", crate::nicer_logs::first_n_chars(&messages[i].content, 30), tokens_used, tokens_limit);
}
}
let mut messages_out: Vec<ChatMessage> = messages.iter().enumerate().filter(|(i, _)| message_take[*i]).map(|(_, x)| x.clone()).collect();
Expand Down

0 comments on commit bfcc5f0

Please sign in to comment.