Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/3-extensions/sql-orm-client/src/collection-contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ export function resolvePrimaryKeyColumn(contract: Contract<SqlStorage>, tableNam
return contract.storage.tables[tableName]?.primaryKey?.columns[0] ?? 'id';
}

export function resolveRowIdentityColumns(
contract: Contract<SqlStorage>,
tableName: string,
): readonly string[] {
const table = contract.storage.tables[tableName];
if (!table) {
return [];
}
if (table.primaryKey && table.primaryKey.columns.length > 0) {
return table.primaryKey.columns;
}
for (const unique of table.uniques) {
if (unique.columns.length > 0) {
return unique.columns;
}
}
return [];
}

export function assertReturningCapability(contract: Contract<SqlStorage>, action: string): void {
if (hasContractCapability(contract, 'returning')) {
return;
Expand Down
83 changes: 75 additions & 8 deletions packages/3-extensions/sql-orm-client/src/collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Contract } from '@prisma-next/contract/types';
import { AsyncIterableResult } from '@prisma-next/framework-components/runtime';
import type { SqlStorage } from '@prisma-next/sql-contract/types';
import {
type AnyExpression,
BinaryExpr,
ColumnRef,
isWhereExpr,
Expand All @@ -25,6 +26,7 @@ import {
resolveModelTableName,
resolvePolymorphismInfo,
resolvePrimaryKeyColumn,
resolveRowIdentityColumns,
resolveUpsertConflictColumns,
} from './collection-contract';
import { dispatchCollectionRows } from './collection-dispatch';
Expand Down Expand Up @@ -107,6 +109,7 @@ import {
type RelatedModelName,
type RelationNames,
type ResolvedCreateInput,
type RuntimeQueryable,
type ShorthandWhereFilter,
type UniqueConstraintCriterion,
type VariantModelRow,
Expand Down Expand Up @@ -1038,12 +1041,20 @@ export class Collection<
return this.#reloadMutationRowByPrimaryKey(pkCriterion);
}

const rows = await this.updateAll(
data as State['hasWhere'] extends true
? Partial<DefaultModelRow<TContract, ModelName>>
: never,
);
return rows[0] ?? null;
return withMutationScope(this.ctx.runtime, async (scope) => {
const scoped = this.#withRuntime(scope);
const identityWhere = await scoped.#findFirstMatchingRowIdentityWhere();
if (!identityWhere) {
return null;
}
const narrowed = scoped.#clone({ filters: [identityWhere] });
const rows = await narrowed.updateAll(
data as State['hasWhere'] extends true
? Partial<DefaultModelRow<TContract, ModelName>>
: never,
);
return rows[0] ?? null;
});
}

updateAll(
Expand Down Expand Up @@ -1115,15 +1126,26 @@ export class Collection<
this: State['hasWhere'] extends true ? Collection<TContract, ModelName, Row, State> : never,
): Promise<Row | null> {
assertReturningCapability(this.contract, 'delete()');
const rows = await this.deleteAll().toArray();
return rows[0] ?? null;
return withMutationScope(this.ctx.runtime, async (scope) => {
const scoped = this.#withRuntime(scope);
const identityWhere = await scoped.#findFirstMatchingRowIdentityWhere();
if (!identityWhere) {
return null;
}
const narrowed = scoped.#clone({ filters: [identityWhere] });
const rows = await narrowed.#executeDeleteReturning().toArray();
return rows[0] ?? null;
});
}

deleteAll(
this: State['hasWhere'] extends true ? Collection<TContract, ModelName, Row, State> : never,
): AsyncIterableResult<Row> {
assertReturningCapability(this.contract, 'deleteAll()');
return this.#executeDeleteReturning();
}

#executeDeleteReturning(): AsyncIterableResult<Row> {
const parentJoinColumns = this.state.includes.map((include) => include.localColumn);
const { selectedForQuery: selectedForDelete, hiddenColumns } = augmentSelectionForJoinColumns(
this.state.selectedFields,
Expand Down Expand Up @@ -1188,6 +1210,41 @@ export class Collection<
return criterion;
}

async #findFirstMatchingRowIdentityWhere(): Promise<AnyExpression | null> {
const identityColumns = resolveRowIdentityColumns(this.contract, this.tableName);
if (identityColumns.length === 0) {
throw new Error(
`update()/delete() on model "${this.modelName}" requires the table to have a primary key or unique constraint`,
);
}
const firstRow = await this.#clone({
selectedFields: [...identityColumns],
includes: [],
}).first();
if (!firstRow) {
return null;
}
const columnToField = getColumnToFieldMap(this.contract, this.modelName);
const criterion: Record<string, unknown> = {};
for (const column of identityColumns) {
const fieldName = columnToField[column] ?? column;
const value = (firstRow as Record<string, unknown>)[fieldName];
if (value === undefined) {
throw new Error(
`Missing identity field "${fieldName}" while resolving single-row scope for model "${this.modelName}"`,
);
}
criterion[fieldName] = value;
}
return (
shorthandToWhereExpr(
this.ctx.context,
this.modelName,
criterion as ShorthandWhereFilter<TContract, ModelName>,
) ?? null
);
}

async #reloadMutationRowByPrimaryKey(criterion: Record<string, unknown>): Promise<Row | null> {
return this.#reloadMutationRowByCriterion(criterion, 'primary key');
}
Expand Down Expand Up @@ -1242,6 +1299,16 @@ export class Collection<
});
}

