diff --git a/docs/docs.go b/docs/docs.go index 51a45f21..67918956 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -203,7 +203,11 @@ const docTemplate = `{ }, "role": { "description": "The role of the author of this message. One of system, user, or assistant.", - "type": "string" + "allOf": [ + { + "$ref": "#/definitions/schemas.Role" + } + ] } } }, @@ -308,6 +312,19 @@ const docTemplate = `{ } } }, + "schemas.Role": { + "type": "string", + "enum": [ + "system", + "user", + "assistant" + ], + "x-enum-varnames": [ + "RoleSystem", + "RoleUser", + "RoleAssistant" + ] + }, "schemas.RouterListSchema": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index aa0d3b25..1912039e 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -200,7 +200,11 @@ }, "role": { "description": "The role of the author of this message. One of system, user, or assistant.", - "type": "string" + "allOf": [ + { + "$ref": "#/definitions/schemas.Role" + } + ] } } }, @@ -305,6 +309,19 @@ } } }, + "schemas.Role": { + "type": "string", + "enum": [ + "system", + "user", + "assistant" + ], + "x-enum-varnames": [ + "RoleSystem", + "RoleUser", + "RoleAssistant" + ] + }, "schemas.RouterListSchema": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 82dcc4a6..4a683f2c 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -6,9 +6,10 @@ definitions: description: The content of the message. type: string role: + allOf: + - $ref: '#/definitions/schemas.Role' description: The role of the author of this message. One of system, user, or assistant. - type: string required: - content - role @@ -77,6 +78,16 @@ definitions: token_usage: $ref: '#/definitions/schemas.TokenUsage' type: object + schemas.Role: + enum: + - system + - user + - assistant + type: string + x-enum-varnames: + - RoleSystem + - RoleUser + - RoleAssistant schemas.RouterListSchema: properties: routers: diff --git a/pkg/api/schemas/chat.go b/pkg/api/schemas/chat.go index bb846043..2ca0af06 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schemas/chat.go @@ -62,7 +62,7 @@ func (r *ChatRequest) Params(modelID string, modelName string) *ChatParams { func NewChatFromStr(message string) *ChatRequest { return &ChatRequest{ Message: ChatMessage{ - "user", + RoleUser, message, }, } @@ -93,10 +93,18 @@ type TokenUsage struct { TotalTokens int `json:"total_tokens"` } +type Role string + +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + // ChatMessage is a message in a chat request. type ChatMessage struct { // The role of the author of this message. One of system, user, or assistant. - Role string `json:"role" validate:"required"` + Role Role `json:"role" validate:"required"` // The content of the message. Content string `json:"content" validate:"required"` } diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go index f7cf8b27..bdcf8fcd 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schemas/chat_stream.go @@ -30,7 +30,7 @@ func NewChatStreamFromStr(message string) *ChatStreamRequest { return &ChatStreamRequest{ ChatRequest: &ChatRequest{ Message: ChatMessage{ - "user", + RoleUser, message, }, }, diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schemas/chat_test.go index 9b5ce407..f4cfe9fa 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schemas/chat_test.go @@ -30,7 +30,7 @@ func TestChatRequest_DefaultParams(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -42,7 +42,7 @@ func TestChatRequest_DefaultParams(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -66,7 +66,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -78,7 +78,7 @@ func TestChatRequest_ModelIDOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -102,7 +102,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -114,7 +114,7 @@ func TestChatRequest_ModelNameOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelName: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelMessage, }, }, @@ -139,7 +139,7 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { chatReq := ChatRequest{ Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: defaultMessage, }, MessageHistory: []ChatMessage{ @@ -151,13 +151,13 @@ func TestChatRequest_ModelNameIDOverride(t *testing.T) { OverrideParams: &map[string]ModelParamsOverride{ modelName: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelNameMessage, }, }, modelID: { Message: ChatMessage{ - Role: "user", + Role: RoleUser, Content: myModelIDMessage, }, }, diff --git a/pkg/providers/anthropic/chat.go b/pkg/providers/anthropic/chat.go index 80b45f2b..f0e14fb2 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/providers/anthropic/chat.go @@ -139,7 +139,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche ModelResponse: schemas.ModelResponse{ Metadata: map[string]string{}, Message: schemas.ChatMessage{ - Role: completion.Type, + Role: schemas.Role(completion.Type), Content: completion.Text, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/providers/azureopenai/chat_stream_test.go index 5aade1f5..efb70a0c 100644 --- a/pkg/providers/azureopenai/chat_stream_test.go +++ b/pkg/providers/azureopenai/chat_stream_test.go @@ -72,7 +72,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -140,7 +140,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/providers/azureopenai/client_test.go index 1700bca0..3529413c 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/providers/azureopenai/client_test.go @@ -56,7 +56,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -116,7 +116,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the dealio?", }}} diff --git a/pkg/providers/bedrock/chat.go b/pkg/providers/bedrock/chat.go index 658c1769..dda0637e 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/providers/bedrock/chat.go @@ -104,7 +104,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Cached: false, ModelResponse: schemas.ModelResponse{ Message: schemas.ChatMessage{ - Role: "assistant", + Role: schemas.RoleAssistant, Content: modelResult.OutputText, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/bedrock/client_test.go b/pkg/providers/bedrock/client_test.go index cdae1f68..57056150 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/providers/bedrock/client_test.go @@ -62,7 +62,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/bedrock/testdata/chat.req.json b/pkg/providers/bedrock/testdata/chat.req.json index c2e941d2..9466eda7 100644 --- a/pkg/providers/bedrock/testdata/chat.req.json +++ b/pkg/providers/bedrock/testdata/chat.req.json @@ -2,7 +2,7 @@ "model": "amazon.titan-text-express-v1", "messages": [ { - "role": "user", + "role": schemas.RoleUser, "content": "What's the biggest animal?" } ], diff --git a/pkg/providers/cohere/chat.go b/pkg/providers/cohere/chat.go index ddf75680..969d7da9 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/providers/cohere/chat.go @@ -127,7 +127,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche "responseId": cohereCompletion.ResponseID, }, Message: schemas.ChatMessage{ - Role: "assistant", + Role: payload.Role, Content: cohereCompletion.Text, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/cohere/schemas.go b/pkg/providers/cohere/schemas.go index 9dc9bb09..2c31ce13 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/providers/cohere/schemas.go @@ -2,7 +2,7 @@ package cohere import "github.com/EinStack/glide/pkg/api/schemas" -// Cohere Chat Response +// ChatCompletion Cohere Chat Response type ChatCompletion struct { Text string `json:"text"` GenerationID string `json:"generation_id"` @@ -92,6 +92,7 @@ type FinalResponse struct { type ChatRequest struct { Model string `json:"model"` Message string `json:"message"` + Role schemas.Role `json:"role"` ChatHistory []schemas.ChatMessage `json:"chat_history"` Temperature float64 `json:"temperature,omitempty"` Preamble string `json:"preamble,omitempty"` @@ -112,8 +113,26 @@ func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { message := params.Messages[len(params.Messages)-1] messageHistory := params.Messages[:len(params.Messages)-1] - // TODO: Map chat message roles to Cohere roles: CHATBOT, SYSTEM, USER - + mapRole := func(role schemas.Role) string { + switch role { + case schemas.RoleSystem: + return "SYSTEM" + case schemas.RoleUser: + return "USER" + case schemas.RoleAssistant: + return "CHATBOT" + default: + return "USER" + } + } + + for i := range messageHistory { + messageHistory[i].Role = schemas.Role(mapRole(messageHistory[i].Role)) + } + + message.Role = schemas.Role(mapRole(message.Role)) + + r.Role = message.Role r.Message = message.Content r.ChatHistory = messageHistory } diff --git a/pkg/providers/octoml/client_test.go b/pkg/providers/octoml/client_test.go index f35de1f7..353dbe9d 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/providers/octoml/client_test.go @@ -121,7 +121,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { // Create a chat request payload chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the dealeo?", }}} diff --git a/pkg/providers/ollama/chat.go b/pkg/providers/ollama/chat.go index b93f5b10..40356d7b 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/providers/ollama/chat.go @@ -172,7 +172,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche Cached: false, ModelResponse: schemas.ModelResponse{ Message: schemas.ChatMessage{ - Role: ollamaCompletion.Message.Role, + Role: schemas.Role(ollamaCompletion.Message.Role), Content: ollamaCompletion.Message.Content, }, TokenUsage: schemas.TokenUsage{ diff --git a/pkg/providers/ollama/client_test.go b/pkg/providers/ollama/client_test.go index e6c584cf..e7af71d9 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/providers/ollama/client_test.go @@ -57,7 +57,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the biggest animal?", }}} @@ -85,7 +85,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -122,7 +122,7 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -130,6 +130,6 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { require.NoError(t, err) require.NotNil(t, response) - require.Equal(t, "assistant", response.ModelResponse.Message.Role) + require.Equal(t, schemas.RoleAssistant, response.ModelResponse.Message.Role) require.Equal(t, "London", response.ModelResponse.Message.Content) } diff --git a/pkg/providers/openai/chat_stream.go b/pkg/providers/openai/chat_stream.go index 08ca2b21..659d8b8d 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/providers/openai/chat_stream.go @@ -120,7 +120,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { "generated_at": completionChunk.Created, }, Message: schemas.ChatMessage{ - Role: "assistant", // doesn't present in all chunks + Role: schemas.RoleAssistant, // doesn't present in all chunks Content: responseChunk.Delta.Content, }, }, diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/providers/openai/chat_stream_test.go index 1ab8483b..236b9d6f 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/providers/openai/chat_stream_test.go @@ -72,7 +72,7 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} @@ -140,7 +140,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/providers/openai/chat_test.go b/pkg/providers/openai/chat_test.go index 3109f150..209507ad 100644 --- a/pkg/providers/openai/chat_test.go +++ b/pkg/providers/openai/chat_test.go @@ -57,7 +57,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { require.NoError(t, err) chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ - Role: "user", + Role: schemas.RoleUser, Content: "What's the capital of the United Kingdom?", }}}