diff --git a/packages/repo/src/util.ts b/packages/repo/src/util.ts index 8ec5239fa4f..42a59f6ea82 100644 --- a/packages/repo/src/util.ts +++ b/packages/repo/src/util.ts @@ -10,6 +10,7 @@ import { check, schema, cidForCbor, + byteIterableToStream, } from '@atproto/common' import { ipldToLex, lexToIpld, LexValue, RepoRecord } from '@atproto/lexicon' @@ -33,6 +34,7 @@ import BlockMap from './block-map' import { MissingBlocksError } from './error' import * as parse from './parse' import { Keypair } from '@atproto/crypto' +import { Readable } from 'stream' export async function* verifyIncomingCarBlocks( car: AsyncIterable, @@ -43,16 +45,31 @@ export async function* verifyIncomingCarBlocks( } } -export const writeCar = ( +// we have to turn the car writer output into a stream in order to properly handle errors +export function writeCarStream( root: CID | null, fn: (car: BlockWriter) => Promise, -): AsyncIterable => { +): Readable { const { writer, out } = root !== null ? CarWriter.create(root) : CarWriter.create() - fn(writer).finally(() => writer.close()) + const stream = byteIterableToStream(out) + fn(writer) + .catch((err) => { + stream.destroy(err) + }) + .finally(() => writer.close()) + return stream +} - return out +export async function* writeCar( + root: CID | null, + fn: (car: BlockWriter) => Promise, +): AsyncIterable { + const stream = writeCarStream(root, fn) + for await (const chunk of stream) { + yield chunk + } } export const blocksToCarStream = ( diff --git a/packages/repo/tests/util.test.ts b/packages/repo/tests/util.test.ts new file mode 100644 index 00000000000..f341cadfea9 --- /dev/null +++ b/packages/repo/tests/util.test.ts @@ -0,0 +1,21 @@ +import { dataToCborBlock, wait } from '@atproto/common' +import { writeCar } from '../src' + +describe('Utils', () => { + describe('writeCar()', () => { + it('propagates errors', async () => { + const iterate = async () => { + const iter = writeCar(null, async (car) => { + await wait(1) + const block = await dataToCborBlock({ test: 1 }) + await car.put(block) + throw new Error('Oops!') + }) + for await (const bytes of iter) { + // no-op + } + } + await expect(iterate).rejects.toThrow('Oops!') + }) + }) +}) diff --git a/packages/xrpc-server/src/server.ts b/packages/xrpc-server/src/server.ts index 2c30ae8a440..f8dc5c82f8c 100644 --- a/packages/xrpc-server/src/server.ts +++ b/packages/xrpc-server/src/server.ts @@ -235,6 +235,7 @@ export class Server { } else if (output?.body instanceof Readable) { res.header('Content-Type', output.encoding) res.status(200) + res.once('error', (err) => res.destroy(err)) forwardStreamErrors(output.body, res) output.body.pipe(res) } else if (output) { diff --git a/packages/xrpc-server/tests/responses.test.ts b/packages/xrpc-server/tests/responses.test.ts new file mode 100644 index 00000000000..0eaccba0633 --- /dev/null +++ b/packages/xrpc-server/tests/responses.test.ts @@ -0,0 +1,77 @@ +import * as http from 'http' +import getPort from 'get-port' +import xrpc, { ServiceClient } from '@atproto/xrpc' +import { byteIterableToStream } from '@atproto/common' +import { createServer, closeServer } from './_util' +import * as xrpcServer from '../src' + +const LEXICONS = [ + { + lexicon: 1, + id: 'io.example.readableStream', + defs: { + main: { + type: 'query', + parameters: { + type: 'params', + properties: { + shouldErr: { type: 'boolean' }, + }, + }, + output: { + encoding: 'application/vnd.ipld.car', + }, + }, + }, + }, +] + +describe('Responses', () => { + let s: http.Server + const server = xrpcServer.createServer(LEXICONS) + server.method( + 'io.example.readableStream', + async (ctx: { params: xrpcServer.Params }) => { + async function* iter(): AsyncIterable { + for (let i = 0; i < 5; i++) { + yield new Uint8Array([i]) + } + if (ctx.params.shouldErr) { + throw new Error('error') + } + } + return { + encoding: 'application/vnd.ipld.car', + body: byteIterableToStream(iter()), + } + }, + ) + xrpc.addLexicons(LEXICONS) + + let client: ServiceClient + let url: string + beforeAll(async () => { + const port = await getPort() + s = await createServer(port, server) + url = `http://localhost:${port}` + client = xrpc.service(url) + }) + afterAll(async () => { + await closeServer(s) + }) + + it('returns readable streams of bytes', async () => { + const res = await client.call('io.example.readableStream', { + shouldErr: false, + }) + const expected = new Uint8Array([0, 1, 2, 3, 4]) + expect(res.data).toEqual(expected) + }) + + it('handles errs on readable streams of bytes', async () => { + const attempt = client.call('io.example.readableStream', { + shouldErr: true, + }) + await expect(attempt).rejects.toThrow() + }) +})