#withRuntime(runtime: RuntimeQueryable): Collection<TContract, ModelName, Row, State> {
const Ctor = this.constructor as CollectionConstructor<TContract>;
return new Ctor({ ...this.ctx, runtime }, this.modelName, {
tableName: this.tableName,
state: this.state,
registry: this.registry,
includeRefinementMode: this.includeRefinementMode,
}) as unknown as Collection<TContract, ModelName, Row, State>;
}

#cloneWithRow<NextRow, NextState extends CollectionTypeState = State>(
overrides: Partial<CollectionState>,
): Collection<TContract, ModelName, NextRow, NextState> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
resolveModelTableName,
resolvePolymorphismInfo,
resolvePrimaryKeyColumn,
resolveRowIdentityColumns,
resolveUpsertConflictColumns,
} from '../src/collection-contract';
import { buildMixedPolyContract, getTestContract } from './helpers';
Expand Down Expand Up @@ -235,6 +236,63 @@ describe('collection-contract capability detection', () => {
expect(isToOneCardinality('M:N')).toBe(false);
expect(isToOneCardinality(undefined)).toBe(false);
});

describe('resolveRowIdentityColumns()', () => {
const buildContract = (table: {
primaryKey?: { columns: readonly string[] };
uniques?: ReadonlyArray<{ columns: readonly string[] }>;
}) =>
({
storage: {
tables: {
t: {
primaryKey: table.primaryKey,
uniques: table.uniques ?? [],
},
},
},
}) as unknown as Parameters<typeof resolveRowIdentityColumns>[0];

it('returns primary key columns when present', () => {
expect(
resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['id'] } }), 't'),
).toEqual(['id']);
});

it('returns composite primary key columns when present', () => {
expect(
resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['a', 'b'] } }), 't'),
).toEqual(['a', 'b']);
});

it('falls back to first unique constraint when no primary key', () => {
expect(
resolveRowIdentityColumns(
buildContract({ uniques: [{ columns: ['email'] }, { columns: ['handle'] }] }),
't',
),
).toEqual(['email']);
});

it('returns composite unique columns when no primary key', () => {
expect(
resolveRowIdentityColumns(
buildContract({ uniques: [{ columns: ['tenant_id', 'slug'] }] }),
't',
),
).toEqual(['tenant_id', 'slug']);
});

it('returns empty array when neither primary key nor uniques are defined', () => {
expect(resolveRowIdentityColumns(buildContract({}), 't')).toEqual([]);
});

it('returns empty array for unknown tables', () => {
expect(
resolveRowIdentityColumns(buildContract({ primaryKey: { columns: ['id'] } }), 'missing'),
).toEqual([]);
});
});
});

describe('resolvePolymorphismInfo()', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,36 @@ describe('integration/delete', () => {
timeouts.spinUpPpgDev,
);

it(
'delete() affects only one row even when where() matches several',
async () => {
await withCollectionRuntime(async (runtime) => {
const users = createReturningUsersCollection(runtime);

await seedUsers(runtime, [
{ id: 1, name: 'Remove', email: 'a@example.com' },
{ id: 2, name: 'Remove', email: 'b@example.com' },
{ id: 3, name: 'Keep', email: 'c@example.com' },
]);

const returned = await users.where({ name: 'Remove' }).delete();

expect(returned).not.toBeNull();
expect(returned?.name).toBe('Remove');
expect([1, 2]).toContain(returned?.id);

const rows = await runtime.query<{ id: number; name: string }>(
'select id, name from users order by id',
);
const remainingRemove = rows.filter((row) => row.name === 'Remove');
expect(remainingRemove).toHaveLength(1);
expect(remainingRemove[0]?.id).not.toBe(returned?.id);
expect(rows).toContainEqual({ id: 3, name: 'Keep' });
});
},
timeouts.spinUpPpgDev,
);

it(
'deleteAll() returns all deleted rows',
async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ describe('integration/update', () => {
timeouts.spinUpPpgDev,
);

it(
'update() affects only one row even when where() matches several',
async () => {
await withCollectionRuntime(async (runtime) => {
const users = createReturningUsersCollection(runtime);

await seedUsers(runtime, [
{ id: 1, name: 'Stale', email: 'a@example.com' },
{ id: 2, name: 'Stale', email: 'b@example.com' },
{ id: 3, name: 'Fresh', email: 'c@example.com' },
]);

const returned = await users.where({ name: 'Stale' }).update({ name: 'Updated' });

expect(returned).not.toBeNull();
expect(returned?.name).toBe('Updated');
expect([1, 2]).toContain(returned?.id);

const rows = await runtime.query<{ id: number; name: string }>(
'select id, name from users order by id',
);
const updatedRows = rows.filter((row) => row.name === 'Updated');
const staleRows = rows.filter((row) => row.name === 'Stale');
expect(updatedRows).toHaveLength(1);
expect(staleRows).toHaveLength(1);
expect(rows).toContainEqual({ id: 3, name: 'Fresh' });
expect(updatedRows[0]?.id).toBe(returned?.id);
});
},
timeouts.spinUpPpgDev,
);

it(
'updateAll() returns all updated rows',
async () => {
Expand Down
Loading