From 339e3f9f2f68791eccf05babc570f0a88a55e2ab Mon Sep 17 00:00:00 2001 From: Roman Glushko Date: Thu, 8 Aug 2024 17:13:03 +0300 Subject: [PATCH] #67: Picked more idiomatic naming for the ProviderID --- pkg/providers/cohere/chat.go | 2 +- pkg/providers/cohere/chat_stream.go | 10 +-- pkg/providers/cohere/client.go | 4 +- pkg/providers/cohere/errors.go | 4 +- pkg/providers/config.go | 94 +---------------------------- pkg/providers/openai/chat.go | 2 +- pkg/providers/openai/chat_stream.go | 2 +- pkg/providers/openai/client.go | 6 +- pkg/providers/openai/errors.go | 4 +- pkg/providers/openai/register.go | 2 +- pkg/providers/testing/config.go | 2 +- pkg/routers/lang/config_test.go | 26 ++++---- 12 files changed, 34 insertions(+), 124 deletions(-) diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index 4729d55..754d853 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -118,7 +118,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this - Provider: providerName, + Provider: ProviderID, ModelName: c.config.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/providers/cohere/chat_stream.go index 6f19494..392f0a2 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/providers/cohere/chat_stream.go @@ -90,7 +90,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if err != nil { s.tel.L().Warn( "Chat stream is unexpectedly disconnected", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), ) @@ -101,7 +101,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.tel.L().Debug( "Raw chat stream chunk", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("rawChunk", rawChunk), ) @@ -119,7 +119,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if responseChunk.EventType != TextGenEvent && responseChunk.EventType != StreamEndEvent { s.tel.L().Debug( "Unsupported stream chunk type, skipping it", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("chunk", rawChunk), ) @@ -132,7 +132,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ @@ -151,7 +151,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ diff --git a/pkg/providers/cohere/client.go b/pkg/providers/cohere/client.go index a842659..3393e01 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/providers/cohere/client.go @@ -11,7 +11,7 @@ import ( ) const ( - providerName = "cohere" + ProviderID = "cohere" ) // Client is a client for accessing Cohere API @@ -54,7 +54,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/cohere/errors.go b/pkg/providers/cohere/errors.go index 5b5548c..5f8ea04 100644 --- a/pkg/providers/cohere/errors.go +++ b/pkg/providers/cohere/errors.go @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/config.go b/pkg/providers/config.go index f58ec04..466f07e 100644 --- a/pkg/providers/config.go +++ b/pkg/providers/config.go @@ -98,6 +98,7 @@ func (p DynLangProvider) validate() error { providerConfigUnmarshaller := func(providerConfig interface{}) error { configValue := p[providerID] + providerConfigBytes, err := yaml.Marshal(configValue) if err != nil { return err @@ -116,6 +117,7 @@ func (p DynLangProvider) validate() error { func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error { type plain DynLangProvider // to avoid recursion + temp := plain{} if err := unmarshal(&temp); err != nil { @@ -126,95 +128,3 @@ func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error return p.validate() } - -// TODO: Remove this old LangProviders struct - -//type LangProviders struct { -// // Add other providers like -// OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` -// AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` -// Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` -// OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` -// Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` -// Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` -// Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` -//} -// -//var _ ProviderConfig = (*LangProviders)(nil) - -// ToClient initializes the language model client based on the provided configuration. -// It takes a telemetry object as input and returns a LangModelProvider and an error. -//func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { -// switch { -// case c.OpenAI != nil: -// return openai.NewClient(c.OpenAI, clientConfig, tel) -// case c.AzureOpenAI != nil: -// return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel) -// case c.Cohere != nil: -// return cohere.NewClient(c.Cohere, clientConfig, tel) -// case c.OctoML != nil: -// return octoml.NewClient(c.OctoML, clientConfig, tel) -// case c.Anthropic != nil: -// return anthropic.NewClient(c.Anthropic, clientConfig, tel) -// case c.Bedrock != nil: -// return bedrock.NewClient(c.Bedrock, clientConfig, tel) -// default: -// return nil, ErrProviderNotFound -// } -//} - -//func (c *LangProviders) validateOneProvider() error { -// providersConfigured := 0 -// -// if c.OpenAI != nil { -// providersConfigured++ -// } -// -// if c.AzureOpenAI != nil { -// providersConfigured++ -// } -// -// if c.Cohere != nil { -// providersConfigured++ -// } -// -// if c.OctoML != nil { -// providersConfigured++ -// } -// -// if c.Anthropic != nil { -// providersConfigured++ -// } -// -// if c.Bedrock != nil { -// providersConfigured++ -// } -// -// if c.Ollama != nil { -// providersConfigured++ -// } -// -// // check other providers here -// if providersConfigured == 0 { -// return ErrNoProviderConfigured -// } -// -// if providersConfigured > 1 { -// return fmt.Errorf( -// "exactly one provider must be configured, but %v are configured", -// providersConfigured, -// ) -// } -// -// return nil -//} - -//func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error { -// type plain LangProviders // to avoid recursion -// -// if err := unmarshal((*plain)(c)); err != nil { -// return err -// } -// -// return c.validateOneProvider() -//} diff --git a/pkg/providers/openai/chat.go b/pkg/providers/openai/chat.go index 0669829..86bce6f 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/providers/openai/chat.go @@ -126,7 +126,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche response := schemas.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, - Provider: ProviderOpenAI, + Provider: ProviderID, ModelName: chatCompletion.ModelName, Cached: false, ModelResponse: schemas.ModelResponse{ diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 8fd0a61..ba219e3 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -112,7 +112,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { // TODO: use objectpool here return &schemas.ChatStreamChunk{ Cached: false, - Provider: ProviderOpenAI, + Provider: ProviderID, ModelName: completionChunk.ModelName, ModelResponse: schemas.ModelChunkResponse{ Metadata: &schemas.Metadata{ diff --git a/pkg/providers/openai/client.go b/pkg/providers/openai/client.go index 795d94f..30a0438 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/providers/openai/client.go @@ -13,7 +13,7 @@ import ( ) const ( - ProviderOpenAI = "openai" + ProviderID = "openai" ) // Client is a client for accessing OpenAI API @@ -37,7 +37,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } logger := tel.L().With( - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), ) c := &Client{ @@ -62,7 +62,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return ProviderOpenAI + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/openai/errors.go b/pkg/providers/openai/errors.go index 640962c..d0389cb 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/providers/openai/errors.go @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", ProviderOpenAI), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/openai/register.go b/pkg/providers/openai/register.go index baf37ac..4435ac8 100644 --- a/pkg/providers/openai/register.go +++ b/pkg/providers/openai/register.go @@ -5,5 +5,5 @@ import ( ) func init() { - providers.LangRegistry.Register(ProviderOpenAI, &Config{}) + providers.LangRegistry.Register(ProviderID, &Config{}) } diff --git a/pkg/providers/testing/config.go b/pkg/providers/testing/config.go index 11237fc..dd7d085 100644 --- a/pkg/providers/testing/config.go +++ b/pkg/providers/testing/config.go @@ -18,7 +18,7 @@ type Config struct { APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` } -func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { +func (c *Config) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (provider.LangProvider, error) { return NewProviderMock(nil, []RespMock{}), nil } diff --git a/pkg/routers/lang/config_test.go b/pkg/routers/lang/config_test.go index aa782f9..1ed4336 100644 --- a/pkg/routers/lang/config_test.go +++ b/pkg/routers/lang/config_test.go @@ -28,7 +28,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -46,7 +46,7 @@ func TestRouterConfig_BuildModels(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -81,7 +81,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), Provider: &providers.DynLangProvider{ - openai.ProviderOpenAI: &openai.Config{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &openAIParams, }, @@ -93,8 +93,8 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - Cohere: &cohere.Config{ + Provider: &providers.DynLangProvider{ + cohere.ProviderID: &cohere.Config{ APIKey: "ABC", DefaultParams: &cohereParams, }, @@ -129,8 +129,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -147,8 +147,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -170,8 +170,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, }, @@ -183,8 +183,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) { Client: clients.DefaultClientConfig(), ErrorBudget: health.DefaultErrorBudget(), Latency: latency.DefaultConfig(), - Provider: providers.LangProviders{ - OpenAI: &openai.Config{ + Provider: &providers.DynLangProvider{ + openai.ProviderID: &openai.Config{ APIKey: "ABC", DefaultParams: &defaultParams, },