Skip to content

Commit

Permalink
add support for image2image prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Sep 3, 2024
1 parent f90061d commit a3517fb
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 162 deletions.
6 changes: 3 additions & 3 deletions atoma-event-subscribe/sui/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::time::Duration;

use atoma_sui::subscriber::{SuiSubscriber, SuiSubscriberError};
use atoma_types::InputSource;
use atoma_types::{InputSource, ModelInput};
use clap::Parser;
use sui_sdk::types::base_types::ObjectID;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -31,7 +31,7 @@ async fn main() -> Result<(), SuiSubscriberError> {

let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32);
let (input_manager_tx, mut input_manager_rx) =
tokio::sync::mpsc::channel::<(InputSource, oneshot::Sender<String>)>(32);
tokio::sync::mpsc::channel::<(InputSource, oneshot::Sender<ModelInput>)>(32);

// Spawn a task to discard messages
tokio::spawn(async move {
Expand All @@ -42,7 +42,7 @@ async fn main() -> Result<(), SuiSubscriberError> {
InputSource::Ipfs { cid, format } => format!("{cid}.{format:?}"),
InputSource::Raw { prompt } => prompt,
};
if let Err(err) = oneshot.send(data) {
if let Err(err) = oneshot.send(ModelInput::Text(data)) {
error!("Failed to send response: {:?}", err);
}
}
Expand Down
22 changes: 16 additions & 6 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tracing::{debug, error, info, instrument};

use crate::config::SuiSubscriberConfig;
use crate::AtomaEvent;
use atoma_types::{InputSource, Request, SmallId, NON_SAMPLED_NODE_ERR};
use atoma_types::{InputSource, ModelInput, Request, SmallId, NON_SAMPLED_NODE_ERR};

/// The size of a request id, expressed in hex format
const WAIT_FOR_INPUT_MANAGER_RESPONSE_SECS: u64 = 5;
Expand Down Expand Up @@ -43,7 +43,7 @@ pub struct SuiSubscriber {
/// The websocket address of a Sui RPC node
ws_addr: Option<String>,
/// Input manager sender, responsible for sending the input metadata and a oneshot sender, to the input manager service to get back the user prompt.
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<ModelInput>)>,
}

impl SuiSubscriber {
Expand All @@ -55,7 +55,7 @@ impl SuiSubscriber {
package_id: ObjectID,
event_sender: mpsc::Sender<Request>,
request_timeout: Option<Duration>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<ModelInput>)>,
) -> Result<Self, SuiSubscriberError> {
let filter = EventFilter::Package(package_id);
Ok(Self {
Expand Down Expand Up @@ -91,7 +91,7 @@ impl SuiSubscriber {
pub async fn new_from_config<P: AsRef<Path>>(
config_path: P,
event_sender: mpsc::Sender<Request>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<ModelInput>)>,
) -> Result<Self, SuiSubscriberError> {
let config = SuiSubscriberConfig::from_file_path(config_path);
let small_id = config.small_id();
Expand Down Expand Up @@ -235,7 +235,17 @@ impl SuiSubscriber {
.await
.map_err(|_| SuiSubscriberError::TimeoutError)??;
// Replace the prompt string to the real prompt instead of the firebase user id.
request.set_raw_prompt(result);
match result {
ModelInput::ImageBytes(bytes) => {
request.set_raw_image(bytes);
}
ModelInput::ImageFile(path) => {
request.set_raw_prompt(path);
}
ModelInput::Text(text) => {
request.set_raw_prompt(text);
}
}
info!("Received new request: {:?}", request);
let request_id = request.id();
info!(
Expand Down Expand Up @@ -339,7 +349,7 @@ pub enum SuiSubscriberError {
#[error("Malformed event: `{0}`")]
MalformedEvent(String),
#[error("Sending input to input manager error: `{0}`")]
SendInputError(#[from] Box<mpsc::error::SendError<(InputSource, oneshot::Sender<String>)>>),
SendInputError(#[from] Box<mpsc::error::SendError<(InputSource, oneshot::Sender<ModelInput>)>>),
#[error("Error while sending request to input manager: `{0}`")]
InputManagerError(#[from] Box<tokio::sync::oneshot::error::RecvError>),
#[error("Timeout error getting the input from the input manager")]
Expand Down
7 changes: 4 additions & 3 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::HashMap, fmt::Debug, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle,
};

use atoma_types::{AtomaStreamingData, OutputType, PromptParams, Request, Response};
use atoma_types::{AtomaStreamingData, ModelParams, OutputType, Request, Response};
use futures::stream::FuturesUnordered;
use serde::Deserialize;
use thiserror::Error;
Expand Down Expand Up @@ -95,8 +95,8 @@ where
let num_sampled_nodes = request.num_sampled_nodes();
let params = request.params();
let output_type = match params {
PromptParams::Text2ImagePromptParams(_) => OutputType::Image,
PromptParams::Text2TextPromptParams(_) => OutputType::Text,
ModelParams::Text2ImageModelParams(_) => OutputType::Image,
ModelParams::Text2TextModelParams(_) => OutputType::Text,
};
let output_destination = Deserialize::deserialize(&mut rmp_serde::Deserializer::new(
request.output_destination().as_slice(),
Expand All @@ -105,6 +105,7 @@ where
let output_id = match output_destination {
atoma_types::OutputDestination::Firebase { request_id } => request_id,
atoma_types::OutputDestination::Gateway { gateway_user_id } => gateway_user_id,
atoma_types::OutputDestination::Ipfs { cid } => cid,
};
let model_input = M::Input::try_from((output_id, params))?;
let model_output = self.model.run(model_input)?;
Expand Down
31 changes: 5 additions & 26 deletions atoma-inference/src/models/candle/flux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
ModelError, ModelTrait,
},
};
use atoma_types::{AtomaStreamingData, Digest, PromptParams};
use atoma_types::{AtomaStreamingData, Digest, ModelParams};
use candle::{DType, Device, Module, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::{clip, flux, t5};
Expand Down Expand Up @@ -356,8 +356,6 @@ impl ModelTrait for Flux {
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?
.to_dtype(candle::DType::U8)?
.squeeze(0)?;
info!("FLAG: img.dims() = {:?}", img.dims());
save_image(&img, "flux_output.png")?;
let (img, width, height) = convert_to_image(&img)?;

Ok(FluxOutput {
Expand All @@ -370,13 +368,13 @@ impl ModelTrait for Flux {
}
}

impl TryFrom<(Digest, PromptParams)> for FluxInput {
impl TryFrom<(Digest, ModelParams)> for FluxInput {
type Error = ModelError;

fn try_from(value: (Digest, PromptParams)) -> Result<Self, Self::Error> {
fn try_from(value: (Digest, ModelParams)) -> Result<Self, Self::Error> {
let prompt_params = value.1;
match prompt_params {
PromptParams::Text2ImagePromptParams(p) => {
ModelParams::Text2ImageModelParams(p) => {
let height = p.height().map(|h| h as usize);
let width = p.width().map(|w| w as usize);
let prompt = p.get_input_text(); // TODO: for now we use the raw prompt, but likely to fetch it from an external source in the future
Expand All @@ -388,7 +386,7 @@ impl TryFrom<(Digest, PromptParams)> for FluxInput {
decode_only,
})
}
PromptParams::Text2TextPromptParams(_) => Err(ModelError::InvalidPromptParams),
ModelParams::Text2TextModelParams(_) => Err(ModelError::InvalidModelParams),
}
}
}
Expand All @@ -404,22 +402,3 @@ impl LlmOutput for FluxOutput {
self.time_to_generate
}
}

/// Saves an image to disk using the image crate, this expects an input with shape
/// (c, height, width).
fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<(), ModelError> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => bail!("error saving image {p:?}"),
};
image.save(p)?;
Ok(())
}
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,8 @@ impl StableDiffusion {
}

/// Pre-processes image
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> Result<Tensor, ModelError> {
let img = image::io::Reader::open(path)?.decode()?;
fn image_preprocess(image_bytes: &[u8]) -> Result<Tensor, ModelError> {
let img = image::load_from_memory(image_bytes)?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::PathBuf;

use ::candle::{DTypeParseError, Error as CandleError};
use atoma_types::{AtomaStreamingData, Digest, PromptParams};
use atoma_types::{AtomaStreamingData, Digest, ModelParams};
#[cfg(feature = "nccl")]
use cudarc::{driver::DriverError, nccl::result::NcclError};
use thiserror::Error;
Expand All @@ -25,7 +25,7 @@ pub type ModelId = String;
/// Such interface abstracts the fetching, loading and running of an LLM. Moreover, it
/// indirectly expects that fetching is done through some API (most likely the HuggingFace api).
pub trait ModelTrait {
type Input: TryFrom<(Digest, PromptParams), Error = ModelError>;
type Input: TryFrom<(Digest, ModelParams), Error = ModelError>;
type Output: LlmOutput;
type LoadData;

Expand Down Expand Up @@ -75,7 +75,7 @@ pub enum ModelError {
#[error("Send error: `{0}`")]
SendError(#[from] mpsc::error::SendError<AtomaStreamingData>),
#[error("Invalid prompt params")]
InvalidPromptParams,
InvalidModelParams,
#[cfg(feature = "nccl")]
#[error("Nccl error: `{}`", 0.0)]
NcclError(NcclError),
Expand Down
18 changes: 9 additions & 9 deletions atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fmt::Display, path::PathBuf, str::FromStr};

use atoma_types::{Digest, PromptParams};
use atoma_types::{Digest, ModelParams};
use candle::{DType, Device};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -468,12 +468,12 @@ impl TextModelInput {
}
}

impl TryFrom<(String, PromptParams)> for TextModelInput {
impl TryFrom<(String, ModelParams)> for TextModelInput {
type Error = ModelError;

fn try_from((request_id, value): (String, PromptParams)) -> Result<Self, Self::Error> {
fn try_from((request_id, value): (String, ModelParams)) -> Result<Self, Self::Error> {
match value {
PromptParams::Text2TextPromptParams(p) => Ok(Self {
ModelParams::Text2TextModelParams(p) => Ok(Self {
request_id,
prompt: p.get_input_text(),
temperature: p.temperature(),
Expand All @@ -487,7 +487,7 @@ impl TryFrom<(String, PromptParams)> for TextModelInput {
pre_prompt_tokens: p.pre_prompt_tokens(),
should_stream_output: p.should_stream_output(),
}),
PromptParams::Text2ImagePromptParams(_) => Err(ModelError::InvalidModelInput),
ModelParams::Text2ImageModelParams(_) => Err(ModelError::InvalidModelInput),
}
}
}
Expand Down Expand Up @@ -573,7 +573,7 @@ pub struct StableDiffusionInput {
/// Image to image, to be used if one aims to
/// transform the initial generated image in a given
/// specific way
pub img2img: Option<String>,
pub img2img: Option<Vec<u8>>,
/// The strength, indicates how much to transform the initial image. The
/// value must be between 0 and 1, a value of 1 discards the initial image
/// information.
Expand All @@ -582,12 +582,12 @@ pub struct StableDiffusionInput {
pub random_seed: Option<u32>,
}

impl TryFrom<(Digest, PromptParams)> for StableDiffusionInput {
impl TryFrom<(Digest, ModelParams)> for StableDiffusionInput {
type Error = ModelError;

fn try_from((_, value): (Digest, PromptParams)) -> Result<Self, Self::Error> {
fn try_from((_, value): (Digest, ModelParams)) -> Result<Self, Self::Error> {
match value {
PromptParams::Text2ImagePromptParams(p) => Ok(Self {
ModelParams::Text2ImageModelParams(p) => Ok(Self {
prompt: p.get_input_text(),
uncond_prompt: p.uncond_prompt().unwrap_or_default(),
height: p.height().map(|t| t.try_into().unwrap()),
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pub enum ModelServiceError {

#[cfg(test)]
mod tests {
use atoma_types::{Digest, PromptParams};
use atoma_types::{Digest, ModelParams};
use serde::Serialize;
use std::io::Write;
use toml::{toml, Value};
Expand Down Expand Up @@ -178,10 +178,10 @@ mod tests {
}
}

impl TryFrom<(Digest, PromptParams)> for MockInput {
impl TryFrom<(Digest, ModelParams)> for MockInput {
type Error = ModelError;

fn try_from(_: (Digest, PromptParams)) -> Result<Self, Self::Error> {
fn try_from(_: (Digest, ModelParams)) -> Result<Self, Self::Error> {
Ok(Self {})
}
}
Expand Down
10 changes: 5 additions & 5 deletions atoma-inference/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ use crate::models::{config::ModelConfig, types::ModelType, ModelError, ModelTrai
use std::{path::PathBuf, time::Duration};

mod prompts;
use atoma_types::Text2TextPromptParams;
use atoma_types::Text2TextModelParams;
use atoma_types::{AtomaStreamingData, Digest};
use prompts::PROMPTS;
use serde::Serialize;

use std::{collections::HashMap, sync::mpsc};

use atoma_types::{PromptParams, Request};
use atoma_types::{ModelParams, Request};
use futures::{stream::FuturesUnordered, StreamExt};
use reqwest::Client;
use serde_json::json;
Expand Down Expand Up @@ -47,10 +47,10 @@ impl LlmOutput for MockInputOutput {
}
}

impl TryFrom<(Digest, PromptParams)> for MockInputOutput {
impl TryFrom<(Digest, ModelParams)> for MockInputOutput {
type Error = ModelError;

fn try_from((_, value): (Digest, PromptParams)) -> Result<Self, Self::Error> {
fn try_from((_, value): (Digest, ModelParams)) -> Result<Self, Self::Error> {
Ok(Self {
id: value.into_text2text_prompt_params().unwrap().max_tokens(),
})
Expand Down Expand Up @@ -144,7 +144,7 @@ async fn test_mock_model_thread() {
for sender in model_thread_dispatcher.model_senders.values() {
let (response_sender, response_receiver) = oneshot::channel();
let max_tokens = i as u64;
let prompt_params = PromptParams::Text2TextPromptParams(Text2TextPromptParams::new(
let prompt_params = ModelParams::Text2TextModelParams(Text2TextModelParams::new(
atoma_types::InputSource::Raw {
prompt: "".to_string(),
},
Expand Down
1 change: 1 addition & 0 deletions atoma-input-manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ config.workspace = true
futures.workspace = true
gql_client.workspace = true
hex.workspace = true
image.workspace = true
ipfs-api-backend-hyper.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde.workspace = true
Expand Down
5 changes: 3 additions & 2 deletions atoma-input-manager/src/firebase/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use atoma_helpers::{Firebase, FirebaseAuth};
use atoma_types::ModelInput;
use reqwest::Client;
use tracing::{info, instrument};
use url::Url;
Expand Down Expand Up @@ -40,7 +41,7 @@ impl FirebaseInputManager {
pub async fn handle_get_request(
&mut self,
request_id: String,
) -> Result<String, AtomaInputManagerError> {
) -> Result<ModelInput, AtomaInputManagerError> {
let client = Client::new();
let token = self.auth.get_id_token().await?;
let mut url = self.firebase_url.clone();
Expand All @@ -60,7 +61,7 @@ impl FirebaseInputManager {
if response.status().is_success() {
let text = response.text().await?;
info!("Received response with text: {text}");
return Ok(text);
return Ok(ModelInput::Text(text));
}
tokio::time::sleep(tokio::time::Duration::from_secs(SLEEP_BETWEEN_REQUESTS_SEC)).await;
}
Expand Down
Loading

0 comments on commit a3517fb

Please sign in to comment.