diff --git a/packages/3-extensions/sql-orm-client/src/collection-contract.ts b/packages/3-extensions/sql-orm-client/src/collection-contract.ts index e139a58181..ee0569536d 100644 --- a/packages/3-extensions/sql-orm-client/src/collection-contract.ts +++ b/packages/3-extensions/sql-orm-client/src/collection-contract.ts @@ -322,6 +322,25 @@ export function resolvePrimaryKeyColumn(contract: Contract, tableNam return contract.storage.tables[tableName]?.primaryKey?.columns[0] ?? 'id'; } +export function resolveRowIdentityColumns( + contract: Contract, + tableName: string, +): readonly string[] { + const table = contract.storage.tables[tableName]; + if (!table) { + return []; + } + if (table.primaryKey && table.primaryKey.columns.length > 0) { + return table.primaryKey.columns; + } + for (const unique of table.uniques) { + if (unique.columns.length > 0) { + return unique.columns; + } + } + return []; +} + export function assertReturningCapability(contract: Contract, action: string): void { if (hasContractCapability(contract, 'returning')) { return; diff --git a/packages/3-extensions/sql-orm-client/src/collection.ts b/packages/3-extensions/sql-orm-client/src/collection.ts index 2367ae3523..29ee044b2a 100644 --- a/packages/3-extensions/sql-orm-client/src/collection.ts +++ b/packages/3-extensions/sql-orm-client/src/collection.ts @@ -2,6 +2,7 @@ import type { Contract } from '@prisma-next/contract/types'; import { AsyncIterableResult } from '@prisma-next/framework-components/runtime'; import type { SqlStorage } from '@prisma-next/sql-contract/types'; import { + type AnyExpression, BinaryExpr, ColumnRef, isWhereExpr, @@ -25,6 +26,7 @@ import { resolveModelTableName, resolvePolymorphismInfo, resolvePrimaryKeyColumn, + resolveRowIdentityColumns, resolveUpsertConflictColumns, } from './collection-contract'; import { dispatchCollectionRows } from './collection-dispatch'; @@ -107,6 +109,7 @@ import { type RelatedModelName, type RelationNames, type ResolvedCreateInput, + type RuntimeQueryable, type ShorthandWhereFilter, type UniqueConstraintCriterion, type VariantModelRow, @@ -1038,12 +1041,20 @@ export class Collection< return this.#reloadMutationRowByPrimaryKey(pkCriterion); } - const rows = await this.updateAll( - data as State['hasWhere'] extends true - ? Partial> - : never, - ); - return rows[0] ?? null; + return withMutationScope(this.ctx.runtime, async (scope) => { + const scoped = this.#withRuntime(scope); + const identityWhere = await scoped.#findFirstMatchingRowIdentityWhere(); + if (!identityWhere) { + return null; + } + const narrowed = scoped.#clone({ filters: [identityWhere] }); + const rows = await narrowed.updateAll( + data as State['hasWhere'] extends true + ? Partial> + : never, + ); + return rows[0] ?? null; + }); } updateAll( @@ -1115,15 +1126,26 @@ export class Collection< this: State['hasWhere'] extends true ? Collection : never, ): Promise { assertReturningCapability(this.contract, 'delete()'); - const rows = await this.deleteAll().toArray(); - return rows[0] ?? null; + return withMutationScope(this.ctx.runtime, async (scope) => { + const scoped = this.#withRuntime(scope); + const identityWhere = await scoped.#findFirstMatchingRowIdentityWhere(); + if (!identityWhere) { + return null; + } + const narrowed = scoped.#clone({ filters: [identityWhere] }); + const rows = await narrowed.#executeDeleteReturning().toArray(); + return rows[0] ?? null; + }); } deleteAll( this: State['hasWhere'] extends true ? Collection : never, ): AsyncIterableResult { assertReturningCapability(this.contract, 'deleteAll()'); + return this.#executeDeleteReturning(); + } + #executeDeleteReturning(): AsyncIterableResult { const parentJoinColumns = this.state.includes.map((include) => include.localColumn); const { selectedForQuery: selectedForDelete, hiddenColumns } = augmentSelectionForJoinColumns( this.state.selectedFields, @@ -1188,6 +1210,41 @@ export class Collection< return criterion; } + async #findFirstMatchingRowIdentityWhere(): Promise { + const identityColumns = resolveRowIdentityColumns(this.contract, this.tableName); + if (identityColumns.length === 0) { + throw new Error( + `update()/delete() on model "${this.modelName}" requires the table to have a primary key or unique constraint`, + ); + } + const firstRow = await this.#clone({ + selectedFields: [...identityColumns], + includes: [], + }).first(); + if (!firstRow) { + return null; + } + const columnToField = getColumnToFieldMap(this.contract, this.modelName); + const criterion: Record = {}; + for (const column of identityColumns) { + const fieldName = columnToField[column] ?? column; + const value = (firstRow as Record)[fieldName]; + if (value === undefined) { + throw new Error( + `Missing identity field "${fieldName}" while resolving single-row scope for model "${this.modelName}"`, + ); + } + criterion[fieldName] = value; + } + return ( + shorthandToWhereExpr( + this.ctx.context, + this.modelName, + criterion as ShorthandWhereFilter, + ) ?? null + ); + } + async #reloadMutationRowByPrimaryKey(criterion: Record): Promise { return this.#reloadMutationRowByCriterion(criterion, 'primary key'); } @@ -1242,6 +1299,16 @@ export class Collection< }); } + #withRuntime(runtime: RuntimeQueryable): Collection { + const Ctor = this.constructor as CollectionConstructor; + return new Ctor({ ...this.ctx, runtime }, this.modelName, { + tableName: this.tableName, + state: this.state, + registry: this.registry, + includeRefinementMode: this.includeRefinementMode, + }) as unknown as Collection; + } + #cloneWithRow( overrides: Partial, ): Collection { diff --git a/packages/3-extensions/sql-orm-client/test/collection-contract.test.ts b/packages/3-extensions/sql-orm-client/test/collection-contract.test.ts index f74ccbc6ce..5f0edc5326 100644 --- a/packages/3-extensions/sql-orm-client/test/collection-contract.test.ts +++ b/packages/3-extensions/sql-orm-client/test/collection-contract.test.ts @@ -7,6 +7,7 @@ import { resolveModelTableName, resolvePolymorphismInfo, resolvePrimaryKeyColumn, + resolveRowIdentityColumns, resolveUpsertConflictColumns, } from '../src/collection-contract'; import { buildMixedPolyContract, getTestContract } from './helpers'; @@ -235,6 +236,63 @@ describe('collection-contract capability detection', () => { expect(isToOneCardinality('M:N')).toBe(false); expect(isToOneCardinality(undefined)).toBe(false); }); + + describe('resolveRowIdentityColumns()', () => { + const buildContract = (table: { + primaryKey?: { columns: readonly string[] }; + uniques?: ReadonlyArray<{ columns: readonly string[] }>; + }) => + ({ + storage: { + tables: { + t: { + primaryKey: table.primaryKey, + uniques: table.uniques ?? [], + }, + }, + }, + }) as unknown as Parameters[0]; + + it('returns primary key columns when present', () => { + expect( + resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['id'] } }), 't'), + ).toEqual(['id']); + }); + + it('returns composite primary key columns when present', () => { + expect( + resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['a', 'b'] } }), 't'), + ).toEqual(['a', 'b']); + }); + + it('falls back to first unique constraint when no primary key', () => { + expect( + resolveRowIdentityColumns( + buildContract({ uniques: [{ columns: ['email'] }, { columns: ['handle'] }] }), + 't', + ), + ).toEqual(['email']); + }); + + it('returns composite unique columns when no primary key', () => { + expect( + resolveRowIdentityColumns( + buildContract({ uniques: [{ columns: ['tenant_id', 'slug'] }] }), + 't', + ), + ).toEqual(['tenant_id', 'slug']); + }); + + it('returns empty array when neither primary key nor uniques are defined', () => { + expect(resolveRowIdentityColumns(buildContract({}), 't')).toEqual([]); + }); + + it('returns empty array for unknown tables', () => { + expect( + resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['id'] } }), 'missing'), + ).toEqual([]); + }); + }); }); describe('resolvePolymorphismInfo()', () => { diff --git a/packages/3-extensions/sql-orm-client/test/integration/delete.test.ts b/packages/3-extensions/sql-orm-client/test/integration/delete.test.ts index 8c89c542fc..a0336d9b4a 100644 --- a/packages/3-extensions/sql-orm-client/test/integration/delete.test.ts +++ b/packages/3-extensions/sql-orm-client/test/integration/delete.test.ts @@ -65,6 +65,36 @@ describe('integration/delete', () => { timeouts.spinUpPpgDev, ); + it( + 'delete() affects only one row even when where() matches several', + async () => { + await withCollectionRuntime(async (runtime) => { + const users = createReturningUsersCollection(runtime); + + await seedUsers(runtime, [ + { id: 1, name: 'Remove', email: 'a@example.com' }, + { id: 2, name: 'Remove', email: 'b@example.com' }, + { id: 3, name: 'Keep', email: 'c@example.com' }, + ]); + + const returned = await users.where({ name: 'Remove' }).delete(); + + expect(returned).not.toBeNull(); + expect(returned?.name).toBe('Remove'); + expect([1, 2]).toContain(returned?.id); + + const rows = await runtime.query<{ id: number; name: string }>( + 'select id, name from users order by id', + ); + const remainingRemove = rows.filter((row) => row.name === 'Remove'); + expect(remainingRemove).toHaveLength(1); + expect(remainingRemove[0]?.id).not.toBe(returned?.id); + expect(rows).toContainEqual({ id: 3, name: 'Keep' }); + }); + }, + timeouts.spinUpPpgDev, + ); + it( 'deleteAll() returns all deleted rows', async () => { diff --git a/packages/3-extensions/sql-orm-client/test/integration/update.test.ts b/packages/3-extensions/sql-orm-client/test/integration/update.test.ts index 17dd9fda4f..084088bf7c 100644 --- a/packages/3-extensions/sql-orm-client/test/integration/update.test.ts +++ b/packages/3-extensions/sql-orm-client/test/integration/update.test.ts @@ -37,6 +37,38 @@ describe('integration/update', () => { timeouts.spinUpPpgDev, ); + it( + 'update() affects only one row even when where() matches several', + async () => { + await withCollectionRuntime(async (runtime) => { + const users = createReturningUsersCollection(runtime); + + await seedUsers(runtime, [ + { id: 1, name: 'Stale', email: 'a@example.com' }, + { id: 2, name: 'Stale', email: 'b@example.com' }, + { id: 3, name: 'Fresh', email: 'c@example.com' }, + ]); + + const returned = await users.where({ name: 'Stale' }).update({ name: 'Updated' }); + + expect(returned).not.toBeNull(); + expect(returned?.name).toBe('Updated'); + expect([1, 2]).toContain(returned?.id); + + const rows = await runtime.query<{ id: number; name: string }>( + 'select id, name from users order by id', + ); + const updatedRows = rows.filter((row) => row.name === 'Updated'); + const staleRows = rows.filter((row) => row.name === 'Stale'); + expect(updatedRows).toHaveLength(1); + expect(staleRows).toHaveLength(1); + expect(rows).toContainEqual({ id: 3, name: 'Fresh' }); + expect(updatedRows[0]?.id).toBe(returned?.id); + }); + }, + timeouts.spinUpPpgDev, + ); + it( 'updateAll() returns all updated rows', async () => {