From e2195a96d1dabb03123474d3e169b393a6303c14 Mon Sep 17 00:00:00 2001 From: Sean Hatfield Date: Wed, 25 Sep 2024 13:30:20 -0700 Subject: [PATCH] Workspace agent autoselection (#2357) * refactor agent to add fallback to workspace, then to chat provider/model * commenting update logic for bedrock and fireworks fallbacks --------- Co-authored-by: timothycarambat --- server/utils/agents/index.js | 78 +++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/server/utils/agents/index.js b/server/utils/agents/index.js index 9fdfdd1baa..389d9d71b0 100644 --- a/server/utils/agents/index.js +++ b/server/utils/agents/index.js @@ -16,6 +16,7 @@ class AgentHandler { lmstudio: "LMSTUDIO_MODEL_PREF", textgenwebui: null, // does not even use `model` in API req "generic-openai": "GENERIC_OPEN_AI_MODEL_PREF", + bedrock: "AWS_BEDROCK_LLM_MODEL_PREFERENCE", }; invocation = null; aibitat = null; @@ -149,20 +150,16 @@ class AgentHandler { if ( !process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID || !process.env.AWS_BEDROCK_LLM_ACCESS_KEY || - !process.env.AWS_BEDROCK_LLM_REGION || - !process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE + !process.env.AWS_BEDROCK_LLM_REGION ) throw new Error( - "AWS Bedrock Access Keys, model and region must be provided to use agents." + "AWS Bedrock Access Keys and region must be provided to use agents." ); break; case "fireworksai": - if ( - !process.env.FIREWORKS_AI_LLM_API_KEY || - !process.env.FIREWORKS_AI_LLM_MODEL_PREF - ) + if (!process.env.FIREWORKS_AI_LLM_API_KEY) throw new Error( - "FireworksAI API Key & model must be provided to use agents." + "FireworksAI API Key must be provided to use agents." ); break; @@ -173,8 +170,8 @@ class AgentHandler { } } - providerDefault() { - switch (this.provider) { + providerDefault(provider = this.provider) { + switch (provider) { case "openai": return "gpt-4o"; case "anthropic": @@ -214,6 +211,32 @@ class AgentHandler { } } + #getFallbackProvider() { + // First, fallback to the workspace chat provider and model if they exist + if ( + this.invocation.workspace.chatProvider && + this.invocation.workspace.chatModel + ) { + return { + provider: this.invocation.workspace.chatProvider, + model: this.invocation.workspace.chatModel, + }; + } + + // If workspace does not have chat provider and model fallback + // to system provider and try to load provider default model + const systemProvider = process.env.LLM_PROVIDER; + const systemModel = this.providerDefault(systemProvider); + if (systemProvider && systemModel) { + return { + provider: systemProvider, + model: systemModel, + }; + } + + return null; + } + /** * Finds or assumes the model preference value to use for API calls. * If multi-model loading is supported, we use their agent model selection of the workspace @@ -222,22 +245,41 @@ class AgentHandler { * @returns {string} the model preference value to use in API calls */ #fetchModel() { - if (!Object.keys(this.noProviderModelDefault).includes(this.provider)) - return this.invocation.workspace.agentModel || this.providerDefault(); + // Provider was not explicitly set for workspace, so we are going to run our fallback logic + // that will set a provider and model for us to use. + if (!this.provider) { + const fallback = this.#getFallbackProvider(); + if (!fallback) throw new Error("No valid provider found for the agent."); + this.provider = fallback.provider; // re-set the provider to the fallback provider so it is not null. + return fallback.model; // set its defined model based on fallback logic. + } + + // The provider was explicitly set, so check if the workspace has an agent model set. + if (this.invocation.workspace.agentModel) { + return this.invocation.workspace.agentModel; + } - // Provider has no reliable default (cant load many models) - so we need to look at system - // for the model param. + // If the provider we are using is not supported or does not support multi-model loading + // then we use the default model for the provider. + if (!Object.keys(this.noProviderModelDefault).includes(this.provider)) { + return this.providerDefault(); + } + + // Load the model from the system environment variable for providers with no multi-model loading. const sysModelKey = this.noProviderModelDefault[this.provider]; - if (!!sysModelKey) - return process.env[sysModelKey] ?? this.providerDefault(); + if (sysModelKey) return process.env[sysModelKey] ?? this.providerDefault(); - // If all else fails - look at the provider default list + // Otherwise, we have no model to use - so guess a default model to use. return this.providerDefault(); } #providerSetupAndCheck() { - this.provider = this.invocation.workspace.agentProvider; + this.provider = this.invocation.workspace.agentProvider ?? null; // set provider to workspace agent provider if it exists this.model = this.#fetchModel(); + + if (!this.provider) + throw new Error("No valid provider found for the agent."); + this.log(`Start ${this.#invocationUUID}::${this.provider}:${this.model}`); this.checkSetup(); }