From 1956bdb461858cf5e0562434f47a3678d493b142 Mon Sep 17 00:00:00 2001 From: Yiming Date: Thu, 2 Jan 2025 16:39:29 +0800 Subject: [PATCH] fix(delegate): delegate model's guards are not properly including concrete models (#1932) --- .../src/plugins/enhancer/enhance/index.ts | 17 +--- .../enhancer/policy/expression-writer.ts | 24 +++--- .../enhancer/policy/policy-guard-generator.ts | 2 + .../src/plugins/prisma/schema-generator.ts | 5 +- packages/schema/src/utils/ast-utils.ts | 28 ++++++- .../with-delegate/policy-interaction.test.ts | 80 +++++++++++++++++++ tests/regression/tests/issue-1930.test.ts | 80 +++++++++++++++++++ 7 files changed, 207 insertions(+), 29 deletions(-) create mode 100644 tests/regression/tests/issue-1930.test.ts diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index ba8c50feb..689ddaf2c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -24,7 +24,6 @@ import { isArrayExpr, isDataModel, isGeneratorDecl, - isReferenceExpr, isTypeDef, type Model, } from '@zenstackhq/sdk/ast'; @@ -45,6 +44,7 @@ import { } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; +import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils'; import { execPackage } from '../../../utils/exec-utils'; import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils'; import { trackPrismaSchemaError } from '../../prisma'; @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara this.model.declarations .filter((d): d is DataModel => isDelegateModel(d)) .forEach((dm) => { - const concreteModels = this.model.declarations.filter( - (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) - ); + const concreteModels = getConcreteModels(dm); if (concreteModels.length > 0) { delegateInfo.push([dm, concreteModels]); } @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const typeName = typeAlias.getName(); const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName); if (payloadRecord) { - const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]); + const discriminatorDecl = getDiscriminatorField(payloadRecord[0]); if (discriminatorDecl) { source = `${payloadRecord[1] .map( @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara .filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX)); } - private getDiscriminatorField(delegate: DataModel) { - const delegateAttr = getAttribute(delegate, '@@delegate'); - if (!delegateAttr) { - return undefined; - } - const arg = delegateAttr.args[0]?.value; - return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; - } - private saveSourceFile(sf: SourceFile) { if (this.options.preserveTsFiles) { saveSourceFile(sf); diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 645e02cd1..0d792bdc1 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -839,16 +839,18 @@ export class ExpressionWriter { operation = this.options.operationContext; } - this.block(() => { - if (operation === 'postUpdate') { - // 'postUpdate' policies are not delegated to relations, just use constant `false` here - // e.g.: - // @@allow('all', check(author)) should not delegate "postUpdate" to author - this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`); - } else { - const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); - this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`); - } - }); + this.block(() => + this.writeFieldCondition(fieldRef, () => { + if (operation === 'postUpdate') { + // 'postUpdate' policies are not delegated to relations, just use constant `false` here + // e.g.: + // @@allow('all', check(author)) should not delegate "postUpdate" to author + this.writer.write(FALSE); + } else { + const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); + this.writer.write(`${targetGuardFunc}(context, db)`); + } + }) + ); } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 8206f797b..9ffe41dcb 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -454,6 +454,8 @@ export class PolicyGenerator { writer: CodeBlockWriter, sourceFile: SourceFile ) { + // first handle several cases where a constant function can be used + if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 96a3b15f5..a0bde1769 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -57,6 +57,7 @@ import path from 'path'; import semver from 'semver'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; +import { getConcreteModels } from '../../utils/ast-utils'; import { execPackage } from '../../utils/exec-utils'; import { isDefaultWithAuth } from '../enhancer/enhancer-utils'; import { @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator { } // collect concrete models inheriting this model - const concreteModels = decl.$container.declarations.filter( - (d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl) - ); + const concreteModels = getConcreteModels(decl); // generate an optional relation field in delegate base model to each concrete model concreteModels.forEach((concrete) => { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index a6fab7ea5..0e462547f 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -2,6 +2,7 @@ import { BinaryExpr, DataModel, DataModelAttribute, + DataModelField, Expression, InheritableNode, isBinaryExpr, @@ -9,12 +10,13 @@ import { isDataModelField, isInvocationExpr, isModel, + isReferenceExpr, isTypeDef, Model, ModelImport, TypeDef, } from '@zenstackhq/language/ast'; -import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; +import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode } return undefined; } + +/** + * Gets all concrete models that inherit from the given delegate model + */ +export function getConcreteModels(dataModel: DataModel): DataModel[] { + if (!isDelegateModel(dataModel)) { + return []; + } + return dataModel.$container.declarations.filter( + (d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel) + ); +} + +/** + * Gets the discriminator field for the given delegate model + */ +export function getDiscriminatorField(dataModel: DataModel) { + const delegateAttr = getAttribute(dataModel, '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const arg = delegateAttr.args[0]?.value; + return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; +} diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index d149a6392..67fc456af 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => { expect(foundPost2.foo).toBeUndefined(); expect(foundPost2.bar).toBeUndefined(); }); + + it('respects concrete policies when read as base optional relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset? + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@allow('create', true) + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.user.create({ data: { id: 1 } }); + await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read.asset).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); + + it('respects concrete policies when read as base required relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset @relation(fields: [assetId], references: [id]) + assetId Int @unique + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User? + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); }); diff --git a/tests/regression/tests/issue-1930.test.ts b/tests/regression/tests/issue-1930.test.ts new file mode 100644 index 000000000..762369321 --- /dev/null +++ b/tests/regression/tests/issue-1930.test.ts @@ -0,0 +1,80 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1930', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` +model Organization { + id String @id @default(cuid()) + entities Entity[] + + @@allow('all', true) +} + +model Entity { + id String @id @default(cuid()) + org Organization? @relation(fields: [orgId], references: [id]) + orgId String? + contents EntityContent[] + entityType String + isDeleted Boolean @default(false) + + @@delegate(entityType) + + @@allow('all', !isDeleted) +} + +model EntityContent { + id String @id @default(cuid()) + entity Entity @relation(fields: [entityId], references: [id]) + entityId String + + entityContentType String + + @@delegate(entityContentType) + + @@allow('create', true) + @@allow('read', check(entity)) +} + +model Article extends Entity { +} + +model ArticleContent extends EntityContent { + body String? +} + +model OtherContent extends EntityContent { + data Int +} + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + const org = await fullDb.organization.create({ data: {} }); + const article = await fullDb.article.create({ + data: { org: { connect: { id: org.id } } }, + }); + + const db = enhance(); + + // normal create/read + await expect( + db.articleContent.create({ + data: { body: 'abc', entity: { connect: { id: article.id } } }, + }) + ).toResolveTruthy(); + await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({ + contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]), + }); + + // deleted article's contents are not readable + const deletedArticle = await fullDb.article.create({ + data: { org: { connect: { id: org.id } }, isDeleted: true }, + }); + const content1 = await fullDb.articleContent.create({ + data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } }, + }); + await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull(); + }); +});