Skip to content

Commit

Permalink
feat: Add support for nested transaction rollbacks in SQL databases
Browse files Browse the repository at this point in the history
This change adds support for handling rollbacks in nested transactions
in SQL databases. Specifically, the inner transaction should be rolled
back if the outer transaction fails.

To do this we keep track of the transaction ID and transaction depth so we can
re-use an existing open transaction in the underlying engine. This change also
allows the use of the `$transaction` method on an interactive transaction client.

depends-on: prisma/prisma-engines#4375
  • Loading branch information
LucianBuzzo committed Aug 29, 2024
1 parent 7955bca commit 4a19155
Show file tree
Hide file tree
Showing 19 changed files with 7,515 additions and 3,312 deletions.
7 changes: 7 additions & 0 deletions packages/adapter-d1/src/d1.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class D1Transaction extends D1Queryable<StdClient> implements Transaction {
super(client)
}

// eslint-disable-next-line @typescript-eslint/require-await
async begin(): Promise<Result<void>> {
debug(`[js::begin]`)

return ok(undefined)
}

// eslint-disable-next-line @typescript-eslint/require-await
async commit(): Promise<Result<void>> {
debug(`[js::commit]`)
Expand Down
9 changes: 9 additions & 0 deletions packages/adapter-libsql/src/libsql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ class LibSqlTransaction extends LibSqlQueryable<TransactionClient> implements Tr
super(client)
}

// eslint-disable-next-line @typescript-eslint/require-await
async begin(): Promise<Result<void>> {
debug(`[js::commit]`)

throw new Error('Method not implemented.')

return ok(undefined)
}

async commit(): Promise<Result<void>> {
debug(`[js::commit]`)

Expand Down
7 changes: 7 additions & 0 deletions packages/adapter-neon/src/neon.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ class NeonTransaction extends NeonWsQueryable<neon.PoolClient> implements Transa
super(client)
}

async begin(): Promise<Result<void>> {
debug(`[js::begin]`)

this.client.release()
return Promise.resolve(ok(undefined))
}

async commit(): Promise<Result<void>> {
debug(`[js::commit]`)

Expand Down
7 changes: 7 additions & 0 deletions packages/adapter-pg-worker/src/pg.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ class PgTransaction extends PgQueryable<TransactionClient> implements Transactio
super(client)
}

async begin(): Promise<Result<void>> {
debug(`[js::begin]`)

this.client.release()
return ok(undefined)
}

async commit(): Promise<Result<void>> {
debug(`[js::commit]`)

Expand Down
7 changes: 7 additions & 0 deletions packages/adapter-pg/src/pg.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ class PgTransaction extends PgQueryable<TransactionClient> implements Transactio
super(client)
}

async begin(): Promise<Result<void>> {
debug(`[js::begin]`)

this.client.release()
return ok(undefined)
}

async commit(): Promise<Result<void>> {
debug(`[js::commit]`)

Expand Down
7 changes: 7 additions & 0 deletions packages/adapter-planetscale/src/planetscale.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class PlanetScaleTransaction extends PlanetScaleQueryable<planetScale.Transactio
super(tx)
}

async begin(): Promise<Result<void>> {
debug(`[js::begin]`)

this.txDeferred.resolve()
return Promise.resolve(ok(await this.txResultPromise))
}

async commit(): Promise<Result<void>> {
debug(`[js::commit]`)

Expand Down
6 changes: 6 additions & 0 deletions packages/client/src/runtime/RequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ export class RequestHandler {
const interactiveTransaction =
request.transaction?.kind === 'itx' ? getItxTransactionOptions(request.transaction) : undefined

if (interactiveTransaction) {
interactiveTransaction.payload = {
new_tx_id: interactiveTransaction?.id,
}
}

const response = await this.client._engine.request(request.protocolQuery, {
traceparent: this.client._tracingHelper.getTraceParent(),
interactiveTransaction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ You very likely have the wrong "binaryTarget" defined in the schema.prisma file.
max_wait: arg.maxWait,
timeout: arg.timeout,
isolation_level: arg.isolationLevel,
new_tx_id: arg?.newTxId,
})

const result = await Connection.onHttpError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type Options = {
maxWait?: number
timeout?: number
isolationLevel?: IsolationLevel
newTxId?: string
}

export type InteractiveTransactionInfo<Payload = unknown> = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ export class DataProxyEngine implements Engine<DataProxyTxInfoPayload> {
max_wait: arg.maxWait,
timeout: arg.timeout,
isolation_level: arg.isolationLevel,
new_tx_id: arg?.newTxId,
})

const url = await this.url('transaction/start')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ export class LibraryEngine implements Engine<undefined> {
max_wait: arg.maxWait,
timeout: arg.timeout,
isolation_level: arg.isolationLevel,
new_tx_id: arg?.newTxId,
})

result = await this.engine?.startTransaction(jsonOptions, headerStr)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const denylist = ['$connect', '$disconnect', '$on', '$transaction', '$use', '$extends'] as const
const denylist = ['$connect', '$disconnect', '$on', '$use', '$extends'] as const

export const itxClientDenyList = denylist as ReadonlyArray<string | symbol>

