From 614c6522508f40cd6223e10c0f8a71e7206411df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jorge=20Ant=C3=B3nio?= Date: Fri, 30 Aug 2024 12:31:28 +0100 Subject: [PATCH 1/2] feat: update llama (#131) * first commit * first commit --- Cargo.toml | 2 +- atoma-inference/src/models/candle/llama.rs | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a93a9fdf..471d2431 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-n candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", branch = "main" } clap = "4.5.4" config = "0.14.0" -cudarc = { version = "0.11.6", features = [ +cudarc = { version = "0.12.0", features = [ "std", "cublas", "cublaslt", diff --git a/atoma-inference/src/models/candle/llama.rs b/atoma-inference/src/models/candle/llama.rs index c114062c..e420b045 100644 --- a/atoma-inference/src/models/candle/llama.rs +++ b/atoma-inference/src/models/candle/llama.rs @@ -147,10 +147,12 @@ impl ModelTrait for LlamaModel { .bos_token_id .or_else(|| self.tokenizer.tokenizer().token_to_id(BOS_TOKEN)) .unwrap(); - let eos_token_id = self - .config - .eos_token_id - .or_else(|| self.tokenizer.tokenizer().token_to_id(EOS_TOKEN)); + let eos_token_id = self.config.eos_token_id.clone().or_else(|| { + self.tokenizer + .tokenizer() + .token_to_id(EOS_TOKEN) + .map(model::LlamaEosToks::Single) + }); let prompt_ids = self .tokenizer .tokenizer() @@ -201,8 +203,16 @@ impl ModelTrait for LlamaModel { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - if Some(next_token) == eos_token_id { - break; + match eos_token_id { + Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(model::LlamaEosToks::Multiple(ref eos_ids)) + if eos_ids.contains(&next_token) => + { + break; + } + _ => (), } if let Some(t) = self.tokenizer.next_token(next_token, request_id.clone())? { res += &t; From 78555e6e80d70bdfca7153928e2aba953082c574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jorge=20Ant=C3=B3nio?= Date: Mon, 2 Sep 2024 09:54:21 +0100 Subject: [PATCH 2/2] feat: add flux model (#132) * first commit * add flux impl * clean up the code * add changes * add logging * add more detailed loggings * running issues * add changes --- atoma-inference/src/model_thread.rs | 14 +- atoma-inference/src/models/candle/flux.rs | 425 ++++++++++++++++++++++ atoma-inference/src/models/candle/mod.rs | 1 + atoma-inference/src/models/mod.rs | 2 + atoma-inference/src/models/types.rs | 14 +- atoma-types/src/lib.rs | 18 +- 6 files changed, 465 insertions(+), 9 deletions(-) create mode 100644 atoma-inference/src/models/candle/flux.rs diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 1a95a966..140f0f26 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -15,9 +15,9 @@ use crate::models::candle::mixtral_nccl::MixtralNcclModel; use crate::models::{ candle::{ - falcon::FalconModel, llama::LlamaModel, mamba::MambaModel, mistral::MistralModel, - mixtral::MixtralModel, phi3::Phi3Model, quantized::QuantizedModel, qwen::QwenModel, - stable_diffusion::StableDiffusion, + falcon::FalconModel, flux::Flux, llama::LlamaModel, mamba::MambaModel, + mistral::MistralModel, mixtral::MixtralModel, phi3::Phi3Model, quantized::QuantizedModel, + qwen::QwenModel, stable_diffusion::StableDiffusion, }, config::{ModelConfig, ModelsConfig}, types::{LlmOutput, ModelType}, @@ -288,6 +288,14 @@ pub(crate) fn dispatch_model_thread( stream_tx, ) } + ModelType::FluxSchnell | ModelType::FluxDev => spawn_model_thread::( + model_name, + api_key.clone(), + cache_dir.clone(), + model_config, + model_receiver, + stream_tx, + ), ModelType::LlamaV1 | ModelType::LlamaV2 | ModelType::LlamaTinyLlama1_1BChat diff --git a/atoma-inference/src/models/candle/flux.rs b/atoma-inference/src/models/candle/flux.rs new file mode 100644 index 00000000..ffb1e57e --- /dev/null +++ b/atoma-inference/src/models/candle/flux.rs @@ -0,0 +1,425 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::{path::PathBuf, str::FromStr}; + +use crate::{ + bail, + models::{ + candle::convert_to_image, + config::ModelConfig, + types::{LlmOutput, ModelType}, + ModelError, ModelTrait, + }, +}; +use atoma_types::{Digest, PromptParams}; +use candle::{DType, Device, Module, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::{clip, flux, t5}; +use hf_hub::api::sync::ApiBuilder; +use serde::{Deserialize, Serialize}; +use tokenizers::Tokenizer; +use tracing::{info, trace}; + +use super::device; + +#[derive(Debug, Clone, Deserialize, Serialize)] +/// Flux input data +pub struct FluxInput { + /// Text input + prompt: String, + /// Image height + height: Option, + /// Image width + width: Option, + /// decode only + decode_only: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Flux model variants +pub enum Model { + Schnell, + Dev, +} + +/// Flux model structure +pub struct Flux { + /// Device that hosts the models + device: Device, + /// Data type of the models + dtype: DType, + /// Flux model variant + model: Model, + /// T5 model + t5_model: t5::T5EncoderModel, + /// T5 tokenizer + t5_tokenizer: Tokenizer, + /// CLIP model + clip_model: clip::text_model::ClipTextTransformer, + /// CLIP tokenizer + clip_tokenizer: Tokenizer, + /// Biflux model + bf_model: flux::model::Flux, + /// Autoencoder model + ae_model: flux::autoencoder::AutoEncoder, +} + +/// Flux model load data +pub struct FluxLoadData { + /// Device that hosts the models + device: Device, + /// Data type of the models + dtype: DType, + /// File paths to load different models + /// configurations and tokenizers + file_paths: Vec, + /// Flux model variant + model: Model, +} + +/// Flux model output +/// Stable diffusion output +#[derive(Serialize)] +pub struct FluxOutput { + /// Data buffer of the image encoding + pub image_data: Vec, + /// Height of the image + pub height: usize, + /// Width of the image + pub width: usize, + /// Number of input tokens + input_tokens: usize, + /// Time to generate output + time_to_generate: f64, +} + +impl ModelTrait for Flux { + type Input = FluxInput; + type Output = FluxOutput; + type LoadData = FluxLoadData; + + fn fetch( + api_key: String, + cache_dir: PathBuf, + config: ModelConfig, + ) -> Result { + info!("Fetching Flux model.."); + let device = device(config.device_first_id())?; + let dtype = DType::from_str(&config.dtype())?; + + let api = ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?; + + let model_type = ModelType::from_str(&config.model_id())?; + let repo_id = model_type.repo().to_string(); + + info!("Fetching T5 model files.."); + + let bf_repo = api.repo(hf_hub::Repo::model(repo_id)); + let t5_repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let t5_model_file = t5_repo.get("model.safetensors")?; + let t5_config_filename = t5_repo.get("config.json")?; + let t5_tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + + info!("Fetching CLIP model files.."); + + let clip_repo = api.repo(hf_hub::Repo::model( + "openai/clip-vit-large-patch14".to_string(), + )); + let clip_model_file = clip_repo.get("model.safetensors")?; + let clip_tokenizer_filename = clip_repo.get("tokenizer.json")?; + + let model = match model_type { + ModelType::FluxSchnell => Model::Schnell, + ModelType::FluxDev => Model::Dev, + _ => bail!("Invalid model type for Flux model"), + }; + + info!("Fetching Biflux model files.."); + + let bf_model_file = match model { + Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?, + Model::Dev => bf_repo.get("flux1-dev.safetensors")?, + }; + + info!("Fetching Autoencoder model files.."); + let ae_model_file = bf_repo.get("ae.safetensors")?; + + Ok(Self::LoadData { + device, + dtype, + file_paths: vec![ + t5_config_filename, + t5_tokenizer_filename, + t5_model_file, + clip_tokenizer_filename, + clip_model_file, + bf_model_file, + ae_model_file, + ], + model, + }) + } + + fn load( + load_data: Self::LoadData, + _stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>, + ) -> Result { + info!("Loading T5 model.."); + let start = std::time::Instant::now(); + let t5_config_filename = load_data.file_paths[0].clone(); + let t5_tokenizer_filename = load_data.file_paths[1].clone(); + let t5_model_filename = load_data.file_paths[2].clone(); + let clip_tokenizer_filename = load_data.file_paths[3].clone(); + let clip_model_filename = load_data.file_paths[4].clone(); + let bf_model_filename = load_data.file_paths[5].clone(); + let ae_model_filename = load_data.file_paths[6].clone(); + + let t5_vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &[t5_model_filename], + load_data.dtype, + &load_data.device, + )? + }; + + let t5_config = std::fs::read_to_string(t5_config_filename)?; + let t5_config: t5::Config = serde_json::from_str(&t5_config)?; + + let t5_model = t5::T5EncoderModel::load(t5_vb, &t5_config)?; + let t5_tokenizer = Tokenizer::from_file(t5_tokenizer_filename)?; + + info!( + "Loaded T5 model in {} seconds", + start.elapsed().as_secs_f64() + ); + info!("Loading CLIP model.."); + + let start = std::time::Instant::now(); + let clip_vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &[clip_model_filename], + load_data.dtype, + &load_data.device, + )? + }; + let clip_config = clip::text_model::ClipTextConfig { + vocab_size: 49408, + projection_dim: 768, + activation: clip::text_model::Activation::QuickGelu, + intermediate_size: 3072, + embed_dim: 768, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + }; + let clip_model = + clip::text_model::ClipTextTransformer::new(clip_vb.pp("text_model"), &clip_config)?; + let clip_tokenizer = Tokenizer::from_file(clip_tokenizer_filename)?; + + info!( + "Loaded CLIP model in {} seconds", + start.elapsed().as_secs_f64() + ); + info!("Loading Biflux model.."); + + let start = std::time::Instant::now(); + let bf_vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &[bf_model_filename], + load_data.dtype, + &load_data.device, + )? + }; + + let bf_config = match load_data.model { + Model::Dev => flux::model::Config::dev(), + Model::Schnell => flux::model::Config::schnell(), + }; + let bf_model = flux::model::Flux::new(&bf_config, bf_vb)?; + + info!( + "Loaded Biflux model in {} seconds", + start.elapsed().as_secs_f64() + ); + info!("Loading Autoencoder model.."); + + let start = std::time::Instant::now(); + let ae_vb = unsafe { + VarBuilder::from_mmaped_safetensors( + &[ae_model_filename], + load_data.dtype, + &load_data.device, + )? + }; + let ae_config = match load_data.model { + Model::Dev => flux::autoencoder::Config::dev(), + Model::Schnell => flux::autoencoder::Config::schnell(), + }; + let ae_model = flux::autoencoder::AutoEncoder::new(&ae_config, ae_vb)?; + + info!( + "Loaded Autoencoder model in {} seconds", + start.elapsed().as_secs_f64() + ); + + Ok(Self { + device: load_data.device.clone(), + dtype: load_data.dtype, + model: load_data.model, + t5_model, + t5_tokenizer, + clip_model, + clip_tokenizer, + bf_model, + ae_model, + }) + } + + fn model_type(&self) -> ModelType { + match self.model { + Model::Schnell => ModelType::FluxSchnell, + Model::Dev => ModelType::FluxDev, + } + } + + fn run(&mut self, input: Self::Input) -> Result { + info!("Running Flux model, on input prompt: {}", input.prompt); + let start = std::time::Instant::now(); + + let width = input.width.unwrap_or(1360); + let height = input.height.unwrap_or(768); + + let mut t5_tokens = self + .t5_tokenizer + .encode(input.prompt.as_str(), true)? + .get_ids() + .to_vec(); + t5_tokens.resize(256, 0); + let input_t5_token_ids = Tensor::new(&t5_tokens[..], &self.device)?.unsqueeze(0)?; + let t5_embedding = self.t5_model.forward(&input_t5_token_ids)?; + + info!("Produced a T5 embedding"); + + info!("Running CLIP model, on input prompt: {}", input.prompt); + let clip_tokens = self + .clip_tokenizer + .encode(input.prompt.as_str(), true)? + .get_ids() + .to_vec(); + let input_clip_token_ids = Tensor::new(&clip_tokens[..], &self.device)?.unsqueeze(0)?; + let clip_embedding = self.clip_model.forward(&input_clip_token_ids)?; + + info!("Produced a CLIP embedding"); + + info!("Running Biflux model, on input prompt: {}", input.prompt); + let img = + flux::sampling::get_noise(1, height, width, &self.device)?.to_dtype(self.dtype)?; + let state = flux::sampling::State::new(&t5_embedding, &clip_embedding, &img)?; + let timesteps = match self.model { + Model::Dev => flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15))), + Model::Schnell => flux::sampling::get_schedule(4, None), + }; + + trace!("{state:?}"); + trace!("{timesteps:?}"); + let img = flux::sampling::denoise( + &self.bf_model, + &state.img, + &state.img_ids, + &state.txt, + &state.txt_ids, + &state.vec, + ×teps, + 4., + )?; + let img = flux::sampling::unpack(&img, height, width)?; + trace!("latent img\n{img}"); + + let img = self.ae_model.decode(&img)?; + trace!("img\n{img}"); + + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)? + .to_dtype(candle::DType::U8)? + .squeeze(0)?; + info!("FLAG: img.dims() = {:?}", img.dims()); + save_image(&img, "flux_output.png")?; + let (img, width, height) = convert_to_image(&img)?; + + Ok(FluxOutput { + image_data: img, + height, + width, + input_tokens: input.prompt.len(), + time_to_generate: start.elapsed().as_secs_f64(), + }) + } +} + +impl TryFrom<(Digest, PromptParams)> for FluxInput { + type Error = ModelError; + + fn try_from(value: (Digest, PromptParams)) -> Result { + let prompt_params = value.1; + match prompt_params { + PromptParams::Text2ImagePromptParams(p) => { + let height = p.height().map(|h| h as usize); + let width = p.width().map(|w| w as usize); + let prompt = p.prompt(); + let decode_only = p.decode_only(); + Ok(Self { + prompt, + height, + width, + decode_only, + }) + } + PromptParams::Text2TextPromptParams(_) => Err(ModelError::InvalidPromptParams), + } + } +} + +impl LlmOutput for FluxOutput { + fn num_input_tokens(&self) -> usize { + self.input_tokens + } + fn num_output_tokens(&self) -> Option { + None + } + fn time_to_generate(&self) -> f64 { + self.time_to_generate + } +} + +/// Saves an image to disk using the image crate, this expects an input with shape +/// (c, height, width). +fn save_image>(img: &Tensor, p: P) -> Result<(), ModelError> { + let p = p.as_ref(); + let (channel, height, width) = img.dims3()?; + if channel != 3 { + bail!("save_image expects an input of shape (3, height, width)") + } + let img = img.permute((1, 2, 0))?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => bail!("error saving image {p:?}"), + }; + image.save(p)?; + Ok(()) +} diff --git a/atoma-inference/src/models/candle/mod.rs b/atoma-inference/src/models/candle/mod.rs index 7270ac19..ec769751 100644 --- a/atoma-inference/src/models/candle/mod.rs +++ b/atoma-inference/src/models/candle/mod.rs @@ -21,6 +21,7 @@ pub mod mixtral_nccl; #[cfg(feature = "nccl")] mod mixtral_nccl_model; +pub mod flux; pub mod mamba; pub mod mistral; pub mod mixtral; diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 7277dce5..13238a78 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -74,6 +74,8 @@ pub enum ModelError { InvalidModelInput, #[error("Send error: `{0}`")] SendError(#[from] mpsc::error::SendError<(Digest, String)>), + #[error("Invalid prompt params")] + InvalidPromptParams, #[cfg(feature = "nccl")] #[error("Nccl error: `{}`", 0.0)] NcclError(NcclError), diff --git a/atoma-inference/src/models/types.rs b/atoma-inference/src/models/types.rs index 7e7e5320..aca04d47 100644 --- a/atoma-inference/src/models/types.rs +++ b/atoma-inference/src/models/types.rs @@ -33,6 +33,8 @@ pub enum ModelType { Falcon7b, Falcon40b, Falcon180b, + FluxSchnell, + FluxDev, LlamaV1, LlamaV2, LlamaSolar10_7B, @@ -103,6 +105,8 @@ impl FromStr for ModelType { "falcon_7b" => Ok(Self::Falcon7b), "falcon_40b" => Ok(Self::Falcon40b), "falcon_180b" => Ok(Self::Falcon180b), + "flux_dev" => Ok(Self::FluxDev), + "flux_schnell" => Ok(Self::FluxSchnell), "llama_v1" => Ok(Self::LlamaV1), "llama_v2" => Ok(Self::LlamaV2), "llama_solar_10_7b" => Ok(Self::LlamaSolar10_7B), @@ -177,6 +181,8 @@ impl ModelType { Self::Falcon7b => "tiiuae/falcon-7b", Self::Falcon40b => "tiiuae/falcon-40b", Self::Falcon180b => "tiiuae/falcon-180b", + Self::FluxDev => "black-forest-labs/FLUX.1-dev", + Self::FluxSchnell => "black-forest-labs/FLUX.1-schnell", Self::LlamaV1 => "Narsil/amall-7b", Self::LlamaV2 => "meta-llama/Llama-2-7b-hf", Self::LlamaSolar10_7B => "upstage/SOLAR-10.7B-v1.0", @@ -279,7 +285,9 @@ impl ModelType { Self::Mamba790m => "refs/pr/1", Self::Mamba1_4b => "refs/pr/1", Self::Mamba2_8b => "refs/pr/4", - Self::QuantizedL8b + Self::FluxDev + | Self::FluxSchnell + | Self::QuantizedL8b | Self::QuantizedLeo13b | Self::QuantizedLeo7b | Self::QuantizedLlama13b @@ -314,6 +322,8 @@ impl Display for ModelType { Self::Falcon7b => write!(f, "falcon_7b"), Self::Falcon40b => write!(f, "falcon_40b"), Self::Falcon180b => write!(f, "falcon_180b"), + Self::FluxDev => write!(f, "flux_dev"), + Self::FluxSchnell => write!(f, "flux_schnell"), Self::LlamaV1 => write!(f, "llama_v1"), Self::LlamaV2 => write!(f, "llama_v2"), Self::LlamaSolar10_7B => write!(f, "llama_solar_10_7b"), @@ -579,7 +589,7 @@ impl TryFrom<(Digest, PromptParams)> for StableDiffusionInput { match value { PromptParams::Text2ImagePromptParams(p) => Ok(Self { prompt: p.prompt(), - uncond_prompt: p.uncond_prompt(), + uncond_prompt: p.uncond_prompt().unwrap_or_default(), height: p.height().map(|t| t.try_into().unwrap()), width: p.width().map(|t| t.try_into().unwrap()), n_steps: p.n_steps().map(|t| t.try_into().unwrap()), diff --git a/atoma-types/src/lib.rs b/atoma-types/src/lib.rs index d0856cf7..8b40c18a 100644 --- a/atoma-types/src/lib.rs +++ b/atoma-types/src/lib.rs @@ -412,7 +412,7 @@ pub struct Text2ImagePromptParams { /// Model to run the inference model: String, /// Unconditional prompt, used in stable diffusion models - uncond_prompt: String, + uncond_prompt: Option, /// Height of the final generated image height: Option, /// Width of the final generated image @@ -429,6 +429,8 @@ pub struct Text2ImagePromptParams { img2img_strength: f64, /// The random seed for inference sampling random_seed: Option, + /// Only decode the image (applicable to Flux models) + decode_only: Option, } impl Text2ImagePromptParams { @@ -437,7 +439,7 @@ impl Text2ImagePromptParams { pub fn new( prompt: String, model: String, - uncond_prompt: String, + uncond_prompt: Option, height: Option, width: Option, n_steps: Option, @@ -446,6 +448,7 @@ impl Text2ImagePromptParams { img2img: Option, img2img_strength: f64, random_seed: Option, + decode_only: Option, ) -> Self { Self { prompt, @@ -459,6 +462,7 @@ impl Text2ImagePromptParams { img2img, img2img_strength, random_seed, + decode_only, } } @@ -473,7 +477,7 @@ impl Text2ImagePromptParams { } /// Getter for `uncond_prompt` - pub fn uncond_prompt(&self) -> String { + pub fn uncond_prompt(&self) -> Option { self.uncond_prompt.clone() } @@ -516,6 +520,11 @@ impl Text2ImagePromptParams { pub fn random_seed(&self) -> Option { self.random_seed } + + /// Getter for `decode_only` + pub fn decode_only(&self) -> Option { + self.decode_only.clone() + } } impl TryFrom for Text2ImagePromptParams { @@ -525,7 +534,7 @@ impl TryFrom for Text2ImagePromptParams { Ok(Self { prompt: utils::parse_str(&value["prompt"])?, model: utils::parse_str(&value["model"])?, - uncond_prompt: utils::parse_str(&value["uncond_prompt"])?, + uncond_prompt: utils::parse_optional_str(&value["uncond_prompt"]), random_seed: Some(utils::parse_u32(&value["random_seed"])?), height: Some(utils::parse_u64(&value["height"])?), width: Some(utils::parse_u64(&value["width"])?), @@ -534,6 +543,7 @@ impl TryFrom for Text2ImagePromptParams { guidance_scale: Some(utils::parse_f32_from_le_bytes(&value["guidance_scale"])? as f64), img2img: utils::parse_optional_str(&value["img2img"]), img2img_strength: utils::parse_f32_from_le_bytes(&value["img2img_strength"])? as f64, + decode_only: utils::parse_optional_str(&value["decode_only"]), }) } }