Skip to content

Commit

Permalink
ChatSession and GenerativeModel implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
DellaBitta committed Jul 17, 2024
1 parent b73cb1a commit 20b16fb
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 40 deletions.
20 changes: 9 additions & 11 deletions common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ export interface CachedContentBase {

// @public
export class ChatSession {
constructor(apiKey: string, model: string, params?: StartChatParams, requestOptions?: RequestOptions);
constructor(apiKey: string, model: string, params?: StartChatParams, _requestOptions?: RequestOptions);
getHistory(): Promise<Content[]>;
// (undocumented)
model: string;
// (undocumented)
params?: StartChatParams;
// (undocumented)
requestOptions?: RequestOptions;
sendMessage(request: string | Array<string | Part>): Promise<GenerateContentResult>;
sendMessageStream(request: string | Array<string | Part>): Promise<GenerateContentStreamResult>;
sendMessage(request: string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
sendMessageStream(request: string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
}

// @public
Expand Down Expand Up @@ -462,16 +460,16 @@ export interface GenerativeContentBlob {

// @public
export class GenerativeModel {
constructor(apiKey: string, modelParams: ModelParams, requestOptions?: RequestOptions);
constructor(apiKey: string, modelParams: ModelParams, _requestOptions?: RequestOptions);
// (undocumented)
apiKey: string;
batchEmbedContents(batchEmbedContentRequest: BatchEmbedContentsRequest): Promise<BatchEmbedContentsResponse>;
batchEmbedContents(batchEmbedContentRequest: BatchEmbedContentsRequest, requestOptions?: SingleRequestOptions): Promise<BatchEmbedContentsResponse>;
// (undocumented)
cachedContent: CachedContent;
countTokens(request: CountTokensRequest | string | Array<string | Part>): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>): Promise<EmbedContentResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentStreamResult>;
countTokens(request: CountTokensRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<EmbedContentResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
Expand Down
27 changes: 23 additions & 4 deletions packages/main/src/methods/chat-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
GenerateContentStreamResult,
Part,
RequestOptions,
SingleRequestOptions,
StartChatParams,
} from "../../types";
import { formatNewContent } from "../requests/request-helpers";
Expand Down Expand Up @@ -49,7 +50,7 @@ export class ChatSession {
apiKey: string,
public model: string,
public params?: StartChatParams,
public requestOptions?: RequestOptions,
private _requestOptions: RequestOptions = {},
) {
this._apiKey = apiKey;
if (params?.history) {
Expand All @@ -70,10 +71,15 @@ export class ChatSession {

/**
* Sends a chat message and receives a non-streaming
* {@link GenerateContentResult}
* {@link GenerateContentResult}.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async sendMessage(
request: string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<GenerateContentResult> {
await this._sendPromise;
const newContent = formatNewContent(request);
Expand All @@ -86,6 +92,10 @@ export class ChatSession {
cachedContent: this.params?.cachedContent,
contents: [...this._history, newContent],
};
const chatSessionRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
let finalResult;
// Add onto the chain.
this._sendPromise = this._sendPromise
Expand All @@ -94,7 +104,7 @@ export class ChatSession {
this._apiKey,
this.model,
generateContentRequest,
this.requestOptions,
chatSessionRequestOptions,
),
)
.then((result) => {
Expand Down Expand Up @@ -128,9 +138,14 @@ export class ChatSession {
* Sends a chat message and receives the response as a
* {@link GenerateContentStreamResult} containing an iterable stream
* and a response promise.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async sendMessageStream(
request: string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<GenerateContentStreamResult> {
await this._sendPromise;
const newContent = formatNewContent(request);
Expand All @@ -143,11 +158,15 @@ export class ChatSession {
cachedContent: this.params?.cachedContent,
contents: [...this._history, newContent],
};
const chatSessionRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
const streamPromise = generateContentStream(
this._apiKey,
this.model,
generateContentRequest,
this.requestOptions,
chatSessionRequestOptions,
);

// Add onto the chain.
Expand Down
6 changes: 3 additions & 3 deletions packages/main/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
import {
CountTokensRequest,
CountTokensResponse,
RequestOptions,
SingleRequestOptions,
} from "../../types";
import { Task, makeModelRequest } from "../requests/request";

export async function countTokens(
apiKey: string,
model: string,
params: CountTokensRequest,
requestOptions?: RequestOptions,
singleRequestOptions: SingleRequestOptions,
): Promise<CountTokensResponse> {
const response = await makeModelRequest(
model,
Task.COUNT_TOKENS,
apiKey,
false,
JSON.stringify(params),
requestOptions,
singleRequestOptions,
);
return response.json();
}
6 changes: 3 additions & 3 deletions packages/main/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
GenerateContentResponse,
GenerateContentResult,
GenerateContentStreamResult,
RequestOptions,
SingleRequestOptions,
} from "../../types";
import { Task, makeModelRequest } from "../requests/request";
import { addHelpers } from "../requests/response-helpers";
Expand All @@ -30,7 +30,7 @@ export async function generateContentStream(
apiKey: string,
model: string,
params: GenerateContentRequest,
requestOptions?: RequestOptions,
requestOptions: SingleRequestOptions,
): Promise<GenerateContentStreamResult> {
const response = await makeModelRequest(
model,
Expand All @@ -47,7 +47,7 @@ export async function generateContent(
apiKey: string,
model: string,
params: GenerateContentRequest,
requestOptions?: RequestOptions,
requestOptions?: SingleRequestOptions,
): Promise<GenerateContentResult> {
const response = await makeModelRequest(
model,
Expand Down
67 changes: 56 additions & 11 deletions packages/main/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
Part,
RequestOptions,
SafetySetting,
SingleRequestOptions,
StartChatParams,
Tool,
ToolConfig,
Expand Down Expand Up @@ -67,7 +68,7 @@ export class GenerativeModel {
constructor(
public apiKey: string,
modelParams: ModelParams,
requestOptions?: RequestOptions,
private _requestOptions: RequestOptions = {},
) {
if (modelParams.model.includes("/")) {
// Models may be named "models/model-name" or "tunedModels/model-name"
Expand All @@ -84,17 +85,25 @@ export class GenerativeModel {
modelParams.systemInstruction,
);
this.cachedContent = modelParams.cachedContent;
this.requestOptions = requestOptions || {};
}

/**
* Makes a single non-streaming call to the model
* and returns an object containing a single {@link GenerateContentResponse}.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async generateContent(
request: GenerateContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<GenerateContentResult> {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return generateContent(
this.apiKey,
this.model,
Expand All @@ -107,20 +116,29 @@ export class GenerativeModel {
cachedContent: this.cachedContent?.name,
...formattedParams,
},
this.requestOptions,
generativeModelRequestOptions,
);
}

/**
* Makes a single streaming call to the model
* and returns an object containing an iterable stream that iterates
* over all chunks in the streaming response as well as
* a promise that returns the final aggregated response.
* Makes a single streaming call to the model and returns an object
* containing an iterable stream that iterates over all chunks in the
* streaming response as well as a promise that returns the final
* aggregated response.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async generateContentStream(
request: GenerateContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<GenerateContentStreamResult> {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return generateContentStream(
this.apiKey,
this.model,
Expand All @@ -133,7 +151,7 @@ export class GenerativeModel {
cachedContent: this.cachedContent?.name,
...formattedParams,
},
this.requestOptions,
generativeModelRequestOptions,
);
}

Expand All @@ -160,9 +178,14 @@ export class GenerativeModel {

/**
* Counts the tokens in the provided request.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async countTokens(
request: CountTokensRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<CountTokensResponse> {
const formattedParams = formatCountTokensInput(request, {
model: this.model,
Expand All @@ -173,40 +196,62 @@ export class GenerativeModel {
systemInstruction: this.systemInstruction,
cachedContent: this.cachedContent,
});
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return countTokens(
this.apiKey,
this.model,
formattedParams,
this.requestOptions,
generativeModelRequestOptions,
);
}

/**
* Embeds the provided content.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async embedContent(
request: EmbedContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
): Promise<EmbedContentResponse> {
const formattedParams = formatEmbedContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return embedContent(
this.apiKey,
this.model,
formattedParams,
this.requestOptions,
generativeModelRequestOptions,
);
}

/**
* Embeds an array of {@link EmbedContentRequest}s.
*
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided at the
* time of the {@link GoogleAIFileManager} initialization.
*/
async batchEmbedContents(
batchEmbedContentRequest: BatchEmbedContentsRequest,
requestOptions: SingleRequestOptions = {},
): Promise<BatchEmbedContentsResponse> {
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return batchEmbedContents(
this.apiKey,
this.model,
batchEmbedContentRequest,
this.requestOptions,
generativeModelRequestOptions,
);
}
}
Loading

0 comments on commit 20b16fb

Please sign in to comment.