From a7c916305bce624a17ea7eeb81a7285231e4622f Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 22 Aug 2024 13:33:46 +0200 Subject: [PATCH] Integrate llm factor into trigger2 Signed-off-by: Ryan Levick --- Cargo.lock | 39 ++++----- crates/factor-llm/Cargo.toml | 8 +- crates/factor-llm/src/host.rs | 4 +- crates/factor-llm/src/lib.rs | 52 +++++++++--- crates/factor-llm/src/spin.rs | 106 +++++++++++++++++++++++++ crates/factor-llm/tests/factor_test.rs | 8 +- crates/llm-local/Cargo.toml | 1 - crates/llm-local/src/lib.rs | 22 ++--- crates/llm-remote-http/Cargo.toml | 3 - crates/llm-remote-http/src/lib.rs | 9 +-- crates/runtime-config/Cargo.toml | 1 + crates/runtime-config/src/lib.rs | 47 +++++++---- crates/trigger2/Cargo.toml | 1 + crates/trigger2/src/cli.rs | 6 +- crates/trigger2/src/factors.rs | 8 ++ 15 files changed, 243 insertions(+), 72 deletions(-) create mode 100644 crates/factor-llm/src/spin.rs diff --git a/Cargo.lock b/Cargo.lock index f2b404a08..3295d890c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2435,20 +2435,6 @@ dependencies = [ "syn 2.0.58", ] -[[package]] -name = "factor-llm" -version = "2.8.0-pre0" -dependencies = [ - "anyhow", - "async-trait", - "spin-factors", - "spin-factors-test", - "spin-locked-app", - "spin-world", - "tokio", - "tracing", -] - [[package]] name = "fallible-iterator" version = "0.2.0" @@ -7647,6 +7633,25 @@ dependencies = [ "spin-key-value-sqlite", ] +[[package]] +name = "spin-factor-llm" +version = "2.8.0-pre0" +dependencies = [ + "anyhow", + "async-trait", + "serde 1.0.197", + "spin-factors", + "spin-factors-test", + "spin-llm-local", + "spin-llm-remote-http", + "spin-locked-app", + "spin-world", + "tokio", + "toml 0.8.14", + "tracing", + "url", +] + [[package]] name = "spin-factor-outbound-http" version = "2.8.0-pre0" @@ -7996,7 +8001,6 @@ dependencies = [ "serde 1.0.197", "spin-common", "spin-core", - "spin-llm", "spin-world", "terminal", "tokenizers", @@ -8010,12 +8014,9 @@ version = "2.8.0-pre0" dependencies = [ "anyhow", "http 0.2.12", - "llm", "reqwest 0.11.27", "serde 1.0.197", "serde_json", - "spin-core", - "spin-llm", "spin-telemetry", "spin-world", "tracing", @@ -8170,6 +8171,7 @@ dependencies = [ "spin-factor-key-value-azure", "spin-factor-key-value-redis", "spin-factor-key-value-spin", + "spin-factor-llm", "spin-factor-outbound-http", "spin-factor-outbound-mqtt", "spin-factor-outbound-mysql", @@ -8484,6 +8486,7 @@ dependencies = [ "spin-componentize", "spin-core", "spin-factor-key-value", + "spin-factor-llm", "spin-factor-outbound-http", "spin-factor-outbound-mqtt", "spin-factor-outbound-mysql", diff --git a/crates/factor-llm/Cargo.toml b/crates/factor-llm/Cargo.toml index bcdc1e81f..e5a26b018 100644 --- a/crates/factor-llm/Cargo.toml +++ b/crates/factor-llm/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "factor-llm" +name = "spin-factor-llm" version.workspace = true authors.workspace = true edition.workspace = true @@ -11,10 +11,16 @@ rust-version.workspace = true [dependencies] anyhow = "1.0" async-trait = "0.1" +serde = "1.0" spin-factors = { path = "../factors" } +spin-llm-local = { path = "../llm-local" } +spin-llm-remote-http = { path = "../llm-remote-http" } spin-locked-app = { path = "../locked-app" } spin-world = { path = "../world" } tracing = { workspace = true } +tokio = { version = "1", features = ["sync"] } +toml = "0.8" +url = "2" [dev-dependencies] spin-factors-test = { path = "../factors-test" } diff --git a/crates/factor-llm/src/host.rs b/crates/factor-llm/src/host.rs index 748f97b1a..af980ad7e 100644 --- a/crates/factor-llm/src/host.rs +++ b/crates/factor-llm/src/host.rs @@ -16,6 +16,8 @@ impl v2::Host for InstanceState { return Err(access_denied_error(&model)); } self.engine + .lock() + .await .infer( model, prompt, @@ -39,7 +41,7 @@ impl v2::Host for InstanceState { if !self.allowed_models.contains(&m) { return Err(access_denied_error(&m)); } - self.engine.generate_embeddings(m, data).await + self.engine.lock().await.generate_embeddings(m, data).await } fn convert_error(&mut self, error: v2::Error) -> anyhow::Result { diff --git a/crates/factor-llm/src/lib.rs b/crates/factor-llm/src/lib.rs index 3e40f36a2..543e59b61 100644 --- a/crates/factor-llm/src/lib.rs +++ b/crates/factor-llm/src/lib.rs @@ -1,4 +1,5 @@ mod host; +pub mod spin; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -11,26 +12,28 @@ use spin_factors::{ use spin_locked_app::MetadataKey; use spin_world::v1::llm::{self as v1}; use spin_world::v2::llm::{self as v2}; +use tokio::sync::Mutex; pub const ALLOWED_MODELS_KEY: MetadataKey> = MetadataKey::new("ai_models"); +/// The factor for LLMs. pub struct LlmFactor { - create_engine: Box Box + Send + Sync>, + default_engine_creator: Box, } impl LlmFactor { - pub fn new(create_engine: F) -> Self - where - F: Fn() -> Box + Send + Sync + 'static, - { + /// Creates a new LLM factor with the given default engine creator. + /// + /// The default engine creator is used to create the engine if no runtime configuration is provided. + pub fn new(default_engine_creator: F) -> Self { Self { - create_engine: Box::new(create_engine), + default_engine_creator: Box::new(default_engine_creator), } } } impl Factor for LlmFactor { - type RuntimeConfig = (); + type RuntimeConfig = RuntimeConfig; type AppState = AppState; type InstanceBuilder = InstanceState; @@ -45,7 +48,7 @@ impl Factor for LlmFactor { fn configure_app( &self, - ctx: ConfigureAppContext, + mut ctx: ConfigureAppContext, ) -> anyhow::Result { let component_allowed_models = ctx .app() @@ -62,7 +65,12 @@ impl Factor for LlmFactor { )) }) .collect::>()?; + let engine = ctx + .take_runtime_config() + .map(|c| c.engine) + .unwrap_or_else(|| self.default_engine_creator.create()); Ok(AppState { + engine, component_allowed_models, }) } @@ -78,25 +86,35 @@ impl Factor for LlmFactor { .get(ctx.app_component().id()) .cloned() .unwrap_or_default(); + let engine = ctx.app_state().engine.clone(); Ok(InstanceState { - engine: (self.create_engine)(), + engine, allowed_models, }) } } +/// The application state for the LLM factor. pub struct AppState { + engine: Arc>, component_allowed_models: HashMap>>, } +/// The instance state for the LLM factor. pub struct InstanceState { - engine: Box, + engine: Arc>, pub allowed_models: Arc>, } +/// The runtime configuration for the LLM factor. +pub struct RuntimeConfig { + engine: Arc>, +} + impl SelfInstanceBuilder for InstanceState {} +/// The interface for a language model engine. #[async_trait] pub trait LlmEngine: Send + Sync { async fn infer( @@ -112,3 +130,17 @@ pub trait LlmEngine: Send + Sync { data: Vec, ) -> Result; } + +/// A creator for an LLM engine. +pub trait LlmEngineCreator: Send + Sync { + fn create(&self) -> Arc>; +} + +impl LlmEngineCreator for F +where + F: Fn() -> Arc> + Send + Sync, +{ + fn create(&self) -> Arc> { + self() + } +} diff --git a/crates/factor-llm/src/spin.rs b/crates/factor-llm/src/spin.rs new file mode 100644 index 000000000..6ebd7a706 --- /dev/null +++ b/crates/factor-llm/src/spin.rs @@ -0,0 +1,106 @@ +use std::path::PathBuf; +use std::sync::Arc; + +pub use spin_llm_local::LocalLlmEngine; + +use spin_llm_remote_http::RemoteHttpLlmEngine; +use spin_world::async_trait; +use spin_world::v1::llm::{self as v1}; +use spin_world::v2::llm::{self as v2}; +use tokio::sync::Mutex; +use url::Url; + +use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig}; + +#[async_trait] +impl LlmEngine for LocalLlmEngine { + async fn infer( + &mut self, + model: v1::InferencingModel, + prompt: String, + params: v2::InferencingParams, + ) -> Result { + self.infer(model, prompt, params).await + } + + async fn generate_embeddings( + &mut self, + model: v2::EmbeddingModel, + data: Vec, + ) -> Result { + self.generate_embeddings(model, data).await + } +} + +#[async_trait] +impl LlmEngine for RemoteHttpLlmEngine { + async fn infer( + &mut self, + model: v1::InferencingModel, + prompt: String, + params: v2::InferencingParams, + ) -> Result { + self.infer(model, prompt, params).await + } + + async fn generate_embeddings( + &mut self, + model: v2::EmbeddingModel, + data: Vec, + ) -> Result { + self.generate_embeddings(model, data).await + } +} + +pub fn runtime_config_from_toml( + table: &toml::Table, + state_dir: PathBuf, + use_gpu: bool, +) -> anyhow::Result> { + let Some(value) = table.get("llm_compute") else { + return Ok(None); + }; + let config: LlmCompute = value.clone().try_into()?; + + Ok(Some(RuntimeConfig { + engine: config.into_engine(state_dir, use_gpu), + })) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum LlmCompute { + Spin, + RemoteHttp(RemoteHttpCompute), +} + +impl LlmCompute { + fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc> { + match self { + LlmCompute::Spin => default_engine_creator(state_dir, use_gpu).create(), + LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new( + config.url, + config.auth_token, + ))), + } + } +} + +#[derive(Debug, serde::Deserialize)] +pub struct RemoteHttpCompute { + url: Url, + auth_token: String, +} + +/// The default engine creator for the LLM factor when used in the Spin CLI. +pub fn default_engine_creator( + state_dir: PathBuf, + use_gpu: bool, +) -> impl LlmEngineCreator + 'static { + move || { + Arc::new(Mutex::new(LocalLlmEngine::new( + state_dir.join("ai-models"), + use_gpu, + ))) as _ + } +} diff --git a/crates/factor-llm/tests/factor_test.rs b/crates/factor-llm/tests/factor_test.rs index 4504238b5..a0c4e988a 100644 --- a/crates/factor-llm/tests/factor_test.rs +++ b/crates/factor-llm/tests/factor_test.rs @@ -1,10 +1,12 @@ use std::collections::HashSet; +use std::sync::Arc; -use factor_llm::{LlmEngine, LlmFactor}; +use spin_factor_llm::{LlmEngine, LlmFactor}; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; use spin_world::v1::llm::{self as v1}; use spin_world::v2::llm::{self as v2, Host}; +use tokio::sync::Mutex; #[derive(RuntimeFactors)] struct TestFactors { @@ -37,9 +39,9 @@ async fn llm_works() -> anyhow::Result<()> { }); let factors = TestFactors { llm: LlmFactor::new(move || { - Box::new(FakeLLm { + Arc::new(Mutex::new(FakeLLm { handle: handle.clone(), - }) as _ + })) as _ }), }; let env = TestEnvironment::new(factors).extend_manifest(toml! { diff --git a/crates/llm-local/Cargo.toml b/crates/llm-local/Cargo.toml index b0d4ea397..5b7331642 100644 --- a/crates/llm-local/Cargo.toml +++ b/crates/llm-local/Cargo.toml @@ -20,7 +20,6 @@ safetensors = "0.3.3" serde = { version = "1.0.150", features = ["derive"] } spin-common = { path = "../common" } spin-core = { path = "../core" } -spin-llm = { path = "../llm" } spin-world = { path = "../world" } terminal = { path = "../terminal" } tokenizers = "0.13.4" diff --git a/crates/llm-local/src/lib.rs b/crates/llm-local/src/lib.rs index f5d00c7a1..cf0b9f992 100644 --- a/crates/llm-local/src/lib.rs +++ b/crates/llm-local/src/lib.rs @@ -10,8 +10,6 @@ use llm::{ }; use rand::SeedableRng; use spin_common::ui::quoted_path; -use spin_core::async_trait; -use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2}; use spin_world::v2::llm::{self as wasi_llm}; use std::{ collections::hash_map::Entry, @@ -23,6 +21,8 @@ use std::{ use tokenizers::PaddingParams; use tracing::{instrument, Level}; +const MODEL_ALL_MINILM_L6_V2: &str = "all-minilm-l6-v2"; + #[derive(Clone)] pub struct LocalLlmEngine { registry: PathBuf, @@ -31,10 +31,9 @@ pub struct LocalLlmEngine { embeddings_models: HashMap>, } -#[async_trait] -impl LlmEngine for LocalLlmEngine { +impl LocalLlmEngine { #[instrument(name = "spin_llm_local.infer", skip(self, prompt), err(level = Level::INFO))] - async fn infer( + pub async fn infer( &mut self, model: wasi_llm::InferencingModel, prompt: String, @@ -94,7 +93,7 @@ impl LlmEngine for LocalLlmEngine { } #[instrument(name = "spin_llm_local.generate_embeddings", skip(self, data), err(level = Level::INFO))] - async fn generate_embeddings( + pub async fn generate_embeddings( &mut self, model: wasi_llm::EmbeddingModel, data: Vec, @@ -107,18 +106,13 @@ impl LlmEngine for LocalLlmEngine { } impl LocalLlmEngine { - pub async fn new(registry: PathBuf, use_gpu: bool) -> Self { - let mut engine = Self { + pub fn new(registry: PathBuf, use_gpu: bool) -> Self { + Self { registry, use_gpu, inferencing_models: Default::default(), embeddings_models: Default::default(), - }; - - let _ = engine.inferencing_model("llama2-chat".into()).await; - let _ = engine.embeddings_model(MODEL_ALL_MINILM_L6_V2.into()).await; - - engine + } } /// Get embeddings model from cache or load from disk diff --git a/crates/llm-remote-http/Cargo.toml b/crates/llm-remote-http/Cargo.toml index 3a9bb8e12..af05459e5 100644 --- a/crates/llm-remote-http/Cargo.toml +++ b/crates/llm-remote-http/Cargo.toml @@ -7,11 +7,8 @@ edition = { workspace = true } [dependencies] anyhow = "1.0" http = "0.2" -llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", default-features = false } serde = { version = "1.0.150", features = ["derive"] } serde_json = "1.0" -spin-core = { path = "../core" } -spin-llm = { path = "../llm" } spin-telemetry = { path = "../telemetry" } spin-world = { path = "../world" } reqwest = { version = "0.11", features = ["gzip", "json"] } diff --git a/crates/llm-remote-http/src/lib.rs b/crates/llm-remote-http/src/lib.rs index 4987b14b0..4a0039539 100644 --- a/crates/llm-remote-http/src/lib.rs +++ b/crates/llm-remote-http/src/lib.rs @@ -5,8 +5,6 @@ use reqwest::{ }; use serde::{Deserialize, Serialize}; use serde_json::json; -use spin_core::async_trait; -use spin_llm::LlmEngine; use spin_world::v2::llm::{self as wasi_llm}; use tracing::{instrument, Level}; @@ -53,10 +51,9 @@ struct EmbeddingResponseBody { usage: EmbeddingUsage, } -#[async_trait] -impl LlmEngine for RemoteHttpLlmEngine { +impl RemoteHttpLlmEngine { #[instrument(name = "spin_llm_remote_http.infer", skip(self, prompt), err(level = Level::INFO), fields(otel.kind = "client"))] - async fn infer( + pub async fn infer( &mut self, model: wasi_llm::InferencingModel, prompt: String, @@ -119,7 +116,7 @@ impl LlmEngine for RemoteHttpLlmEngine { } #[instrument(name = "spin_llm_remote_http.generate_embeddings", skip(self, data), err(level = Level::INFO), fields(otel.kind = "client"))] - async fn generate_embeddings( + pub async fn generate_embeddings( &mut self, model: wasi_llm::EmbeddingModel, data: Vec, diff --git a/crates/runtime-config/Cargo.toml b/crates/runtime-config/Cargo.toml index 39f5b6f7a..267dfdec7 100644 --- a/crates/runtime-config/Cargo.toml +++ b/crates/runtime-config/Cargo.toml @@ -15,6 +15,7 @@ spin-factor-key-value = { path = "../factor-key-value" } spin-factor-key-value-spin = { path = "../factor-key-value-spin" } spin-factor-key-value-redis = { path = "../factor-key-value-redis" } spin-factor-key-value-azure = { path = "../factor-key-value-azure" } +spin-factor-llm = { path = "../factor-llm" } spin-factor-outbound-http = { path = "../factor-outbound-http" } spin-factor-outbound-mqtt = { path = "../factor-outbound-mqtt" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index d5b7d88c9..77ca2a8ec 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context as _; use spin_factor_key_value::runtime_config::spin::{self as key_value, MakeKeyValueStore}; use spin_factor_key_value::{DefaultLabelResolver as _, KeyValueFactor}; +use spin_factor_llm::{spin as llm, LlmFactor}; use spin_factor_outbound_http::OutboundHttpFactor; use spin_factor_outbound_mqtt::OutboundMqttFactor; use spin_factor_outbound_mysql::OutboundMysqlFactor; @@ -39,13 +40,17 @@ where for<'a> >>::Error: Into, { /// Creates a new resolved runtime configuration from a runtime config source TOML file. - pub fn from_file(runtime_config_path: &Path, state_dir: Option<&str>) -> anyhow::Result { + pub fn from_file( + runtime_config_path: &Path, + state_dir: Option<&str>, + use_gpu: bool, + ) -> anyhow::Result { let tls_resolver = SpinTlsRuntimeConfig::new(runtime_config_path); - let key_value_config_resolver = - key_value_config_resolver(PathBuf::from(state_dir.unwrap_or(DEFAULT_STATE_DIR))); + let state_dir = PathBuf::from(state_dir.unwrap_or(DEFAULT_STATE_DIR)); + let key_value_config_resolver = key_value_config_resolver(state_dir.clone()); - let sqlite_config_resolver = - sqlite_config_resolver(state_dir).context("failed to resolve sqlite runtime config")?; + let sqlite_config_resolver = sqlite_config_resolver(state_dir.clone()) + .context("failed to resolve sqlite runtime config")?; let file = std::fs::read_to_string(runtime_config_path).with_context(|| { format!( @@ -61,9 +66,11 @@ where })?; let runtime_config: T = TomlRuntimeConfigSource::new( &toml, + state_dir, &key_value_config_resolver, &tls_resolver, &sqlite_config_resolver, + use_gpu, ) .try_into() .map_err(Into::into)?; @@ -99,12 +106,11 @@ where impl ResolvedRuntimeConfig { pub fn default(state_dir: Option<&str>) -> Self { + let state_dir = state_dir.unwrap_or(DEFAULT_STATE_DIR); Self { - sqlite_resolver: sqlite_config_resolver(state_dir) + sqlite_resolver: sqlite_config_resolver(PathBuf::from(state_dir)) .expect("failed to resolve sqlite runtime config"), - key_value_resolver: key_value_config_resolver(PathBuf::from( - state_dir.unwrap_or(DEFAULT_STATE_DIR), - )), + key_value_resolver: key_value_config_resolver(PathBuf::from(state_dir)), runtime_config: Default::default(), } } @@ -113,23 +119,29 @@ impl ResolvedRuntimeConfig { /// The TOML based runtime configuration source Spin CLI. pub struct TomlRuntimeConfigSource<'a> { table: TomlKeyTracker<'a>, + state_dir: PathBuf, key_value: &'a key_value::RuntimeConfigResolver, tls: &'a SpinTlsRuntimeConfig, sqlite: &'a sqlite::RuntimeConfigResolver, + use_gpu: bool, } impl<'a> TomlRuntimeConfigSource<'a> { pub fn new( table: &'a toml::Table, + state_dir: PathBuf, key_value: &'a key_value::RuntimeConfigResolver, tls: &'a SpinTlsRuntimeConfig, sqlite: &'a sqlite::RuntimeConfigResolver, + use_gpu: bool, ) -> Self { Self { table: TomlKeyTracker::new(table), + state_dir, key_value, tls, sqlite, + use_gpu, } } } @@ -173,6 +185,16 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource< } } +impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_> { + fn get_runtime_config(&mut self) -> anyhow::Result> { + Ok(llm::runtime_config_from_toml( + self.table.as_ref(), + self.state_dir.clone(), + self.use_gpu, + )?) + } +} + impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_> { fn get_runtime_config(&mut self) -> anyhow::Result> { Ok(None) @@ -251,14 +273,11 @@ pub fn key_value_config_resolver( /// The sqlite runtime configuration resolver. /// /// Takes a base path to the state directory. -fn sqlite_config_resolver( - state_dir: Option<&str>, -) -> anyhow::Result { - let default_database_dir = PathBuf::from(state_dir.unwrap_or(DEFAULT_STATE_DIR)); +fn sqlite_config_resolver(state_dir: PathBuf) -> anyhow::Result { let local_database_dir = std::env::current_dir().context("failed to get current working directory")?; Ok(sqlite::RuntimeConfigResolver::new( - default_database_dir, + state_dir, local_database_dir, )) } diff --git a/crates/trigger2/Cargo.toml b/crates/trigger2/Cargo.toml index a87d63b8c..820eba767 100644 --- a/crates/trigger2/Cargo.toml +++ b/crates/trigger2/Cargo.toml @@ -23,6 +23,7 @@ spin-componentize = { path = "../componentize" } spin-core = { path = "../core" } spin-factor-key-value = { path = "../factor-key-value" } spin-factor-outbound-http = { path = "../factor-outbound-http" } +spin-factor-llm = { path = "../factor-llm" } spin-factor-outbound-mqtt = { path = "../factor-outbound-mqtt" } spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factor-outbound-pg = { path = "../factor-outbound-pg" } diff --git a/crates/trigger2/src/cli.rs b/crates/trigger2/src/cli.rs index 2aab88adb..9bc787cef 100644 --- a/crates/trigger2/src/cli.rs +++ b/crates/trigger2/src/cli.rs @@ -10,7 +10,7 @@ use spin_common::ui::quoted_path; use spin_common::url::parse_file_url; use spin_common::{arg_parser::parse_kv, sloth}; use spin_factors_executor::{ComponentLoader, FactorsExecutor}; -use spin_runtime_config::ResolvedRuntimeConfig; +use spin_runtime_config::{ResolvedRuntimeConfig, DEFAULT_STATE_DIR}; use crate::factors::{TriggerFactors, TriggerFactorsRuntimeConfig}; use crate::stdio::{FollowComponents, StdioLoggingExecutorHooks}; @@ -304,11 +304,13 @@ impl TriggerAppBuilder { }; self.trigger.add_to_linker(core_engine_builder.linker())?; + let use_gpu = true; let runtime_config = match options.runtime_config_file { Some(runtime_config_path) => { ResolvedRuntimeConfig::::from_file( runtime_config_path, options.state_dir, + use_gpu, )? } None => ResolvedRuntimeConfig::default(options.state_dir), @@ -319,10 +321,12 @@ impl TriggerAppBuilder { .await?; let factors = TriggerFactors::new( + options.state_dir.unwrap_or(DEFAULT_STATE_DIR), self.working_dir.clone(), options.allow_transient_write, runtime_config.key_value_resolver, runtime_config.sqlite_resolver, + use_gpu, ); // TODO: move these into Factor methods/constructors diff --git a/crates/trigger2/src/factors.rs b/crates/trigger2/src/factors.rs index f7e805f71..6274f6bcf 100644 --- a/crates/trigger2/src/factors.rs +++ b/crates/trigger2/src/factors.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use spin_factor_key_value::KeyValueFactor; +use spin_factor_llm::LlmFactor; use spin_factor_outbound_http::OutboundHttpFactor; use spin_factor_outbound_mqtt::{NetworkedMqttClient, OutboundMqttFactor}; use spin_factor_outbound_mysql::OutboundMysqlFactor; @@ -25,14 +26,17 @@ pub struct TriggerFactors { pub mqtt: OutboundMqttFactor, pub pg: OutboundPgFactor, pub mysql: OutboundMysqlFactor, + pub llm: LlmFactor, } impl TriggerFactors { pub fn new( + state_dir: impl Into, working_dir: impl Into, allow_transient_writes: bool, default_key_value_label_resolver: impl spin_factor_key_value::DefaultLabelResolver + 'static, default_sqlite_label_resolver: impl spin_factor_sqlite::DefaultLabelResolver + 'static, + use_gpu: bool, ) -> Self { Self { wasi: wasi_factor(working_dir, allow_transient_writes), @@ -45,6 +49,10 @@ impl TriggerFactors { mqtt: OutboundMqttFactor::new(NetworkedMqttClient::creator()), pg: OutboundPgFactor::new(), mysql: OutboundMysqlFactor::new(), + llm: LlmFactor::new(spin_factor_llm::spin::default_engine_creator( + state_dir.into(), + use_gpu, + )), } } }