Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Factors] Integrate llm factor #2742

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion crates/factor-llm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "factor-llm"
name = "spin-factor-llm"
version.workspace = true
authors.workspace = true
edition.workspace = true
Expand All @@ -11,10 +11,16 @@ rust-version.workspace = true
[dependencies]
anyhow = "1.0"
async-trait = "0.1"
serde = "1.0"
spin-factors = { path = "../factors" }
spin-llm-local = { path = "../llm-local" }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are feature flags in the root spin-cli and old spin-trigger and spin-trigger-http crates that control dependencies for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to do some small refactorings in a follow up that will make this easier to bring back.

spin-llm-remote-http = { path = "../llm-remote-http" }
spin-locked-app = { path = "../locked-app" }
spin-world = { path = "../world" }
tracing = { workspace = true }
tokio = { version = "1", features = ["sync"] }
toml = "0.8"
url = "2"

[dev-dependencies]
spin-factors-test = { path = "../factors-test" }
Expand Down
4 changes: 3 additions & 1 deletion crates/factor-llm/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ impl v2::Host for InstanceState {
return Err(access_denied_error(&model));
}
self.engine
.lock()
.await
.infer(
model,
prompt,
Expand All @@ -39,7 +41,7 @@ impl v2::Host for InstanceState {
if !self.allowed_models.contains(&m) {
return Err(access_denied_error(&m));
}
self.engine.generate_embeddings(m, data).await
self.engine.lock().await.generate_embeddings(m, data).await
}

fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
Expand Down
52 changes: 42 additions & 10 deletions crates/factor-llm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod host;
pub mod spin;

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
Expand All @@ -11,26 +12,28 @@ use spin_factors::{
use spin_locked_app::MetadataKey;
use spin_world::v1::llm::{self as v1};
use spin_world::v2::llm::{self as v2};
use tokio::sync::Mutex;

pub const ALLOWED_MODELS_KEY: MetadataKey<Vec<String>> = MetadataKey::new("ai_models");

/// The factor for LLMs.
pub struct LlmFactor {
create_engine: Box<dyn Fn() -> Box<dyn LlmEngine> + Send + Sync>,
default_engine_creator: Box<dyn LlmEngineCreator>,
Comment on lines 20 to +21
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR but I'd like to standardize on the approach used with e.g. OutboundMysqlFactor<C = MysqlClient> for this sort of thing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'd personally like to move away from generics and towards using dynamic dispatch like we do here.

}

impl LlmFactor {
pub fn new<F>(create_engine: F) -> Self
where
F: Fn() -> Box<dyn LlmEngine> + Send + Sync + 'static,
{
/// Creates a new LLM factor with the given default engine creator.
///
/// The default engine creator is used to create the engine if no runtime configuration is provided.
pub fn new<F: LlmEngineCreator + 'static>(default_engine_creator: F) -> Self {
Self {
create_engine: Box::new(create_engine),
default_engine_creator: Box::new(default_engine_creator),
}
}
}

impl Factor for LlmFactor {
type RuntimeConfig = ();
type RuntimeConfig = RuntimeConfig;
type AppState = AppState;
type InstanceBuilder = InstanceState;

Expand All @@ -45,7 +48,7 @@ impl Factor for LlmFactor {

fn configure_app<T: RuntimeFactors>(
&self,
ctx: ConfigureAppContext<T, Self>,
mut ctx: ConfigureAppContext<T, Self>,
) -> anyhow::Result<Self::AppState> {
let component_allowed_models = ctx
.app()
Expand All @@ -62,7 +65,12 @@ impl Factor for LlmFactor {
))
})
.collect::<anyhow::Result<_>>()?;
let engine = ctx
.take_runtime_config()
.map(|c| c.engine)
.unwrap_or_else(|| self.default_engine_creator.create());
Ok(AppState {
engine,
component_allowed_models,
})
}
Expand All @@ -78,25 +86,35 @@ impl Factor for LlmFactor {
.get(ctx.app_component().id())
.cloned()
.unwrap_or_default();
let engine = ctx.app_state().engine.clone();

Ok(InstanceState {
engine: (self.create_engine)(),
engine,
allowed_models,
})
}
}

/// The application state for the LLM factor.
pub struct AppState {
engine: Arc<Mutex<dyn LlmEngine>>,
component_allowed_models: HashMap<String, Arc<HashSet<String>>>,
}

/// The instance state for the LLM factor.
pub struct InstanceState {
engine: Box<dyn LlmEngine>,
engine: Arc<Mutex<dyn LlmEngine>>,
pub allowed_models: Arc<HashSet<String>>,
}

/// The runtime configuration for the LLM factor.
pub struct RuntimeConfig {
engine: Arc<Mutex<dyn LlmEngine>>,
}

