Skip to content

Commit

Permalink
fix(delegate): delegate model's guards are not properly including con…
Browse files Browse the repository at this point in the history
…crete models (#1932)
  • Loading branch information
ymc9 authored Jan 2, 2025
1 parent 2eecae5 commit 1956bdb
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 29 deletions.
17 changes: 3 additions & 14 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {
isArrayExpr,
isDataModel,
isGeneratorDecl,
isReferenceExpr,
isTypeDef,
type Model,
} from '@zenstackhq/sdk/ast';
Expand All @@ -45,6 +44,7 @@ import {
} from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
import { execPackage } from '../../../utils/exec-utils';
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
import { trackPrismaSchemaError } from '../../prisma';
Expand Down Expand Up @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
this.model.declarations
.filter((d): d is DataModel => isDelegateModel(d))
.forEach((dm) => {
const concreteModels = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
);
const concreteModels = getConcreteModels(dm);
if (concreteModels.length > 0) {
delegateInfo.push([dm, concreteModels]);
}
Expand Down Expand Up @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const typeName = typeAlias.getName();
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
if (payloadRecord) {
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
if (discriminatorDecl) {
source = `${payloadRecord[1]
.map(
Expand Down Expand Up @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
}

private getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
saveSourceFile(sf);
Expand Down
24 changes: 13 additions & 11 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,18 @@ export class ExpressionWriter {
operation = this.options.operationContext;
}

this.block(() => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
}
});
this.block(() =>
this.writeFieldCondition(fieldRef, () => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(FALSE);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${targetGuardFunc}(context, db)`);
}
})
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ export class PolicyGenerator {
writer: CodeBlockWriter,
sourceFile: SourceFile
) {
// first handle several cases where a constant function can be used

if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
Expand Down
5 changes: 2 additions & 3 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import path from 'path';
import semver from 'semver';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { getConcreteModels } from '../../utils/ast-utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
Expand Down Expand Up @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
}

// collect concrete models inheriting this model
const concreteModels = decl.$container.declarations.filter(
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
);
const concreteModels = getConcreteModels(decl);

// generate an optional relation field in delegate base model to each concrete model
concreteModels.forEach((concrete) => {
Expand Down
28 changes: 27 additions & 1 deletion packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
TypeDef,
} from '@zenstackhq/language/ast';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
}
return undefined;
}

/**
* Gets all concrete models that inherit from the given delegate model
*/
export function getConcreteModels(dataModel: DataModel): DataModel[] {
if (!isDelegateModel(dataModel)) {
return [];
}
return dataModel.$container.declarations.filter(
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
);
}

/**
* Gets the discriminator field for the given delegate model
*/
export function getDiscriminatorField(dataModel: DataModel) {
const delegateAttr = getAttribute(dataModel, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}
Original file line number Diff line number Diff line change
Expand Up @@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => {
expect(foundPost2.foo).toBeUndefined();
expect(foundPost2.bar).toBeUndefined();
});

it('respects concrete policies when read as base optional relation', async () => {
const { enhance } = await loadSchema(
`
model User {
id Int @id @default(autoincrement())
asset Asset?
@@allow('all', true)
}
model Asset {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int @unique
type String
@@delegate(type)
@@allow('all', true)
}
model Post extends Asset {
title String
private Boolean
@@allow('create', true)
@@deny('read', private)
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
await fullDb.user.create({ data: { id: 1 } });
await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } });
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
asset: expect.objectContaining({ type: 'Post' }),
});

const db = enhance();
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read.asset).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});

it('respects concrete policies when read as base required relation', async () => {
const { enhance } = await loadSchema(
`
model User {
id Int @id @default(autoincrement())
asset Asset @relation(fields: [assetId], references: [id])
assetId Int @unique
@@allow('all', true)
}
model Asset {
id Int @id @default(autoincrement())
user User?
type String
@@delegate(type)
@@allow('all', true)
}
model Post extends Asset {
title String
private Boolean
@@deny('read', private)
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } });
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
asset: expect.objectContaining({ type: 'Post' }),
});

const db = enhance();
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});
});
80 changes: 80 additions & 0 deletions tests/regression/tests/issue-1930.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1930', () => {
it('regression', async () => {
const { enhance } = await loadSchema(
`
model Organization {
id String @id @default(cuid())
entities Entity[]
@@allow('all', true)
}
model Entity {
id String @id @default(cuid())
org Organization? @relation(fields: [orgId], references: [id])
orgId String?
contents EntityContent[]
entityType String
isDeleted Boolean @default(false)
@@delegate(entityType)
@@allow('all', !isDeleted)
}
model EntityContent {
id String @id @default(cuid())
entity Entity @relation(fields: [entityId], references: [id])
entityId String
entityContentType String
@@delegate(entityContentType)
@@allow('create', true)
@@allow('read', check(entity))
}
model Article extends Entity {
}
model ArticleContent extends EntityContent {
body String?
}
model OtherContent extends EntityContent {
data Int
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
const org = await fullDb.organization.create({ data: {} });
const article = await fullDb.article.create({
data: { org: { connect: { id: org.id } } },
});

const db = enhance();

// normal create/read
await expect(
db.articleContent.create({
data: { body: 'abc', entity: { connect: { id: article.id } } },
})
).toResolveTruthy();
await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({
contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]),
});

// deleted article's contents are not readable
const deletedArticle = await fullDb.article.create({
data: { org: { connect: { id: org.id } }, isDeleted: true },
});
const content1 = await fullDb.articleContent.create({
data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } },
});
await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull();
});
});

0 comments on commit 1956bdb

Please sign in to comment.