-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Factors] Integrate llm factor #2742
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mod host; | ||
pub mod spin; | ||
|
||
use std::collections::{HashMap, HashSet}; | ||
use std::sync::Arc; | ||
|
@@ -11,26 +12,28 @@ use spin_factors::{ | |
use spin_locked_app::MetadataKey; | ||
use spin_world::v1::llm::{self as v1}; | ||
use spin_world::v2::llm::{self as v2}; | ||
use tokio::sync::Mutex; | ||
|
||
pub const ALLOWED_MODELS_KEY: MetadataKey<Vec<String>> = MetadataKey::new("ai_models"); | ||
|
||
/// The factor for LLMs. | ||
pub struct LlmFactor { | ||
create_engine: Box<dyn Fn() -> Box<dyn LlmEngine> + Send + Sync>, | ||
default_engine_creator: Box<dyn LlmEngineCreator>, | ||
Comment on lines
20
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of scope for this PR but I'd like to standardize on the approach used with e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I'd personally like to move away from generics and towards using dynamic dispatch like we do here. |
||
} | ||
|
||
impl LlmFactor { | ||
pub fn new<F>(create_engine: F) -> Self | ||
where | ||
F: Fn() -> Box<dyn LlmEngine> + Send + Sync + 'static, | ||
{ | ||
/// Creates a new LLM factor with the given default engine creator. | ||
/// | ||
/// The default engine creator is used to create the engine if no runtime configuration is provided. | ||
pub fn new<F: LlmEngineCreator + 'static>(default_engine_creator: F) -> Self { | ||
Self { | ||
create_engine: Box::new(create_engine), | ||
default_engine_creator: Box::new(default_engine_creator), | ||
} | ||
} | ||
} | ||
|
||
impl Factor for LlmFactor { | ||
type RuntimeConfig = (); | ||
type RuntimeConfig = RuntimeConfig; | ||
type AppState = AppState; | ||
type InstanceBuilder = InstanceState; | ||
|
||
|
@@ -45,7 +48,7 @@ impl Factor for LlmFactor { | |
|
||
fn configure_app<T: RuntimeFactors>( | ||
&self, | ||
ctx: ConfigureAppContext<T, Self>, | ||
mut ctx: ConfigureAppContext<T, Self>, | ||
) -> anyhow::Result<Self::AppState> { | ||
let component_allowed_models = ctx | ||
.app() | ||
|
@@ -62,7 +65,12 @@ impl Factor for LlmFactor { | |
)) | ||
}) | ||
.collect::<anyhow::Result<_>>()?; | ||
let engine = ctx | ||
.take_runtime_config() | ||
.map(|c| c.engine) | ||
.unwrap_or_else(|| self.default_engine_creator.create()); | ||
Ok(AppState { | ||
engine, | ||
component_allowed_models, | ||
}) | ||
} | ||
|
@@ -78,25 +86,35 @@ impl Factor for LlmFactor { | |
.get(ctx.app_component().id()) | ||
.cloned() | ||
.unwrap_or_default(); | ||
let engine = ctx.app_state().engine.clone(); | ||
|
||
Ok(InstanceState { | ||
engine: (self.create_engine)(), | ||
engine, | ||
allowed_models, | ||
}) | ||
} | ||
} | ||
|
||
/// The application state for the LLM factor. | ||
pub struct AppState { | ||
engine: Arc<Mutex<dyn LlmEngine>>, | ||
component_allowed_models: HashMap<String, Arc<HashSet<String>>>, | ||
} | ||
|
||
/// The instance state for the LLM factor. | ||
pub struct InstanceState { | ||
engine: Box<dyn LlmEngine>, | ||
engine: Arc<Mutex<dyn LlmEngine>>, | ||
pub allowed_models: Arc<HashSet<String>>, | ||
} | ||
|
||
/// The runtime configuration for the LLM factor. | ||
pub struct RuntimeConfig { | ||
engine: Arc<Mutex<dyn LlmEngine>>, | ||
} | ||
|
||
impl SelfInstanceBuilder for InstanceState {} | ||
|
||
/// The interface for a language model engine. | ||
#[async_trait] | ||
pub trait LlmEngine: Send + Sync { | ||
async fn infer( | ||
|
@@ -112,3 +130,17 @@ pub trait LlmEngine: Send + Sync { | |
data: Vec<String>, | ||
) -> Result<v2::EmbeddingsResult, v2::Error>; | ||
} | ||
|
||
/// A creator for an LLM engine. | ||
pub trait LlmEngineCreator: Send + Sync { | ||
fn create(&self) -> Arc<Mutex<dyn LlmEngine>>; | ||
} | ||
|
||
impl<F> LlmEngineCreator for F | ||
where | ||
F: Fn() -> Arc<Mutex<dyn LlmEngine>> + Send + Sync, | ||
{ | ||
fn create(&self) -> Arc<Mutex<dyn LlmEngine>> { | ||
self() | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
use std::path::PathBuf; | ||
use std::sync::Arc; | ||
|
||
pub use spin_llm_local::LocalLlmEngine; | ||
|
||
use spin_llm_remote_http::RemoteHttpLlmEngine; | ||
use spin_world::async_trait; | ||
use spin_world::v1::llm::{self as v1}; | ||
use spin_world::v2::llm::{self as v2}; | ||
use tokio::sync::Mutex; | ||
use url::Url; | ||
|
||
use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig}; | ||
|
||
#[async_trait] | ||
impl LlmEngine for LocalLlmEngine { | ||
async fn infer( | ||
&mut self, | ||
model: v1::InferencingModel, | ||
prompt: String, | ||
params: v2::InferencingParams, | ||
) -> Result<v2::InferencingResult, v2::Error> { | ||
self.infer(model, prompt, params).await | ||
} | ||
|
||
async fn generate_embeddings( | ||
&mut self, | ||
model: v2::EmbeddingModel, | ||
data: Vec<String>, | ||
) -> Result<v2::EmbeddingsResult, v2::Error> { | ||
self.generate_embeddings(model, data).await | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl LlmEngine for RemoteHttpLlmEngine { | ||
async fn infer( | ||
&mut self, | ||
model: v1::InferencingModel, | ||
prompt: String, | ||
params: v2::InferencingParams, | ||
) -> Result<v2::InferencingResult, v2::Error> { | ||
self.infer(model, prompt, params).await | ||
} | ||
|
||
async fn generate_embeddings( | ||
&mut self, | ||
model: v2::EmbeddingModel, | ||
data: Vec<String>, | ||
) -> Result<v2::EmbeddingsResult, v2::Error> { | ||
self.generate_embeddings(model, data).await | ||
} | ||
} | ||
|
||
pub fn runtime_config_from_toml( | ||
table: &toml::Table, | ||
state_dir: PathBuf, | ||
use_gpu: bool, | ||
) -> anyhow::Result<Option<RuntimeConfig>> { | ||
let Some(value) = table.get("llm_compute") else { | ||
return Ok(None); | ||
}; | ||
let config: LlmCompute = value.clone().try_into()?; | ||
|
||
Ok(Some(RuntimeConfig { | ||
engine: config.into_engine(state_dir, use_gpu), | ||
})) | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
#[serde(rename_all = "snake_case", tag = "type")] | ||
pub enum LlmCompute { | ||
Spin, | ||
RemoteHttp(RemoteHttpCompute), | ||
} | ||
|
||
impl LlmCompute { | ||
fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> { | ||
match self { | ||
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu).create(), | ||
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new( | ||
config.url, | ||
config.auth_token, | ||
))), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, serde::Deserialize)] | ||
pub struct RemoteHttpCompute { | ||
url: Url, | ||
auth_token: String, | ||
} | ||
|
||
/// The default engine creator for the LLM factor when used in the Spin CLI. | ||
pub fn default_engine_creator( | ||
state_dir: PathBuf, | ||
use_gpu: bool, | ||
) -> impl LlmEngineCreator + 'static { | ||
move || { | ||
Arc::new(Mutex::new(LocalLlmEngine::new( | ||
state_dir.join("ai-models"), | ||
use_gpu, | ||
))) as _ | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are feature flags in the root
spin-cli
and oldspin-trigger
andspin-trigger-http
crates that control dependencies for this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to do some small refactorings in a follow up that will make this easier to bring back.