diff --git a/src/helpers/inherit-metedata.ts b/src/helpers/inherit-metedata.ts new file mode 100644 index 000000000..5ca81f005 --- /dev/null +++ b/src/helpers/inherit-metedata.ts @@ -0,0 +1,186 @@ +// Inspired by @nestjs/mapped-types +import { ClassType } from "../interfaces"; + +export function applyIsOptionalDecorator(targetClass: Function, propertyKey: string) { + if (!isClassValidatorAvailable()) { + return; + } + const classValidator: typeof import("class-validator") = require("class-validator"); + const decoratorFactory = classValidator.IsOptional(); + decoratorFactory(targetClass.prototype, propertyKey); +} + +export function inheritValidationMetadata( + parentClass: ClassType, + targetClass: Function, + isPropertyInherited?: (key: string) => boolean, +) { + if (!isClassValidatorAvailable()) { + return; + } + try { + const classValidator: typeof import("class-validator") = require("class-validator"); + const metadataStorage: import("class-validator").MetadataStorage = (classValidator as any) + .getMetadataStorage + ? (classValidator as any).getMetadataStorage() + : classValidator.getFromContainer(classValidator.MetadataStorage); + + const getTargetValidationMetadatasArgs = [parentClass, null!, false, false]; + const targetMetadata: ReturnType< + typeof metadataStorage.getTargetValidationMetadatas + > = (metadataStorage.getTargetValidationMetadatas as Function)( + ...getTargetValidationMetadatasArgs, + ); + targetMetadata + .filter(({ propertyName }) => !isPropertyInherited || isPropertyInherited(propertyName)) + .map(value => { + const originalType = Reflect.getMetadata( + "design:type", + parentClass.prototype, + value.propertyName, + ); + if (originalType) { + // @ts-ignore + Reflect.defineMetadata( + "design:type", + originalType, + targetClass.prototype, + value.propertyName, + ); + } + + metadataStorage.addValidationMetadata({ + ...value, + target: targetClass, + }); + return value.propertyName; + }); + } catch (err) { + if (err.code !== "EEXIST") { + throw err; + } + } +} + +type TransformMetadataKey = + | "_excludeMetadatas" + | "_exposeMetadatas" + | "_typeMetadatas" + | "_transformMetadatas"; + +export function inheritTransformationMetadata( + parentClass: ClassType, + targetClass: Function, + isPropertyInherited?: (key: string) => boolean, +) { + if (!isClassTransformerAvailable()) { + return; + } + try { + const transformMetadataKeys: TransformMetadataKey[] = [ + "_excludeMetadatas", + "_exposeMetadatas", + "_transformMetadatas", + "_typeMetadatas", + ]; + transformMetadataKeys.forEach(key => + inheritTransformerMetadata(key, parentClass, targetClass, isPropertyInherited), + ); + } catch (err) { + if (err.code !== "EEXIST") { + throw err; + } + } +} + +function inheritTransformerMetadata( + key: TransformMetadataKey, + parentClass: ClassType, + targetClass: Function, + isPropertyInherited?: (key: string) => boolean, +) { + let classTransformer: any; + try { + /** "class-transformer" >= v0.3.x */ + classTransformer = require("class-transformer/cjs/storage"); + } catch { + /** "class-transformer" <= v0.3.x */ + classTransformer = require("class-transformer/storage"); + } + const metadataStorage /*: typeof import('class-transformer/types/storage').defaultMetadataStorage */ = + classTransformer.defaultMetadataStorage; + + while (parentClass && parentClass !== Object) { + if (metadataStorage[key].has(parentClass)) { + const metadataMap = metadataStorage[key] as Map>; + const parentMetadata = metadataMap.get(parentClass); + + const targetMetadataEntries: Iterable<[string, any]> = Array.from(parentMetadata!.entries()) + .filter(([keyInEntries]) => !isPropertyInherited || isPropertyInherited(keyInEntries)) + .map(([keyInEntries, metadata]) => { + if (Array.isArray(metadata)) { + // "_transformMetadatas" is an array of elements + const targetMetadata = metadata.map(item => ({ + ...item, + target: targetClass, + })); + return [keyInEntries, targetMetadata]; + } + return [keyInEntries, { ...metadata, target: targetClass }]; + }); + + if (metadataMap.has(targetClass)) { + const existingRules = metadataMap.get(targetClass)!.entries(); + metadataMap.set(targetClass, new Map([...existingRules, ...targetMetadataEntries])); + } else { + metadataMap.set(targetClass, new Map(targetMetadataEntries)); + } + } + parentClass = Object.getPrototypeOf(parentClass); + } +} + +function isClassValidatorAvailable() { + try { + require("class-validator"); + return true; + } catch { + return false; + } +} + +function isClassTransformerAvailable() { + try { + require("class-transformer"); + return true; + } catch { + return false; + } +} + +export function inheritPropertyInitializers( + target: Record, + sourceClass: ClassType, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + isPropertyInherited = (key: string) => true, +) { + try { + const tempInstance = new sourceClass(); + const propertyNames = Object.getOwnPropertyNames(tempInstance); + + propertyNames + .filter( + propertyName => + typeof tempInstance[propertyName] !== "undefined" && + typeof target[propertyName] === "undefined", + ) + .filter(propertyName => isPropertyInherited(propertyName)) + .forEach(propertyName => { + target[propertyName] = tempInstance[propertyName]; + }); + } catch (err) { + if (err.code !== "EEXIST") { + throw err; + } + } +} diff --git a/src/utils/index.ts b/src/utils/index.ts index fae4db30f..caf200380 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -11,3 +11,10 @@ export { defaultPrintSchemaOptions, } from "./emitSchemaDefinitionFile"; export { ContainerType, ContainerGetter } from "./container"; +export { + PartialType, + PickType, + RequiredType, + OmitType, + IntersectionType, +} from "./types-transformation"; diff --git a/src/utils/types-transformation.ts b/src/utils/types-transformation.ts new file mode 100644 index 000000000..9bb1c323a --- /dev/null +++ b/src/utils/types-transformation.ts @@ -0,0 +1,136 @@ +import { ObjectType, InputType, InterfaceType } from "../decorators"; +import { + inheritValidationMetadata, + inheritTransformationMetadata, + applyIsOptionalDecorator, +} from "../helpers/inherit-metedata"; +import { ClassType } from "../interfaces"; +import { getMetadataStorage } from "../metadata"; + +export function PartialType(BaseClass: ClassType): ClassType> { + const PartialClass = abstractClass(); + inheritValidationMetadata(BaseClass, PartialClass); + inheritTransformationMetadata(BaseClass, PartialClass); + + const fields = getMetadataStorage().fields.filter( + f => f.target === BaseClass || BaseClass.prototype instanceof f.target, + ); + + fields.forEach(field => { + getMetadataStorage().collectClassFieldMetadata({ + ...field, + typeOptions: { ...field.typeOptions, nullable: true }, + target: PartialClass, + }); + applyIsOptionalDecorator(PartialClass, field.name); + }); + + return PartialClass as ClassType>; +} + +export function RequiredType(BaseClass: ClassType): ClassType> { + const RequiredClass = abstractClass(); + inheritValidationMetadata(BaseClass, RequiredClass); + inheritTransformationMetadata(BaseClass, RequiredClass); + + const fields = getMetadataStorage().fields.filter( + f => f.target === BaseClass || BaseClass.prototype instanceof f.target, + ); + + fields.forEach(field => { + getMetadataStorage().collectClassFieldMetadata({ + ...field, + typeOptions: { ...field.typeOptions, nullable: false }, + target: RequiredClass, + }); + }); + return RequiredClass as ClassType>; +} + +export function PickType( + BaseClass: ClassType, + ...pickFields: K[] +): ClassType> { + const PickClass = abstractClass(); + + const isInheritedPredicate = (propertyKey: string) => pickFields.includes(propertyKey as K); + inheritValidationMetadata(BaseClass, PickClass, isInheritedPredicate); + inheritTransformationMetadata(BaseClass, PickClass, isInheritedPredicate); + + const fields = getMetadataStorage().fields.filter( + f => + (f.target === BaseClass || BaseClass.prototype instanceof f.target) && + pickFields.includes(f.name as K), + ); + + fields.forEach(field => { + getMetadataStorage().collectClassFieldMetadata({ + ...field, + target: PickClass, + }); + }); + return PickClass as ClassType>; +} + +export function OmitType( + BaseClass: ClassType, + ...omitFields: K[] +): ClassType> { + const OmitClass = abstractClass(); + + const isInheritedPredicate = (propertyKey: string) => !omitFields.includes(propertyKey as K); + inheritValidationMetadata(BaseClass, OmitClass, isInheritedPredicate); + inheritTransformationMetadata(BaseClass, OmitClass, isInheritedPredicate); + + const fields = getMetadataStorage().fields.filter( + f => + (f.target === BaseClass || BaseClass.prototype instanceof f.target) && + !omitFields.includes(f.name as K), + ); + + fields.forEach(field => { + getMetadataStorage().collectClassFieldMetadata({ + ...field, + target: OmitClass, + }); + }); + return OmitClass as ClassType>; +} + +export function IntersectionType(BaseClassA: ClassType, BaseClassB: ClassType) { + const IntersectionClass = abstractClass(); + inheritValidationMetadata(BaseClassA, IntersectionClass); + inheritTransformationMetadata(BaseClassA, IntersectionClass); + inheritValidationMetadata(BaseClassB, IntersectionClass); + inheritTransformationMetadata(BaseClassB, IntersectionClass); + + const fields = getMetadataStorage().fields.filter( + f => + f.target === BaseClassB || + BaseClassB.prototype instanceof f.target || + f.target === BaseClassA || + BaseClassA.prototype instanceof f.target, + ); + + fields.forEach(field => { + getMetadataStorage().collectClassFieldMetadata({ + ...field, + target: IntersectionClass, + }); + }); + + return IntersectionClass as ClassType; +} + +function abstractClass() { + class AbstractClass {} + InputType({ isAbstract: true })(AbstractClass); + ObjectType({ isAbstract: true })(AbstractClass); + InterfaceType({ isAbstract: true })(AbstractClass); + getMetadataStorage().collectArgsMetadata({ + name: AbstractClass.name, + isAbstract: true, + target: AbstractClass, + }); + return AbstractClass; +} diff --git a/tests/functional/types-transformation.ts b/tests/functional/types-transformation.ts new file mode 100644 index 000000000..ef5206046 --- /dev/null +++ b/tests/functional/types-transformation.ts @@ -0,0 +1,384 @@ +import "reflect-metadata"; +import { + graphql, + IntrospectionInputObjectType, + IntrospectionNamedTypeRef, + IntrospectionNonNullTypeRef, + IntrospectionObjectType, + TypeKind, +} from "graphql"; +import { + Arg, + Args, + ArgsType, + ClassType, + Field, + getMetadataStorage, + InputType, + InterfaceType, + ObjectType, + OmitType, + PartialType, + PickType, + Query, + RequiredType, + Resolver, + IntersectionType, + buildSchema, + Mutation, + ArgumentValidationError, +} from "../../src"; +import { getSchemaInfo } from "../helpers/getSchemaInfo"; +import { MaxLength, Max, Min } from "class-validator"; + +describe("Types transformation utils", () => { + beforeEach(() => { + getMetadataStorage().clear(); + }); + + it("PartialType should set all fields to nullable", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: true }) + baseFieldA: string; + + @Field({ nullable: false }) + baseFieldB: string; + + @Field() + baseFieldC: string; + } + + @ObjectType() + class SampleObject extends PartialType(BaseObject) {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const baseFieldA = sampleObjectType.fields.find(field => field.name === "baseFieldA")!; + expect(baseFieldA.type.kind).toEqual(TypeKind.SCALAR); + const baseFieldB = sampleObjectType.fields.find(field => field.name === "baseFieldB")!; + expect(baseFieldB.type.kind).toEqual(TypeKind.SCALAR); + const baseFieldC = sampleObjectType.fields.find(field => field.name === "baseFieldC")!; + expect(baseFieldC.type.kind).toEqual(TypeKind.SCALAR); + }); + + it("RequiredType should set all fields to NON_NULL", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: true }) + baseFieldA: string; + + @Field({ nullable: false }) + baseFieldB: string; + + @Field() + baseFieldC: string; + } + + @ObjectType() + class SampleObject extends RequiredType(BaseObject) {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const baseFieldA = sampleObjectType.fields.find(field => field.name === "baseFieldA")!; + expect(baseFieldA.type.kind).toEqual(TypeKind.NON_NULL); + const baseFieldB = sampleObjectType.fields.find(field => field.name === "baseFieldB")!; + expect(baseFieldB.type.kind).toEqual(TypeKind.NON_NULL); + const baseFieldC = sampleObjectType.fields.find(field => field.name === "baseFieldC")!; + expect(baseFieldC.type.kind).toEqual(TypeKind.NON_NULL); + }); + + it("PickType should only define specified field", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: true }) + baseFieldA: string; + + @Field({ nullable: false }) + baseFieldB: string; + + @Field() + baseFieldC: string; + } + + @ObjectType() + class SampleObject extends PickType(BaseObject, "baseFieldA") {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const baseFieldA = sampleObjectType.fields.find(field => field.name === "baseFieldA")!; + expect(baseFieldA).toBeDefined(); + const baseFieldB = sampleObjectType.fields.find(field => field.name === "baseFieldB")!; + expect(baseFieldB).toBeUndefined(); + const baseFieldC = sampleObjectType.fields.find(field => field.name === "baseFieldC")!; + expect(baseFieldC).toBeUndefined(); + }); + + it("OmitType should omit specified field", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: true }) + baseFieldA: string; + + @Field({ nullable: false }) + baseFieldB: string; + + @Field() + baseFieldC: string; + } + + @ObjectType() + class SampleObject extends OmitType(BaseObject, "baseFieldA", "baseFieldB") {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const baseFieldA = sampleObjectType.fields.find(field => field.name === "baseFieldA")!; + expect(baseFieldA).toBeUndefined(); + const baseFieldB = sampleObjectType.fields.find(field => field.name === "baseFieldB")!; + expect(baseFieldB).toBeUndefined(); + const baseFieldC = sampleObjectType.fields.find(field => field.name === "baseFieldC")!; + expect(baseFieldC).toBeDefined(); + }); + + it("IntersectionType should combines two types into one new type without error", async () => { + @ObjectType() + class BaseObjectA { + @Field() + baseFieldA: number; + } + + @ObjectType() + class BaseObjectB { + @Field() + baseFieldB: string; + } + + @ObjectType() + class BaseObjectC { + @Field() + baseFieldC: string; + } + + @ObjectType() + class SampleObject extends IntersectionType( + BaseObjectA, + IntersectionType(BaseObjectB, BaseObjectC), + ) {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const baseFieldA = sampleObjectType.fields.find(field => field.name === "baseFieldA")!; + const baseFieldAType = (baseFieldA.type as IntrospectionNonNullTypeRef) + .ofType as IntrospectionNamedTypeRef; + expect(baseFieldA).toBeDefined(); + expect(baseFieldAType.name).toEqual("Float"); + const baseFieldB = sampleObjectType.fields.find(field => field.name === "baseFieldB")!; + const baseFieldBType = (baseFieldB.type as IntrospectionNonNullTypeRef) + .ofType as IntrospectionNamedTypeRef; + expect(baseFieldB).toBeDefined(); + expect(baseFieldBType.name).toEqual("String"); + const baseFieldC = sampleObjectType.fields.find(field => field.name === "baseFieldC")!; + const baseFieldCType = (baseFieldC.type as IntrospectionNonNullTypeRef) + .ofType as IntrospectionNamedTypeRef; + expect(baseFieldC).toBeDefined(); + expect(baseFieldCType.name).toEqual("String"); + }); + + it("should composable", async () => { + @InputType() + class PartialObject { + @Field() + nullableStringField: string; + } + + @ArgsType() + class RequiredObject { + @Field() + nonNullStringField: string; + } + + @InterfaceType() + class PickedObject { + @Field() + pickedStringField: string; + } + + @ObjectType() + class OmittedObject { + @Field() + OmittedStringField: string; + } + + @ObjectType() + class SampleObject extends IntersectionType( + IntersectionType(PartialType(PartialObject), RequiredType(RequiredObject)), + IntersectionType( + PickType(PickedObject, "pickedStringField"), + OmitType(OmittedObject, "OmittedStringField"), + ), + ) {} + + const sampleObjectType = await getSampleObjectType(SampleObject); + + const nullableStringField = sampleObjectType.fields.find( + f => f.name === "nullableStringField", + )!; + expect(nullableStringField).toBeDefined(); + expect(nullableStringField.type.kind).toEqual(TypeKind.SCALAR); + + const nonNullStringField = sampleObjectType.fields.find(f => f.name === "nonNullStringField")!; + expect(nonNullStringField).toBeDefined(); + expect(nonNullStringField.type.kind).toEqual(TypeKind.NON_NULL); + + const OmittedStringField = sampleObjectType.fields.find(f => f.name === "OmittedStringField")!; + expect(OmittedStringField).toBeUndefined(); + + const pickedStringField = sampleObjectType.fields.find(f => f.name === "pickedStringField")!; + expect(pickedStringField).toBeDefined(); + }); + + it("should generate correct input type", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: false }) + stringField: string; + } + + @InputType() + class SampleArgs extends PartialType(BaseObject) {} + + @Resolver() + class SampleResolver { + @Query() + sampleQuery(@Arg("sample") _args: SampleArgs): String { + return ""; + } + } + + const schemaInfo = await getSchemaInfo({ + resolvers: [SampleResolver], + }); + const schemaIntrospection = schemaInfo.schemaIntrospection; + const sampleInputType = schemaIntrospection.types.find( + type => type.name === "SampleArgs", + ) as IntrospectionInputObjectType; + + const stringField = sampleInputType.inputFields.find(f => f.name === "stringField")!; + expect(stringField).toBeDefined(); + expect(stringField.type.kind).toEqual(TypeKind.SCALAR); + }); + + it("should generate correct args type", async () => { + @ObjectType() + class BaseObject { + @Field({ nullable: false }) + stringField: string; + } + + @ArgsType() + class SampleArgs extends PartialType(BaseObject) {} + + @Resolver() + class SampleResolver { + @Query() + sampleQuery(@Args() _args: SampleArgs): String { + return ""; + } + } + + const schemaInfo = await getSchemaInfo({ + resolvers: [SampleResolver], + }); + + const sampleQuery = schemaInfo.queryType.fields.find(f => f.name === "sampleQuery")!; + const stringField = sampleQuery.args[0]; + console.log("sampleQuery: \n", sampleQuery); + + expect(stringField).toBeDefined(); + expect(stringField.type.kind).toEqual(TypeKind.SCALAR); + }); + + it("should throw validation error when input is incorrect", async () => { + @ObjectType() + class SampleObject { + @Field({ nullable: true }) + field?: string; + } + + @InputType() + class BaseInputA { + @Field() + @MaxLength(5) + stringField: string; + + @Field() + @Max(5) + numberField: number; + } + + @InputType() + class BaseInputB { + @Field({ nullable: true }) + @Min(5) + optionalField?: number; + } + + @InputType() + class SampleInput extends IntersectionType(BaseInputA, BaseInputB) {} + + @Resolver(of => SampleObject) + class SampleResolver { + @Mutation() + sampleMutation(@Arg("input") input: SampleInput): SampleObject { + return {}; + } + + @Query() + sampleQuery(): SampleObject { + return {}; + } + } + + const schema = await buildSchema({ + resolvers: [SampleResolver], + validate: true, + }); + + const mutation = `mutation { + sampleMutation(input: { + stringField: "12345", + numberField: 15, + }) { + field + } + }`; + + const result = await graphql(schema, mutation); + expect(result.data).toBeNull(); + expect(result.errors).toHaveLength(1); + + const validationError = result.errors![0].originalError! as ArgumentValidationError; + expect(validationError).toBeInstanceOf(ArgumentValidationError); + expect(validationError.validationErrors).toHaveLength(1); + expect(validationError.validationErrors[0].property).toEqual("numberField"); + }); +}); + +async function getSampleObjectType(SampleObject: SampleObject) { + @Resolver() + class SampleResolver { + @Query(() => SampleObject) + sampleQuery(): SampleObject { + return {} as SampleObject; + } + } + + const schemaInfo = await getSchemaInfo({ + resolvers: [SampleResolver], + }); + const schemaIntrospection = schemaInfo.schemaIntrospection; + const sampleObjectType = schemaIntrospection.types.find( + type => type.name === "SampleObject", + ) as IntrospectionObjectType; + return sampleObjectType; +}