Expand Down
9 changes: 6 additions & 3 deletions packages/client/src/runtime/getPrismaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -782,17 +782,21 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client
*/
async _transactionWithCallback({
callback,
options,
options = {},
}: {
callback: (client: Client) => Promise<unknown>
options?: Options
options?: Options & { newTxId?: string }
}) {
if (this[TX_ID]) {
options.newTxId = this[TX_ID]
}
const headers = { traceparent: this._tracingHelper.getTraceParent() }

const optionsWithDefaults: Options = {
maxWait: options?.maxWait ?? this._engineConfig.transactionOptions.maxWait,
timeout: options?.timeout ?? this._engineConfig.transactionOptions.timeout,
isolationLevel: options?.isolationLevel ?? this._engineConfig.transactionOptions.isolationLevel,
newTxId: options.newTxId,
}
const info = await this._engine.transaction('start', headers, optionsWithDefaults)

Expand All @@ -803,7 +807,6 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client

result = await callback(this._createItxClient(transaction))

// it went well, then we commit the transaction
await this._engine.transaction('commit', headers, info)
} catch (e: any) {
// it went bad, then we rollback the transaction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ function itxWithinGenericExtension() {

void xclient.$transaction((tx) => {
expectTypeOf(tx).toHaveProperty('helperMethod')
expectTypeOf(tx).not.toHaveProperty('$transaction')
expectTypeOf(tx).not.toHaveProperty('$extends')
return Promise.resolve()
})
Expand Down
2 changes: 1 addition & 1 deletion packages/client/tests/functional/extensions/itx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ testMatrix.setupTestSuite(
if (isTransaction) {
expect(ctx.$connect).toBeUndefined()
expect(ctx.$disconnect).toBeUndefined()
expect(ctx.$transaction).toBeUndefined()
expect(ctx.$transaction).toBeDefined()
expect(ctx.$extends).toBeUndefined()
} else {
expect(ctx.$connect).toBeDefined()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { faker } from '@faker-js/faker'
import { ClientEngineType } from '@prisma/internals'
import { copycat } from '@snaplet/copycat'

Expand Down Expand Up @@ -196,11 +197,77 @@ testMatrix.setupTestSuite(
await expect(result).resolves.toHaveLength(2)
})

/**
* If a parent transaction is rolled back, the child transaction should also rollback
* - This is only supported in SQL derived servers
*/
testIf(provider === Providers.POSTGRESQL)('sql: nested rollback', async () => {
const rand1 = Math.floor(Math.random() * 1000)
const rand2 = rand1 + 1
const email1 = 'user_' + rand1 + '@website.com'
const email2 = 'user_' + rand2 + '@website.com'
const client = prisma
await expect(
client.$transaction(async (tx) => {
await tx.user.create({
data: {
email: email1,
},
})

await tx.$transaction(async (tx2) => {
await tx2.user.create({
data: {
email: email2,
},
})
})

// Abort the outer transaction
throw new Error('Rollback')
}),
).rejects.toThrow(/Rollback/)

const result = await prisma.user.findMany({
where: {
email: {
in: [email1, email2],
},
},
})

// Both transactions should rollback
expect(result).toHaveLength(0)
})

testIf(provider === Providers.POSTGRESQL)('sql: multiple interactive transactions', async () => {
const existingEmail = faker.internet.email()

await prisma.$transaction(async (tx) => {
await tx.user.create({ data: { email: existingEmail } })
})

await prisma.$transaction(async (tx) => {
await tx.user.create({ data: { email: existingEmail + 1 } })
})

const result = await prisma.user.findMany({
where: {
email: {
in: [existingEmail, existingEmail + 1],
},
},
})

// Both transactions should succeed
expect(result).toHaveLength(2)
})

/**
* We don't allow certain methods to be called in a transaction
*/
test('forbidden', async () => {
const forbidden = ['$connect', '$disconnect', '$on', '$transaction', '$use']
const forbidden = ['$connect', '$disconnect', '$on', '$use']
expect.assertions(forbidden.length + 1)

const result = prisma.$transaction((prisma) => {
Expand Down
1 change: 1 addition & 0 deletions packages/driver-adapter-utils/src/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const bindTransaction = (errorRegistry: ErrorRegistryInternal, transaction: Tran
options: transaction.options,
queryRaw: wrapAsync(errorRegistry, transaction.queryRaw.bind(transaction)),
executeRaw: wrapAsync(errorRegistry, transaction.executeRaw.bind(transaction)),
begin: wrapAsync(errorRegistry, transaction.begin.bind(transaction)),
commit: wrapAsync(errorRegistry, transaction.commit.bind(transaction)),
rollback: wrapAsync(errorRegistry, transaction.rollback.bind(transaction)),
}
Expand Down
4 changes: 4 additions & 0 deletions packages/driver-adapter-utils/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ export interface Transaction extends Queryable {
* Transaction options.
*/
readonly options: TransactionOptions
/**
* Begin the transaction.
*/
begin(): Promise<Result<void>>
/**
* Commit the transaction.
*/
Expand Down
Loading

0 comments on commit 4a19155

Please sign in to comment.