diff --git a/package-lock.json b/package-lock.json index 48571a5..0968a84 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@chakra-ui/react": "^2.5.1", "@emotion/react": "^11.10.6", "@emotion/styled": "^11.10.6", + "@google/generative-ai": "^0.1.1", "framer-motion": "^9.0.4", "highlightjs-solidity": "^2.0.6", "mixpanel-browser": "^2.46.0", @@ -1745,6 +1746,14 @@ "node": ">=12" } }, + "node_modules/@google/generative-ai": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.1.tgz", + "integrity": "sha512-cbzKa8mT9YkTrT4XUuENIuvlqiJjwDgcD2Ks4L99Az9dWLgdXn8xnETEAZLOpqzoGx+1PuATZqlUnVRAeLbMgA==", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@motionone/animation": { "version": "10.15.1", "resolved": "https://registry.npmjs.org/@motionone/animation/-/animation-10.15.1.tgz", @@ -5896,6 +5905,11 @@ "dev": true, "optional": true }, + "@google/generative-ai": { + "version": "0.1.1", + "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.1.1.tgz", + "integrity": "sha512-cbzKa8mT9YkTrT4XUuENIuvlqiJjwDgcD2Ks4L99Az9dWLgdXn8xnETEAZLOpqzoGx+1PuATZqlUnVRAeLbMgA==" + }, "@motionone/animation": { "version": "10.15.1", "resolved": "https://registry.npmjs.org/@motionone/animation/-/animation-10.15.1.tgz", diff --git a/package.json b/package.json index 80565fc..fa4e80e 100644 --- a/package.json +++ b/package.json @@ -13,6 +13,7 @@ "@chakra-ui/react": "^2.5.1", "@emotion/react": "^11.10.6", "@emotion/styled": "^11.10.6", + "@google/generative-ai": "^0.1.1", "framer-motion": "^9.0.4", "highlightjs-solidity": "^2.0.6", "mixpanel-browser": "^2.46.0", diff --git a/src/components/App.tsx b/src/components/App.tsx index 5512ceb..84e2e71 100644 --- a/src/components/App.tsx +++ b/src/components/App.tsx @@ -4,8 +4,11 @@ import { Column, Row } from "../utils/chakra"; import { copySnippetToClipboard } from "../utils/clipboard"; import { getFluxNodeTypeColor, getFluxNodeTypeDarkColor } from "../utils/color"; import { getPlatformModifierKey, getPlatformModifierKeyText } from "../utils/platform"; +import { GoogleGenerativeAI } from "@google/generative-ai"; + import { API_KEY_LOCAL_STORAGE_KEY, + GOOGLE_API_KEY_LOCAL_STORAGE_KEY, DEFAULT_SETTINGS, FIT_VIEW_SETTINGS, HOTKEY_CONFIG, @@ -283,6 +286,7 @@ function App() { const submitPrompt = async (overrideExistingIfPossible: boolean) => { takeSnapshot(); + const apiProvider = settings.model.includes("Gemini") ? "gemini" : "openai"; const responses = settings.n; const temp = settings.temp; const model = settings.model; @@ -361,74 +365,109 @@ function App() { if (firstCompletionId === undefined) throw new Error("No first completion id!"); - (async () => { - const stream = await OpenAI( - "chat", - { - model, - n: responses, - temperature: temp, - messages: messagesFromLineage(parentNodeLineage, settings), - }, - { apiKey: apiKey!, mode: "raw" } - ); - - const DECODER = new TextDecoder(); - - const abortController = new AbortController(); + if (apiProvider === "openai") { + // Existing OpenAI logic + (async () => { + const stream = await OpenAI( + "chat", + { + model, + n: responses, + temperature: temp, + messages: messagesFromLineage(parentNodeLineage, settings), + }, + { apiKey: apiKey!, mode: "raw" } + ); - for await (const chunk of yieldStream(stream, abortController)) { - if (abortController.signal.aborted) break; + const DECODER = new TextDecoder(); - try { - const decoded = JSON.parse(DECODER.decode(chunk)); + const abortController = new AbortController(); - if (decoded.choices === undefined) - throw new Error( - "No choices in response. Decoded response: " + JSON.stringify(decoded) - ); + for await (const chunk of yieldStream(stream, abortController)) { + if (abortController.signal.aborted) break; - const choice: CreateChatCompletionStreamResponseChoicesInner = - decoded.choices[0]; + try { + const decoded = JSON.parse(DECODER.decode(chunk)); + + if (decoded.choices === undefined) + throw new Error( + "No choices in response. Decoded response: " + JSON.stringify(decoded) + ); + + const choice: CreateChatCompletionStreamResponseChoicesInner = + decoded.choices[0]; + + if (choice.index === undefined) + throw new Error( + "No index in choice. Decoded choice: " + JSON.stringify(choice) + ); + + const correspondingNodeId = + // If we re-used a node we have to pull it from children array. + overrideExistingIfPossible && choice.index < currentNodeChildren.length + ? currentNodeChildren[choice.index].id + : newNodes[newNodes.length - responses + choice.index].id; + + // The ChatGPT API will start by returning a + // choice with only a role delta and no content. + if (choice.delta?.content) { + setNodes((newerNodes) => { + try { + return appendTextToFluxNodeAsGPT(newerNodes, { + id: correspondingNodeId, + text: choice.delta?.content ?? UNDEFINED_RESPONSE_STRING, + streamId, // This will cause a throw if the streamId has changed. + }); + } catch (e: any) { + // If the stream id does not match, + // it is stale and we should abort. + abortController.abort(e.message); + + return newerNodes; + } + }); + } - if (choice.index === undefined) - throw new Error( - "No index in choice. Decoded choice: " + JSON.stringify(choice) - ); + // We cannot return within the loop, and we do + // not want to execute the code below, so we break. + if (abortController.signal.aborted) break; - const correspondingNodeId = - // If we re-used a node we have to pull it from children array. - overrideExistingIfPossible && choice.index < currentNodeChildren.length - ? currentNodeChildren[choice.index].id - : newNodes[newNodes.length - responses + choice.index].id; - - // The ChatGPT API will start by returning a - // choice with only a role delta and no content. - if (choice.delta?.content) { - setNodes((newerNodes) => { - try { - return appendTextToFluxNodeAsGPT(newerNodes, { + // If the choice has a finish reason, then it's the final + // choice and we can mark it as no longer animated right now. + if (choice.finish_reason !== null) { + // Reset the stream id. + setNodes((nodes) => + setFluxNodeStreamId(nodes, { id: correspondingNodeId, - text: choice.delta?.content ?? UNDEFINED_RESPONSE_STRING, - streamId, // This will cause a throw if the streamId has changed. - }); - } catch (e: any) { - // If the stream id does not match, - // it is stale and we should abort. - abortController.abort(e.message); - - return newerNodes; - } - }); + streamId: undefined, + }) + ); + + setEdges((edges) => + modifyFluxEdge(edges, { + source: parentNode.id, + target: correspondingNodeId, + animated: false, + }) + ); + } + } catch (err) { + console.error(err); } + } - // We cannot return within the loop, and we do - // not want to execute the code below, so we break. - if (abortController.signal.aborted) break; + // If the stream wasn't aborted or was aborted due to a cancelation. + if ( + !abortController.signal.aborted || + abortController.signal.reason === STREAM_CANCELED_ERROR_MESSAGE + ) { + // Mark all the edges as no longer animated. + for (let i = 0; i < responses; i++) { + const correspondingNodeId = + overrideExistingIfPossible && i < currentNodeChildren.length + ? currentNodeChildren[i].id + : newNodes[newNodes.length - responses + i].id; - // If the choice has a finish reason, then it's the final - // choice and we can mark it as no longer animated right now. - if (choice.finish_reason !== null) { // Reset the stream id. setNodes((nodes) => setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined }) @@ -442,44 +481,92 @@ function App() { }) ); } - } catch (err) { - console.error(err); } + })().catch((err) => + toast({ + title: err.toString(), + status: "error", + ...TOAST_CONFIG, + }) + ); + } else if (apiProvider === "gemini") { + const googleApiKey = localStorage.getItem(GOOGLE_API_KEY_LOCAL_STORAGE_KEY); + if (!googleApiKey) { + throw new Error("Google API key is not set."); } + // TODO: decodes weird escape characters. gotta fix this later. + const decodedApiKey = decodeURIComponent(googleApiKey); + const fixedApiKey = decodedApiKey.replace(/^"|"$/g, ""); + const genAI = new GoogleGenerativeAI(fixedApiKey); + const model = genAI.getGenerativeModel({ model: "gemini-pro" }); + const historyMessages = messagesFromLineage(parentNodeLineage, settings) + .filter((message) => message.role !== "system") + .map((message) => ({ + role: message.role === "assistant" ? "model" : message.role, + parts: message.content.toString(), + })); + const trimmedHistoryMessages = historyMessages.slice(0, -1); + + const uniqueHistoryMessages = Array.from( + new Set(trimmedHistoryMessages.map((message) => JSON.stringify(message))) + ).map((message) => JSON.parse(message)); + console.log(uniqueHistoryMessages); + const chat = model.startChat({ + history: uniqueHistoryMessages, + generationConfig: { + maxOutputTokens: 32000, + }, + }); - // If the stream wasn't aborted or was aborted due to a cancelation. - if ( - !abortController.signal.aborted || - abortController.signal.reason === STREAM_CANCELED_ERROR_MESSAGE - ) { - // Mark all the edges as no longer animated. - for (let i = 0; i < responses; i++) { - const correspondingNodeId = - overrideExistingIfPossible && i < currentNodeChildren.length - ? currentNodeChildren[i].id - : newNodes[newNodes.length - responses + i].id; - - // Reset the stream id. - setNodes((nodes) => - setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined }) - ); + const abortController = new AbortController(); - setEdges((edges) => - modifyFluxEdge(edges, { - source: parentNode.id, - target: correspondingNodeId, - animated: false, - }) - ); - } + for (let i = 0; i < responses; i++) { + const msg = parentNode.data.text; + const correspondingNodeId = + overrideExistingIfPossible && i < currentNodeChildren.length + ? currentNodeChildren[i].id + : newNodes[newNodes.length - responses + i].id; + + const streamPromise = chat.sendMessageStream(msg); + + streamPromise.then((stream) => { + (async () => { + let text = ""; + for await (const chunk of stream.stream) { + if (abortController.signal.aborted) break; + const chunkText = chunk.text(); + text += chunkText; + + setNodes((newerNodes) => { + return modifyFluxNodeText(newerNodes, { + asHuman: false, + id: correspondingNodeId, + text: text, + }); + }); + } + + setEdges((edges) => + modifyFluxEdge(edges, { + source: parentNode.id, + target: correspondingNodeId, + animated: false, + }) + ); + + setNodes((nodes) => + setFluxNodeStreamId(nodes, { id: correspondingNodeId, streamId: undefined }) + ); + })().catch((err) => { + toast({ + title: err.toString(), + status: "error", + ...TOAST_CONFIG, + }); + }); + }); } - })().catch((err) => - toast({ - title: err.toString(), - status: "error", - ...TOAST_CONFIG, - }) - ); + } setNodes(markOnlyNodeAsSelected(newNodes, firstCompletionId!)); @@ -490,8 +577,6 @@ function App() { let newEdges = [...edges]; for (let i = 0; i < responses; i++) { - // Update the links between - // re-used nodes if necessary. if (overrideExistingIfPossible && i < currentNodeChildren.length) { const childId = currentNodeChildren[i].id; @@ -504,11 +589,8 @@ function App() { animated: true, }; } else { - // The new nodes are added to the end of the array, so we need to - // subtract responses from and add i to length of the array to access. const childId = newNodes[newNodes.length - responses + i].id; - // Otherwise, add a new edge. newEdges.push( newFluxEdge({ source: parentNode.id, @@ -527,6 +609,7 @@ function App() { if (MIXPANEL_TOKEN) mixpanel.track("Submitted Prompt"); // KPI }; + // The completeNextWords function remains unchanged const completeNextWords = () => { takeSnapshot(); @@ -844,11 +927,15 @@ function App() { if (rawSettings !== null) { return JSON.parse(rawSettings) as Settings; } else { - return DEFAULT_SETTINGS; + return { + ...DEFAULT_SETTINGS, + apiProvider: "openai", + }; } }); const isGPT4 = settings.model.includes("gpt-4"); + const isGemini = settings.model.includes("Google Gemini"); // Auto save. const isSavingSettings = useDebouncedEffect( @@ -864,6 +951,9 @@ function App() { //////////////////////////////////////////////////////////////*/ const [apiKey, setApiKey] = useLocalStorage(API_KEY_LOCAL_STORAGE_KEY); + const [googleApiKey, setGoogleApiKey] = useLocalStorage( + GOOGLE_API_KEY_LOCAL_STORAGE_KEY + ); const [availableModels, setAvailableModels] = useState(null); @@ -889,13 +979,19 @@ function App() { } if (modelsLoadIndex !== modelsLoadCounter.current) return; + if (googleApiKey) { + modelList.push("Google Gemini"); + } + if (modelList.length === 0) modelList.push(settings.model); setAvailableModels(modelList); if (!modelList.includes(settings.model)) { const oldModel = settings.model; - const newModel = modelList.includes(DEFAULT_SETTINGS.model) ? DEFAULT_SETTINGS.model : modelList[0]; + const newModel = modelList.includes(DEFAULT_SETTINGS.model) + ? DEFAULT_SETTINGS.model + : modelList[0]; setSettings((settings) => ({ ...settings, model: newModel })); @@ -911,7 +1007,7 @@ function App() { }, [apiKey]); const isAnythingSaving = isSavingReactFlow || isSavingSettings; - const isAnythingLoading = isAnythingSaving || (availableModels === null); + const isAnythingLoading = isAnythingSaving || availableModels === null; useBeforeunload((event: BeforeUnloadEvent) => { // Prevent leaving the page before saving. @@ -1047,6 +1143,8 @@ function App() { onClose={onCloseSettingsModal} apiKey={apiKey} setApiKey={setApiKey} + googleApiKey={googleApiKey} + setGoogleApiKey={setGoogleApiKey} availableModels={availableModels} /> void; newConnectedToSelectedNode: (type: FluxNodeType) => void; isGPT4: boolean; + isGemini: boolean; settings: Settings; setSettings: (settings: Settings) => void; apiKey: string | null; @@ -273,8 +275,8 @@ export function Prompt({   {promptNodeType === FluxNodeType.User - ? displayNameFromFluxNodeType(FluxNodeType.GPT, isGPT4) - : displayNameFromFluxNodeType(FluxNodeType.User, isGPT4)} + ? displayNameFromFluxNodeType(FluxNodeType.GPT, isGPT4, isGemini) + : displayNameFromFluxNodeType(FluxNodeType.User, isGPT4, isGemini)}   response diff --git a/src/components/modals/APIKeyModal.tsx b/src/components/modals/APIKeyModal.tsx index 7f3bf52..3b981d2 100644 --- a/src/components/modals/APIKeyModal.tsx +++ b/src/components/modals/APIKeyModal.tsx @@ -38,7 +38,12 @@ export function APIKeyModal({ - + We will never upload, log, or store your API key outside of your browser's local storage. Verify for yourself{" "} diff --git a/src/components/modals/SettingsModal.tsx b/src/components/modals/SettingsModal.tsx index 6861cdc..0552dea 100644 --- a/src/components/modals/SettingsModal.tsx +++ b/src/components/modals/SettingsModal.tsx @@ -26,7 +26,9 @@ export const SettingsModal = memo(function SettingsModal({ setSettings, apiKey, setApiKey, - availableModels + googleApiKey, + setGoogleApiKey, + availableModels, }: { isOpen: boolean; onClose: () => void; @@ -34,6 +36,8 @@ export const SettingsModal = memo(function SettingsModal({ setSettings: (settings: Settings) => void; apiKey: string | null; setApiKey: (apiKey: string) => void; + googleApiKey: string | null; + setGoogleApiKey: (apiKey: string) => void; availableModels: string[] | null; }) { const reset = () => { @@ -88,7 +92,26 @@ export const SettingsModal = memo(function SettingsModal({ }} /> - + {/* OpenAI */} + + + {/* Google Gemini */} + void; + label: string; + placeholder: string; + link?: string; } & BoxProps) { return (