impl SelfInstanceBuilder for InstanceState {}

/// The interface for a language model engine.
#[async_trait]
pub trait LlmEngine: Send + Sync {
async fn infer(
Expand All @@ -112,3 +130,17 @@ pub trait LlmEngine: Send + Sync {
data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error>;
}

/// A creator for an LLM engine.
pub trait LlmEngineCreator: Send + Sync {
fn create(&self) -> Arc<Mutex<dyn LlmEngine>>;
}

impl<F> LlmEngineCreator for F
where
F: Fn() -> Arc<Mutex<dyn LlmEngine>> + Send + Sync,
{
fn create(&self) -> Arc<Mutex<dyn LlmEngine>> {
self()
}
}
106 changes: 106 additions & 0 deletions crates/factor-llm/src/spin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::path::PathBuf;
use std::sync::Arc;

pub use spin_llm_local::LocalLlmEngine;

use spin_llm_remote_http::RemoteHttpLlmEngine;
use spin_world::async_trait;
use spin_world::v1::llm::{self as v1};
use spin_world::v2::llm::{self as v2};
use tokio::sync::Mutex;
use url::Url;

use crate::{LlmEngine, LlmEngineCreator, RuntimeConfig};

#[async_trait]
impl LlmEngine for LocalLlmEngine {
async fn infer(
&mut self,
model: v1::InferencingModel,
prompt: String,
params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error> {
self.infer(model, prompt, params).await
}

async fn generate_embeddings(
&mut self,
model: v2::EmbeddingModel,
data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error> {
self.generate_embeddings(model, data).await
}
}

#[async_trait]
impl LlmEngine for RemoteHttpLlmEngine {
async fn infer(
&mut self,
model: v1::InferencingModel,
prompt: String,
params: v2::InferencingParams,
) -> Result<v2::InferencingResult, v2::Error> {
self.infer(model, prompt, params).await
}

async fn generate_embeddings(
&mut self,
model: v2::EmbeddingModel,
data: Vec<String>,
) -> Result<v2::EmbeddingsResult, v2::Error> {
self.generate_embeddings(model, data).await
}
}

pub fn runtime_config_from_toml(
table: &toml::Table,
state_dir: PathBuf,
use_gpu: bool,
) -> anyhow::Result<Option<RuntimeConfig>> {
let Some(value) = table.get("llm_compute") else {
return Ok(None);
};
let config: LlmCompute = value.clone().try_into()?;

Ok(Some(RuntimeConfig {
engine: config.into_engine(state_dir, use_gpu),
}))
}

#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum LlmCompute {
Spin,
RemoteHttp(RemoteHttpCompute),
}

impl LlmCompute {
fn into_engine(self, state_dir: PathBuf, use_gpu: bool) -> Arc<Mutex<dyn LlmEngine>> {
match self {
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu).create(),
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
config.url,
config.auth_token,
))),
}
}
}

#[derive(Debug, serde::Deserialize)]
pub struct RemoteHttpCompute {
url: Url,
auth_token: String,
}

/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: PathBuf,
use_gpu: bool,
) -> impl LlmEngineCreator + 'static {
move || {
Arc::new(Mutex::new(LocalLlmEngine::new(
state_dir.join("ai-models"),
use_gpu,
))) as _
}
}
8 changes: 5 additions & 3 deletions crates/factor-llm/tests/factor_test.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashSet;
use std::sync::Arc;

use factor_llm::{LlmEngine, LlmFactor};
use spin_factor_llm::{LlmEngine, LlmFactor};
use spin_factors::{anyhow, RuntimeFactors};
use spin_factors_test::{toml, TestEnvironment};
use spin_world::v1::llm::{self as v1};
use spin_world::v2::llm::{self as v2, Host};
use tokio::sync::Mutex;

#[derive(RuntimeFactors)]
struct TestFactors {
Expand Down Expand Up @@ -37,9 +39,9 @@ async fn llm_works() -> anyhow::Result<()> {
});
let factors = TestFactors {
llm: LlmFactor::new(move || {
Box::new(FakeLLm {
Arc::new(Mutex::new(FakeLLm {
handle: handle.clone(),
}) as _
})) as _
}),
};
let env = TestEnvironment::new(factors).extend_manifest(toml! {
Expand Down
1 change: 0 additions & 1 deletion crates/llm-local/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ safetensors = "0.3.3"
serde = { version = "1.0.150", features = ["derive"] }
spin-common = { path = "../common" }
spin-core = { path = "../core" }
spin-llm = { path = "../llm" }
spin-world = { path = "../world" }
terminal = { path = "../terminal" }
tokenizers = "0.13.4"
Expand Down
Loading
Loading