Skip to content

Commit

Permalink
Update UseOpenTelemetry for latest genai spec updates (#5532)
Browse files Browse the repository at this point in the history
* Update UseOpenTelemetry for latest genai spec updates

- Events are now expected to be emitted as body fields, and the newly-recommended way to achieve that is via ILogger. So UseOpenTelemetry now takes an optional logger that it uses for emitting such data.
- I restructured the implementation to reduce duplication.
- Added logging of response format and seed.
- Added ChatOptions.TopK, as it's one of the parameters considered special by the spec.
- Updated the Azure.AI.Inference provider name to match the convention and what the library itself uses
- Updated the OpenAI client to use openai regardless of the kind of the actual client being used, per spec and recommendation

* Address PR feedback
  • Loading branch information
stephentoub authored Oct 17, 2024
1 parent 8eddb54 commit 8690e7a
Show file tree
Hide file tree
Showing 18 changed files with 536 additions and 464 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ public class ChatOptions
/// <summary>Gets or sets the "nucleus sampling" factor (or "top p") for generating chat responses.</summary>
public float? TopP { get; set; }

/// <summary>Gets or sets a count indicating how many of the most probable tokens the model should consider when generating the next part of the text.</summary>
public int? TopK { get; set; }

/// <summary>Gets or sets the frequency penalty for generating chat responses.</summary>
public float? FrequencyPenalty { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, s
var providerUrl = typeof(ChatCompletionsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatCompletionsClient) as Uri;

Metadata = new("AzureAIInference", providerUrl, modelId);
Metadata = new("az.ai.inference", providerUrl, modelId);
}

/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
Expand Down Expand Up @@ -296,13 +296,19 @@ private ChatCompletionsOptions ToAzureAIOptions(IList<ChatMessage> chatContents,
}
}

// These properties are strongly-typed on ChatOptions but not on ChatCompletionsOptions.
if (options.TopK is int topK)
{
result.AdditionalProperties["top_k"] = BinaryData.FromObjectAsJson(topK, JsonContext.Default.Options);
}

if (options.AdditionalProperties is { } props)
{
foreach (var prop in props)
{
switch (prop.Key)
{
// These properties are strongly-typed on the ChatCompletionsOptions class.
// These properties are strongly-typed on the ChatCompletionsOptions class but not on the ChatOptions class.
case nameof(result.Seed) when prop.Value is long seed:
result.Seed = seed;
break;
Expand Down Expand Up @@ -498,5 +504,6 @@ private static FunctionCallContent ParseCallContentFromJsonString(string json, s
[JsonSerializable(typeof(AzureAIChatToolJson))]
[JsonSerializable(typeof(IDictionary<string, object?>))]
[JsonSerializable(typeof(JsonElement))]
[JsonSerializable(typeof(int))]
private sealed partial class JsonContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, C
TransferMetadataValue<float>(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value);
TransferMetadataValue<long>(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value);
TransferMetadataValue<int>(nameof(OllamaRequestOptions.top_k), (options, value) => options.top_k = value);
TransferMetadataValue<float>(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value);
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value);
TransferMetadataValue<bool>(nameof(OllamaRequestOptions.use_mlock), (options, value) => options.use_mlock = value);
Expand Down Expand Up @@ -294,6 +293,11 @@ private OllamaChatRequest ToOllamaChatRequest(IList<ChatMessage> chatMessages, C
{
(request.Options ??= new()).top_p = topP;
}

if (options.TopK is int topK)
{
(request.Options ??= new()).top_k = topK;
}
}

return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId)
// The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai";
Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(openAIClient) as Uri ?? _defaultOpenAIEndpoint;

Metadata = new(providerName, providerUrl, modelId);
Metadata = new("openai", providerUrl, modelId);
}

/// <summary>Initializes a new instance of the <see cref="OpenAIChatClient"/> class for the specified <see cref="ChatClient"/>.</summary>
Expand All @@ -69,13 +68,12 @@ public OpenAIChatClient(ChatClient chatClient)
// The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
string providerName = chatClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai";
Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatClient) as Uri ?? _defaultOpenAIEndpoint;
string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(chatClient) as string;

Metadata = new(providerName, providerUrl, model);
Metadata = new("openai", providerUrl, model);
}

/// <summary>Gets or sets <see cref="JsonSerializerOptions"/> to use for any serialization activities related to tool call arguments and results.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ public OpenAIEmbeddingGenerator(
// The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai";
string providerUrl = (typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(openAIClient) as Uri)?.ToString() ??
DefaultOpenAIEndpoint;

Metadata = CreateMetadata(dimensions, providerName, providerUrl, modelId);
Metadata = CreateMetadata("openai", providerUrl, modelId, dimensions);
}

/// <summary>Initializes a new instance of the <see cref="OpenAIEmbeddingGenerator"/> class.</summary>
Expand All @@ -78,19 +77,18 @@ public OpenAIEmbeddingGenerator(EmbeddingClient embeddingClient, int? dimensions
// The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages
// implement the abstractions directly rather than providing adapters on top of the public APIs,
// the package can provide such implementations separate from what's exposed in the public API.
string providerName = embeddingClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai";
string providerUrl = (typeof(EmbeddingClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
?.GetValue(embeddingClient) as Uri)?.ToString() ??
DefaultOpenAIEndpoint;

FieldInfo? modelField = typeof(EmbeddingClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
string? model = modelField?.GetValue(embeddingClient) as string;

Metadata = CreateMetadata(dimensions, providerName, providerUrl, model);
Metadata = CreateMetadata("openai", providerUrl, model, dimensions);
}

/// <summary>Creates the <see cref="EmbeddingGeneratorMetadata"/> for this instance.</summary>
private static EmbeddingGeneratorMetadata CreateMetadata(int? dimensions, string providerName, string providerUrl, string? model) =>
private static EmbeddingGeneratorMetadata CreateMetadata(string providerName, string providerUrl, string? model, int? dimensions) =>
new(providerName, Uri.TryCreate(providerUrl, UriKind.Absolute, out Uri? providerUri) ? providerUri : null, model, dimensions);

/// <inheritdoc />
Expand Down
Loading

0 comments on commit 8690e7a

Please sign in to comment.