Skip to content

Commit

Permalink
#67: Picked more idiomatic naming for the ProviderID
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Aug 8, 2024
1 parent e043e4d commit 339e3f9
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 124 deletions.
2 changes: 1 addition & 1 deletion pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
response := schemas.ChatResponse{
ID: cohereCompletion.ResponseID,
Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this
Provider: providerName,
Provider: ProviderID,
ModelName: c.config.ModelName,
Cached: false,
ModelResponse: schemas.ModelResponse{
Expand Down
10 changes: 5 additions & 5 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
if err != nil {
s.tel.L().Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", providerName),
zap.String("provider", ProviderID),
zap.Error(err),
)

Expand All @@ -101,7 +101,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

s.tel.L().Debug(
"Raw chat stream chunk",
zap.String("provider", providerName),
zap.String("provider", ProviderID),
zap.ByteString("rawChunk", rawChunk),
)

Expand All @@ -119,7 +119,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
if responseChunk.EventType != TextGenEvent && responseChunk.EventType != StreamEndEvent {
s.tel.L().Debug(
"Unsupported stream chunk type, skipping it",
zap.String("provider", providerName),
zap.String("provider", ProviderID),
zap.ByteString("chunk", rawChunk),
)

Expand All @@ -132,7 +132,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Cached: false,
Provider: providerName,
Provider: ProviderID,
ModelName: s.modelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
Expand All @@ -151,7 +151,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Cached: false,
Provider: providerName,
Provider: ProviderID,
ModelName: s.modelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/cohere/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
providerName = "cohere"
ProviderID = "cohere"
)

// Client is a client for accessing Cohere API
Expand Down Expand Up @@ -54,7 +54,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
}

func (c *Client) Provider() string {
return providerName
return ProviderID
}

func (c *Client) ModelName() string {
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/cohere/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error {
if err != nil {
m.tel.Logger.Error(
"Failed to unmarshal chat response error",
zap.String("provider", providerName),
zap.String("provider", ProviderID),
zap.Error(err),
zap.ByteString("rawResponse", bodyBytes),
)
Expand All @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error {

m.tel.Logger.Error(
"Chat request failed",
zap.String("provider", providerName),
zap.String("provider", ProviderID),
zap.Int("statusCode", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.Any("headers", resp.Header),
Expand Down
94 changes: 2 additions & 92 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func (p DynLangProvider) validate() error {

providerConfigUnmarshaller := func(providerConfig interface{}) error {
configValue := p[providerID]

providerConfigBytes, err := yaml.Marshal(configValue)
if err != nil {
return err
Expand All @@ -116,6 +117,7 @@ func (p DynLangProvider) validate() error {

func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain DynLangProvider // to avoid recursion

temp := plain{}

if err := unmarshal(&temp); err != nil {
Expand All @@ -126,95 +128,3 @@ func (p *DynLangProvider) UnmarshalYAML(unmarshal func(interface{}) error) error

return p.validate()
}

// TODO: Remove this old LangProviders struct

//type LangProviders struct {
// // Add other providers like
// OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"`
// AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"`
// Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"`
// OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"`
// Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"`
// Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"`
// Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"`
//}
//
//var _ ProviderConfig = (*LangProviders)(nil)

// ToClient initializes the language model client based on the provided configuration.
// It takes a telemetry object as input and returns a LangModelProvider and an error.
//func (c LangProviders) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) {
// switch {
// case c.OpenAI != nil:
// return openai.NewClient(c.OpenAI, clientConfig, tel)
// case c.AzureOpenAI != nil:
// return azureopenai.NewClient(c.AzureOpenAI, clientConfig, tel)
// case c.Cohere != nil:
// return cohere.NewClient(c.Cohere, clientConfig, tel)
// case c.OctoML != nil:
// return octoml.NewClient(c.OctoML, clientConfig, tel)
// case c.Anthropic != nil:
// return anthropic.NewClient(c.Anthropic, clientConfig, tel)
// case c.Bedrock != nil:
// return bedrock.NewClient(c.Bedrock, clientConfig, tel)
// default:
// return nil, ErrProviderNotFound
// }
//}

//func (c *LangProviders) validateOneProvider() error {
// providersConfigured := 0
//
// if c.OpenAI != nil {
// providersConfigured++
// }
//
// if c.AzureOpenAI != nil {
// providersConfigured++
// }
//
// if c.Cohere != nil {
// providersConfigured++
// }
//
// if c.OctoML != nil {
// providersConfigured++
// }
//
// if c.Anthropic != nil {
// providersConfigured++
// }
//
// if c.Bedrock != nil {
// providersConfigured++
// }
//
// if c.Ollama != nil {
// providersConfigured++
// }
//
// // check other providers here
// if providersConfigured == 0 {
// return ErrNoProviderConfigured
// }
//
// if providersConfigured > 1 {
// return fmt.Errorf(
// "exactly one provider must be configured, but %v are configured",
// providersConfigured,
// )
// }
//
// return nil
//}

//func (c *LangProviders) UnmarshalYAML(unmarshal func(interface{}) error) error {
// type plain LangProviders // to avoid recursion
//
// if err := unmarshal((*plain)(c)); err != nil {
// return err
// }
//
// return c.validateOneProvider()
//}
2 changes: 1 addition & 1 deletion pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
response := schemas.ChatResponse{
ID: chatCompletion.ID,
Created: chatCompletion.Created,
Provider: ProviderOpenAI,
Provider: ProviderID,
ModelName: chatCompletion.ModelName,
Cached: false,
ModelResponse: schemas.ModelResponse{
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/openai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Cached: false,
Provider: ProviderOpenAI,
Provider: ProviderID,
ModelName: completionChunk.ModelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
Expand Down
6 changes: 3 additions & 3 deletions pkg/providers/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

const (
ProviderOpenAI = "openai"
ProviderID = "openai"
)

// Client is a client for accessing OpenAI API
Expand All @@ -37,7 +37,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
}

logger := tel.L().With(
zap.String("provider", ProviderOpenAI),
zap.String("provider", ProviderID),
)

c := &Client{
Expand All @@ -62,7 +62,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
}

func (c *Client) Provider() string {
return ProviderOpenAI
return ProviderID
}

func (c *Client) ModelName() string {
Expand Down
4 changes: 2 additions & 2 deletions pkg/providers/openai/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error {
if err != nil {
m.tel.Logger.Error(
"Failed to unmarshal chat response error",
zap.String("provider", ProviderOpenAI),
zap.String("provider", ProviderID),
zap.Error(err),
zap.ByteString("rawResponse", bodyBytes),
)
Expand All @@ -38,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error {

m.tel.Logger.Error(
"Chat request failed",
zap.String("provider", ProviderOpenAI),
zap.String("provider", ProviderID),
zap.Int("statusCode", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.Any("headers", resp.Header),
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/openai/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ import (
)

func init() {
providers.LangRegistry.Register(ProviderOpenAI, &Config{})
providers.LangRegistry.Register(ProviderID, &Config{})
}
2 changes: 1 addition & 1 deletion pkg/providers/testing/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Config struct {
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
}

func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) {
func (c *Config) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (provider.LangProvider, error) {
return NewProviderMock(nil, []RespMock{}), nil
}

Expand Down
26 changes: 13 additions & 13 deletions pkg/routers/lang/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestRouterConfig_BuildModels(t *testing.T) {
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: &providers.DynLangProvider{
openai.ProviderOpenAI: &openai.Config{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand All @@ -46,7 +46,7 @@ func TestRouterConfig_BuildModels(t *testing.T) {
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: &providers.DynLangProvider{
openai.ProviderOpenAI: &openai.Config{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) {
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: &providers.DynLangProvider{
openai.ProviderOpenAI: &openai.Config{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &openAIParams,
},
Expand All @@ -93,8 +93,8 @@ func TestRouterConfig_BuildModelsPerType(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
Cohere: &cohere.Config{
Provider: &providers.DynLangProvider{
cohere.ProviderID: &cohere.Config{
APIKey: "ABC",
DefaultParams: &cohereParams,
},
Expand Down Expand Up @@ -129,8 +129,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
Provider: &providers.DynLangProvider{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand All @@ -147,8 +147,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
Provider: &providers.DynLangProvider{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand All @@ -170,8 +170,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
Provider: &providers.DynLangProvider{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand All @@ -183,8 +183,8 @@ func TestRouterConfig_InvalidSetups(t *testing.T) {
Client: clients.DefaultClientConfig(),
ErrorBudget: health.DefaultErrorBudget(),
Latency: latency.DefaultConfig(),
Provider: providers.LangProviders{
OpenAI: &openai.Config{
Provider: &providers.DynLangProvider{
openai.ProviderID: &openai.Config{
APIKey: "ABC",
DefaultParams: &defaultParams,
},
Expand Down

0 comments on commit 339e3f9

Please sign in to comment.