Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: input prompt is store on firebase #130

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
resolver = "2"

members = [
"atoma-crypto",
"atoma-client",
"atoma-crypto",
"atoma-event-subscribe/arbitrum",
"atoma-event-subscribe/solana",
"atoma-event-subscribe/sui",
"atoma-inference",
"atoma-helpers",
"atoma-helpers",
"atoma-inference",
"atoma-input-manager",
"atoma-json-rpc",
"atoma-networking",
"atoma-node",
"atoma-json-rpc",
"atoma-storage",
"atoma-types",
"atoma-output-manager",
"atoma-storage",
"atoma-streamer",
"atoma-helpers",
"atoma-types",
"atoma-vllm",
]

Expand All @@ -31,10 +32,11 @@ atoma-client = { path = "./atoma-client/" }
atoma-crypto = { path = "./atoma-crypto/" }
atoma-helpers = { path = "./atoma-helpers/" }
atoma-inference = { path = "./atoma-inference/" }
atoma-input-manager = { path = "./atoma-input-manager/" }
atoma-output-manager = { path = "./atoma-output-manager/" }
atoma-paged-attention = { git = "https://github.com/atoma-network/atoma-paged-attention", branch = "main" }
atoma-sui = { path = "./atoma-event-subscribe/sui/" }
atoma-streamer = { path = "./atoma-streamer" }
atoma-sui = { path = "./atoma-event-subscribe/sui/" }
atoma-types = { path = "./atoma-types" }
axum = "0.7.5"
blake2 = "0.10.6"
Expand Down
1 change: 1 addition & 0 deletions atoma-event-subscribe/sui/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ atoma-types.workspace = true
clap.workspace = true
config.workspace = true
futures.workspace = true
hex.workspace = true
serde.workspace = true
serde_json.workspace = true
sui-sdk.workspace = true
Expand Down
19 changes: 19 additions & 0 deletions atoma-event-subscribe/sui/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::time::Duration;

use atoma_sui::subscriber::{SuiSubscriber, SuiSubscriberError};
use atoma_types::InputSource;
use clap::Parser;
use sui_sdk::types::base_types::ObjectID;
use tokio::sync::oneshot;
use tracing::{error, info};

#[derive(Debug, Parser)]
Expand All @@ -28,6 +30,22 @@ async fn main() -> Result<(), SuiSubscriberError> {
let ws_url = args.ws_addr;

let (event_sender, mut event_receiver) = tokio::sync::mpsc::channel(32);
let (input_manager_tx, mut input_manager_rx) =
tokio::sync::mpsc::channel::<(InputSource, oneshot::Sender<String>)>(32);

// Spawn a task to discard messages
tokio::spawn(async move {
while let Some((input_source, oneshot)) = input_manager_rx.recv().await {
info!("Received input from source: {:?}", input_source);
let data = match input_source {
InputSource::Firebase { request_id } => request_id,
InputSource::Raw { prompt } => prompt,
};
if let Err(err) = oneshot.send(data) {
error!("Failed to send response: {:?}", err);
}
}
});

let event_subscriber = SuiSubscriber::new(
1,
Expand All @@ -36,6 +54,7 @@ async fn main() -> Result<(), SuiSubscriberError> {
package_id,
event_sender,
Some(Duration::from_secs(5 * 60)),
input_manager_tx,
)
.await?;

Expand Down
51 changes: 37 additions & 14 deletions atoma-event-subscribe/sui/src/subscriber.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Write, path::Path, str::FromStr, time::Duration};
use std::{path::Path, str::FromStr, time::Duration};

use futures::StreamExt;
use serde_json::Value;
Expand All @@ -7,15 +7,16 @@ use sui_sdk::types::base_types::{ObjectID, ObjectIDParseError};
use sui_sdk::types::event::EventID;
use sui_sdk::{SuiClient, SuiClientBuilder};
use thiserror::Error;
use tokio::sync::mpsc;
use tokio::sync::oneshot::error::RecvError;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, instrument};

use crate::config::SuiSubscriberConfig;
use crate::AtomaEvent;
use atoma_types::{Request, SmallId, NON_SAMPLED_NODE_ERR};
use atoma_types::{InputSource, Request, SmallId, NON_SAMPLED_NODE_ERR};

/// The size of a request id, expressed in hex format
const REQUEST_ID_HEX_SIZE: usize = 64;
const WAIT_FOR_INPUT_MANAGER_RESPONSE_SECS: u64 = 5;

/// `SuiSubscriber` - Responsible for listening to events emitted from the Atoma smart contract
/// on the Sui blockchain.
Expand All @@ -41,6 +42,8 @@ pub struct SuiSubscriber {
request_timeout: Option<Duration>,
/// The websocket address of a Sui RPC node
ws_addr: Option<String>,
/// Input manager sender, responsible for sending the input metadata and a oneshot sender, to the input manager service to get back the user prompt.
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
}

