Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infered types based on the tool result #431

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions components/stocks/events.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import { GetEventProps } from '@/lib/types'
import { format, parseISO } from 'lib/utils'

interface Event {
date: string
headline: string
description: string
}

export function Events({ props: events }: { props: Event[] }) {
export function Events({ events }: GetEventProps) {
return (
<div className="-mt-2 flex w-full flex-col gap-2 py-4">
{events.map(event => (
Expand Down
13 changes: 2 additions & 11 deletions components/stocks/stock-purchase.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,10 @@ import { useActions, useAIState, useUIState } from 'ai/rsc'
import { formatNumber } from '@/lib/utils'

import type { AI } from '@/lib/chat/actions'
import { PurchaseProps } from '@/lib/types'

interface Purchase {
numberOfShares?: number
symbol: string
price: number
status: 'requires_action' | 'completed' | 'expired'
}

export function Purchase({
props: { numberOfShares, symbol, price, status = 'expired' }
}: {
props: Purchase
}) {
export function Purchase({ numberOfShares, symbol, price, status = 'expired' }: PurchaseProps) {
const [value, setValue] = useState(numberOfShares || 100)
const [purchasingUI, setPurchasingUI] = useState<null | React.ReactNode>(null)
const [aiState, setAIState] = useAIState<typeof AI>()
Expand Down
8 changes: 2 additions & 6 deletions components/stocks/stock.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@
import { useState, useRef, useEffect, useId } from 'react'
import { subMonths, format } from 'lib/utils'
import { useAIState } from 'ai/rsc'
import { ShowStockPriceProps } from '@/lib/types'

interface Stock {
symbol: string
price: number
delta: number
}

function scaleLinear(domain: [number, number], range: [number, number]) {
const [d0, d1] = domain
Expand Down Expand Up @@ -47,7 +43,7 @@ function useResizeObserver<T extends HTMLElement = HTMLElement>(
return size
}

export function Stock({ props: { symbol, price, delta } }: { props: Stock }) {
export function Stock({ symbol, price, delta }: ShowStockPriceProps ) {
const [aiState, setAIState] = useAIState()
const id = useId()

Expand Down
25 changes: 8 additions & 17 deletions components/stocks/stocks.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@
import { useActions, useUIState } from 'ai/rsc'

import type { AI } from '@/lib/chat/actions'

interface Stock {
symbol: string
price: number
delta: number
}

export function Stocks({ props: stocks }: { props: Stock[] }) {
import { ListStockProps } from '@/lib/types'
export function Stocks({ stocks }: ListStockProps) {
const [, setMessages] = useUIState<typeof AI>()
const { submitUserMessage } = useActions()

Expand All @@ -27,9 +21,8 @@ export function Stocks({ props: stocks }: { props: Stock[] }) {
}}
>
<div
className={`text-xl ${
stock.delta > 0 ? 'text-green-600' : 'text-red-600'
} flex w-11 flex-row justify-center rounded-md bg-white/10 p-2`}
className={`text-xl ${stock.delta > 0 ? 'text-green-600' : 'text-red-600'
} flex w-11 flex-row justify-center rounded-md bg-white/10 p-2`}
>
{stock.delta > 0 ? '↑' : '↓'}
</div>
Expand All @@ -41,16 +34,14 @@ export function Stocks({ props: stocks }: { props: Stock[] }) {
</div>
<div className="ml-auto flex flex-col">
<div
className={`${
stock.delta > 0 ? 'text-green-600' : 'text-red-600'
} bold text-right uppercase`}
className={`${stock.delta > 0 ? 'text-green-600' : 'text-red-600'
} bold text-right uppercase`}
>
{` ${((stock.delta / stock.price) * 100).toExponential(1)}%`}
</div>
<div
className={`${
stock.delta > 0 ? 'text-green-700' : 'text-red-700'
} text-right text-base`}
className={`${stock.delta > 0 ? 'text-green-700' : 'text-red-700'
} text-right text-base`}
>
{stock.delta.toExponential(1)}
</div>
Expand Down
84 changes: 19 additions & 65 deletions lib/chat/actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import {
} from '@/lib/utils'
import { saveChat } from '@/app/actions'
import { SpinnerMessage, UserMessage } from '@/components/stocks/message'
import { Chat, Message } from '@/lib/types'
import { Chat, GetEventProps, ListStockProps, Message, PurchaseProps, ShowStockPriceProps } from '@/lib/types'
import { auth } from '@/auth'
import { getEventsSchema, listStockSchema, showStockPriceSchema, showStockPurchaseSchema } from '../schemas'

async function confirmPurchase(symbol: string, price: number, amount: number) {
'use server'
Expand Down Expand Up @@ -89,9 +90,8 @@ async function confirmPurchase(symbol: string, price: number, amount: number) {
{
id: nanoid(),
role: 'system',
content: `[User has purchased ${amount} shares of ${symbol} at ${price}. Total cost = ${
amount * price
}]`
content: `[User has purchased ${amount} shares of ${symbol} at ${price}. Total cost = ${amount * price
}]`
}
]
})
Expand Down Expand Up @@ -179,15 +179,7 @@ async function submitUserMessage(content: string) {
tools: {
listStocks: {
description: 'List three imaginary stocks that are trending.',
parameters: z.object({
stocks: z.array(
z.object({
symbol: z.string().describe('The symbol of the stock'),
price: z.number().describe('The price of the stock'),
delta: z.number().describe('The change in price of the stock')
})
)
}),
parameters: listStockSchema,
generate: async function* ({ stocks }) {
yield (
<BotCard>
Expand Down Expand Up @@ -232,23 +224,15 @@ async function submitUserMessage(content: string) {

return (
<BotCard>
<Stocks props={stocks} />
<Stocks stocks={stocks} />
</BotCard>
)
}
},
showStockPrice: {
description:
'Get the current stock price of a given stock or currency. Use this to show the price to the user.',
parameters: z.object({
symbol: z
.string()
.describe(
'The name or symbol of the stock or currency. e.g. DOGE/AAPL/USD.'
),
price: z.number().describe('The price of the stock.'),
delta: z.number().describe('The change in price of the stock')
}),
parameters: showStockPriceSchema,
generate: async function* ({ symbol, price, delta }) {
yield (
<BotCard>
Expand Down Expand Up @@ -293,28 +277,15 @@ async function submitUserMessage(content: string) {

return (
<BotCard>
<Stock props={{ symbol, price, delta }} />
<Stock symbol={symbol} price={price} delta={delta} />
</BotCard>
)
}
},
showStockPurchase: {
description:
'Show price and the UI to purchase a stock or currency. Use this if the user wants to purchase a stock or currency.',
parameters: z.object({
symbol: z
.string()
.describe(
'The name or symbol of the stock or currency. e.g. DOGE/AAPL/USD.'
),
price: z.number().describe('The price of the stock.'),
numberOfShares: z
.number()
.optional()
.describe(
'The **number of shares** for a stock or currency to purchase. Can be optional if the user did not specify it.'
)
}),
parameters: showStockPurchaseSchema,
generate: async function* ({ symbol, price, numberOfShares = 100 }) {
const toolCallId = nanoid()

Expand Down Expand Up @@ -400,12 +371,10 @@ async function submitUserMessage(content: string) {
return (
<BotCard>
<Purchase
props={{
numberOfShares,
symbol,
price: +price,
status: 'requires_action'
}}
price={+price}
status="requires_action"
symbol={symbol}
numberOfShares={numberOfShares}
/>
</BotCard>
)
Expand All @@ -415,17 +384,7 @@ async function submitUserMessage(content: string) {
getEvents: {
description:
'List funny imaginary events between user highlighted dates that describe stock activity.',
parameters: z.object({
events: z.array(
z.object({
date: z
.string()
.describe('The date of the event, in ISO-8601 format'),
headline: z.string().describe('The headline of the event'),
description: z.string().describe('The description of the event')
})
)
}),
parameters: getEventsSchema,
generate: async function* ({ events }) {
yield (
<BotCard>
Expand Down Expand Up @@ -470,7 +429,7 @@ async function submitUserMessage(content: string) {

return (
<BotCard>
<Events props={events} />
<Events events={events} />
</BotCard>
)
}
Expand Down Expand Up @@ -557,24 +516,19 @@ export const getUIStateFromAIState = (aiState: Chat) => {
message.content.map(tool => {
return tool.toolName === 'listStocks' ? (
<BotCard>
{/* TODO: Infer types based on the tool result*/}
{/* @ts-expect-error */}
<Stocks props={tool.result} />
<Stocks {...tool.result as ListStockProps} />
</BotCard>
) : tool.toolName === 'showStockPrice' ? (
<BotCard>
{/* @ts-expect-error */}
<Stock props={tool.result} />
<Stock {...tool.result as ShowStockPriceProps} />
</BotCard>
) : tool.toolName === 'showStockPurchase' ? (
<BotCard>
{/* @ts-expect-error */}
<Purchase props={tool.result} />
<Purchase {...tool.result as PurchaseProps} />
</BotCard>
) : tool.toolName === 'getEvents' ? (
<BotCard>
{/* @ts-expect-error */}
<Events props={tool.result} />
<Events events={tool.result as GetEventProps["events"]} />
</BotCard>
) : null
})
Expand Down
46 changes: 46 additions & 0 deletions lib/schemas/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { z } from "zod"
const getEventsSchema = z.object({
events: z.array(
z.object({
date: z
.string()
.describe('The date of the event, in ISO-8601 format'),
headline: z.string().describe('The headline of the event'),
description: z.string().describe('The description of the event')
})
)
})

const showStockPurchaseSchema = z.object({
symbol: z
.string()
.describe(
'The name or symbol of the stock or currency. e.g. DOGE/AAPL/USD.'
),
price: z.number().describe('The price of the stock.'),
numberOfShares: z
.number()
.optional()
.describe(
'The **number of shares** for a stock or currency to purchase. Can be optional if the user did not specify it.'
)
})
const showStockPriceSchema = z.object({
symbol: z
.string()
.describe(
'The name or symbol of the stock or currency. e.g. DOGE/AAPL/USD.'
),
price: z.number().describe('The price of the stock.'),
delta: z.number().describe('The change in price of the stock')
})
const listStockSchema = z.object({
stocks: z.array(
z.object({
symbol: z.string().describe('The symbol of the stock'),
price: z.number().describe('The price of the stock'),
delta: z.number().describe('The change in price of the stock')
})
)
})
export { getEventsSchema, showStockPurchaseSchema, showStockPriceSchema, listStockSchema }
14 changes: 12 additions & 2 deletions lib/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { CoreMessage } from 'ai'
import { z } from 'zod'
import { showStockPriceSchema, getEventsSchema, listStockSchema, showStockPurchaseSchema } from './schemas'

export type Message = CoreMessage & {
id: string
Expand All @@ -17,8 +19,8 @@ export interface Chat extends Record<string, any> {
export type ServerActionResult<Result> = Promise<
| Result
| {
error: string
}
error: string
}
>

export interface Session {
Expand All @@ -39,3 +41,11 @@ export interface User extends Record<string, any> {
password: string
salt: string
}
export type ShowStockPriceProps = z.infer<typeof showStockPriceSchema>
export type GetEventProps = z.infer<typeof getEventsSchema>
export type ListStockProps = z.infer<typeof listStockSchema>
export type ShowStockPurchaseProps = z.infer<typeof showStockPurchaseSchema>

export type PurchaseProps = {
status: 'requires_action' | 'completed' | 'expired'
} & ShowStockPurchaseProps