impl SuiSubscriber {
Expand All @@ -52,6 +55,7 @@ impl SuiSubscriber {
package_id: ObjectID,
event_sender: mpsc::Sender<Request>,
request_timeout: Option<Duration>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
) -> Result<Self, SuiSubscriberError> {
let filter = EventFilter::Package(package_id);
Ok(Self {
Expand All @@ -62,6 +66,7 @@ impl SuiSubscriber {
event_sender,
request_timeout,
last_event_id: None,
input_manager_tx,
})
}

Expand All @@ -86,6 +91,7 @@ impl SuiSubscriber {
pub async fn new_from_config<P: AsRef<Path>>(
config_path: P,
event_sender: mpsc::Sender<Request>,
input_manager_tx: mpsc::Sender<(InputSource, oneshot::Sender<String>)>,
) -> Result<Self, SuiSubscriberError> {
let config = SuiSubscriberConfig::from_file_path(config_path);
let small_id = config.small_id();
Expand All @@ -100,6 +106,7 @@ impl SuiSubscriber {
package_id,
event_sender,
Some(request_timeout),
input_manager_tx,
)
.await
}
Expand Down Expand Up @@ -213,18 +220,26 @@ impl SuiSubscriber {
#[instrument(skip(self, event_data))]
async fn handle_prompt_event(&self, event_data: Value) -> Result<(), SuiSubscriberError> {
debug!("event data: {}", event_data);
let request = Request::try_from((self.id, event_data))?;
let mut request = Request::try_from((self.id, event_data))?;

// Get the prompt and replace with the actual prompt.
let (oneshot_sender, oneshot_receiver) = tokio::sync::oneshot::channel();
self.input_manager_tx
.send((request.params().prompt(), oneshot_sender))
.await
.map_err(Box::new)?;
let result = tokio::time::timeout(
Duration::from_secs(WAIT_FOR_INPUT_MANAGER_RESPONSE_SECS),
oneshot_receiver,
)
.await
.map_err(|_| SuiSubscriberError::TimeoutError)??;
// Replace the prompt string to the real prompt instead of the firebase user id.
request.set_raw_prompt(result);
info!("Received new request: {:?}", request);
let request_id =
request
.id()
.iter()
.fold(String::with_capacity(REQUEST_ID_HEX_SIZE), |mut acc, &b| {
write!(acc, "{:02x}", b).expect("Failed to write to request_id");
acc
});
let request_id = request.id();
info!(
"Current node has been sampled for request with id: {}",
"Current node has been sampled for request with id: {:?}",
request_id
);
self.event_sender.send(request).await.map_err(Box::new)?;
Expand Down Expand Up @@ -323,4 +338,12 @@ pub enum SuiSubscriberError {
TypeConversionError(#[from] anyhow::Error),
#[error("Malformed event: `{0}`")]
MalformedEvent(String),
#[error("Sending input to input manager error: `{0}`")]
SendInputError(#[from] Box<mpsc::error::SendError<(InputSource, oneshot::Sender<String>)>>),
#[error("Error while sending request to input manager: `{0}`")]
InputManagerError(#[from] Box<tokio::sync::oneshot::error::RecvError>),
#[error("Timeout error getting the input from the input manager")]
TimeoutError,
#[error("Request receive error {0}")]
RecvError(#[from] RecvError),
}
21 changes: 20 additions & 1 deletion atoma-helpers/src/firebase/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod auth;
pub use auth::*;
use reqwest::{Client, Url};
use serde_json::json;
use std::sync::Arc;
use tokio::sync::Mutex;

Expand All @@ -26,10 +28,27 @@ impl Firebase {
email: String,
password: String,
api_key: String,
fireabase_url: &Url,
node_id: u64,
) -> Result<FirebaseAuth, FirebaseAuthError> {
// This will prevent multiple calls to add_user from happening at the same time. Because in case the user doesn't exists it will trigger multiple signups.
let _guard = self.add_user_lock.lock().await;
let firebase_auth = FirebaseAuth::new(email, password, api_key).await?;
let mut firebase_auth = FirebaseAuth::new(email, password, api_key).await?;
let client = Client::new();
let token = firebase_auth.get_id_token().await?;
let mut url = fireabase_url.clone();
{
let mut path_segment = url.path_segments_mut().unwrap();
path_segment.push("nodes");
path_segment.push(&format!("{}.json", firebase_auth.get_local_id()?));
}
url.set_query(Some(&format!("auth={token}")));
let data = json!({
"id":node_id.to_string()
});
let response = client.put(url).json(&data).send().await?;
response.text().await?;

Ok(firebase_auth)
}
}
1 change: 1 addition & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
image = { workspace = true }
metrics.workspace = true
rmp-serde.workspace = true
thiserror.workspace = true
tokenizers = { workspace = true }
tokio = { workspace = true, features = ["full", "tracing"] }
Expand Down
19 changes: 14 additions & 5 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::{
collections::HashMap, fmt::Debug, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle,
};

use atoma_types::{Digest, OutputType, PromptParams, Request, Response};
use atoma_types::{AtomaStreamingData, OutputType, PromptParams, Request, Response};
use futures::stream::FuturesUnordered;
use serde::Deserialize;
use thiserror::Error;
use tokio::sync::oneshot::{self, error::RecvError};
use tracing::{debug, error, info, instrument, warn, Span};
Expand Down Expand Up @@ -97,7 +98,15 @@ where
PromptParams::Text2ImagePromptParams(_) => OutputType::Image,
PromptParams::Text2TextPromptParams(_) => OutputType::Text,
};
let model_input = M::Input::try_from((hex::encode(&request_id), params))?;
let output_destination = Deserialize::deserialize(&mut rmp_serde::Deserializer::new(
request.output_destination().as_slice(),
))
.unwrap();
let output_id = match output_destination {
atoma_types::OutputDestination::Firebase { request_id } => request_id,
atoma_types::OutputDestination::Gateway { gateway_user_id } => gateway_user_id,
};
let model_input = M::Input::try_from((output_id, params))?;
let model_output = self.model.run(model_input)?;
let time_to_generate = model_output.time_to_generate();
let num_input_tokens = model_output.num_input_tokens();
Expand Down Expand Up @@ -145,7 +154,7 @@ impl ModelThreadDispatcher {
#[instrument(skip_all)]
pub(crate) fn start(
config: ModelsConfig,
stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> Result<(Self, Vec<ModelThreadHandle>), ModelThreadError> {
let mut handles = Vec::new();
let mut model_senders = HashMap::new();
Expand Down Expand Up @@ -236,7 +245,7 @@ pub(crate) fn dispatch_model_thread(
model_type: ModelType,
model_config: ModelConfig,
model_receiver: mpsc::Receiver<ModelThreadCommand>,
stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> JoinHandle<Result<(), ModelThreadError>> {
if model_config.device_ids().len() > 1 {
#[cfg(not(feature = "nccl"))]
Expand Down Expand Up @@ -423,7 +432,7 @@ pub(crate) fn spawn_model_thread<M>(
cache_dir: PathBuf,
model_config: ModelConfig,
model_receiver: mpsc::Receiver<ModelThreadCommand>,
stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> JoinHandle<Result<(), ModelThreadError>>
where
M: ModelTrait + Send + 'static,
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::Digest;
use atoma_types::AtomaStreamingData;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand Down Expand Up @@ -48,7 +48,7 @@ impl FalconModel {
dtype: DType,
model_type: ModelType,
tokenizer: Tokenizer,
stream_tx: mpsc::Sender<(Digest, String)>,
stream_tx: mpsc::Sender<AtomaStreamingData>,
) -> Self {
Self {
model,
Expand Down Expand Up @@ -112,7 +112,7 @@ impl ModelTrait for FalconModel {
#[instrument(skip_all)]
fn load(
load_data: Self::LoadData,
stream_tx: mpsc::Sender<(Digest, String)>,
stream_tx: mpsc::Sender<AtomaStreamingData>,
) -> Result<Self, ModelError>
where
Self: Sized,
Expand Down
6 changes: 3 additions & 3 deletions atoma-inference/src/models/candle/flux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
ModelError, ModelTrait,
},
};
use atoma_types::{Digest, PromptParams};
use atoma_types::{AtomaStreamingData, Digest, PromptParams};
use candle::{DType, Device, Module, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::{clip, flux, t5};
Expand Down Expand Up @@ -176,7 +176,7 @@ impl ModelTrait for Flux {

fn load(
load_data: Self::LoadData,
_stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
_stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> Result<Self, ModelError> {
info!("Loading T5 model..");
let start = std::time::Instant::now();
Expand Down Expand Up @@ -379,7 +379,7 @@ impl TryFrom<(Digest, PromptParams)> for FluxInput {
PromptParams::Text2ImagePromptParams(p) => {
let height = p.height().map(|h| h as usize);
let width = p.width().map(|w| w as usize);
let prompt = p.prompt();
let prompt = p.get_input_text(); // TODO: for now we use the raw prompt, but likely to fetch it from an external source in the future
let decode_only = p.decode_only();
Ok(Self {
prompt,
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::Digest;
use atoma_types::AtomaStreamingData;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand Down Expand Up @@ -103,7 +103,7 @@ impl ModelTrait for LlamaModel {
#[instrument(skip_all)]
fn load(
load_data: Self::LoadData,
stream_tx: mpsc::Sender<(Digest, String)>,
stream_tx: mpsc::Sender<AtomaStreamingData>,
) -> Result<Self, ModelError> {
info!("Loading Llama model ...");

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/llama_nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl LlamaNcclWorker {
model_weights_file_paths: &[PathBuf],
tokenizer_file_path: &PathBuf,
device_id: usize,
stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> Result<Self, ModelError> {
let device = CudaDevice::new(rank)?;
// Initialize the Communicator from Nvidia Collective Communication Library. This is for the inter gpu communication.
Expand Down Expand Up @@ -227,7 +227,7 @@ impl ModelTrait for LlamaNcclModel {

fn load(
load_data: Self::LoadData,
stream_tx: tokio::sync::mpsc::Sender<(Digest, String)>,
stream_tx: tokio::sync::mpsc::Sender<AtomaStreamingData>,
) -> Result<Self, ModelError> {
info!("Loading Llama model ...");
let start = Instant::now();
Expand Down
Loading
Loading