diff --git a/src/cli.ts b/src/cli.ts index 41cf4aad..741a2711 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -69,8 +69,11 @@ const main = async () => { })) let success = true - for (const { outputFileName, paths } of results) { - let content = args.format === 'ts' ? tsString(paths) : jsonString(paths) + for (const { outputFileName, paths, components } of results) { + let content = + args.format === 'ts' + ? tsString(paths, components) + : jsonString(paths, components) if (args.prettify) { content = await runPrettier(outputFileName, content) } @@ -133,15 +136,22 @@ const writeOutput = (fileName: string, content: string): void => { fs.writeFileSync(fileName, content) } -const tsString = (paths: OpenAPIV3.PathsObject): string => `\ +const tsString = ( + paths: OpenAPIV3.PathsObject, + components: OpenAPIV3.ComponentsObject +): string => `\ import { OpenAPIV3 } from 'openapi-types' -const spec: { paths: OpenAPIV3.PathsObject } = ${JSON.stringify({ paths })}; +const spec: { paths: OpenAPIV3.PathsObject, components: OpenAPIV3.ComponentsObject } = ${JSON.stringify( + { paths, components } +)}; export default spec; ` -const jsonString = (paths: OpenAPIV3.PathsObject): string => - JSON.stringify({ paths }) +const jsonString = ( + paths: OpenAPIV3.PathsObject, + components: OpenAPIV3.ComponentsObject +): string => JSON.stringify({ paths, components }) main() diff --git a/src/components.ts b/src/components.ts new file mode 100644 index 00000000..eef8b9b0 --- /dev/null +++ b/src/components.ts @@ -0,0 +1,96 @@ +import * as ts from 'typescript' +import { OpenAPIV3 } from 'openapi-types' +import { isInterface, isTypeAlias } from './utils' + +export class Components { + // undefined acts as a placeholder + #schemas: Map + #symbolSchemas: Map + + constructor() { + this.#schemas = new Map() + this.#symbolSchemas = new Map() + } + + withSymbol( + symbol: ts.Symbol | undefined, + run: ( + addComponent: () => void + ) => OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | undefined + ): OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | undefined { + const ref = symbol && this.#getRefForSymbol(symbol) + if (ref) { + return { $ref: ref } + } + + if (symbol && (isInterface(symbol) || isTypeAlias(symbol))) { + let added = false + const schema = run(() => { + this.#addSymbol(symbol) + added = true + }) + if (added) { + if (schema === undefined || '$ref' in schema) { + this.#deleteSymbol(symbol) + return schema + } else { + return { $ref: this.#addSchema(symbol, schema) } + } + } else { + return schema + } + } else { + return run(() => {}) + } + } + + #getRefForSymbol(symbol: ts.Symbol): string | undefined { + const schemaName = this.#symbolSchemas.get(symbol) + return schemaName !== undefined + ? `#/components/schemas/${schemaName}` + : undefined + } + + #addSymbol(symbol: ts.Symbol) { + if (this.#symbolSchemas.has(symbol)) return + + let name = symbol.name + for (let i = 2; ; i++) { + if (this.#schemas.has(name)) { + name = `${symbol.name}${i}` + } else { + break + } + } + + this.#schemas.set(name, undefined) + this.#symbolSchemas.set(symbol, name) + } + + #deleteSymbol(symbol: ts.Symbol) { + const name = this.#symbolSchemas.get(symbol) + if (name === undefined) return + + this.#schemas.delete(name) + this.#symbolSchemas.delete(symbol) + } + + #addSchema(symbol: ts.Symbol, schema: OpenAPIV3.SchemaObject): string { + const name = this.#symbolSchemas.get(symbol) + if (name === undefined) + throw new Error(`No schema has been added for symbol ${symbol.name}`) + this.#schemas.set(name, schema) + return `#/components/schemas/${name}` + } + + build(): OpenAPIV3.ComponentsObject { + const schemas = Object.fromEntries( + [...this.#schemas.entries()].flatMap(([k, v]) => + v !== undefined ? [[k, v]] : [] + ) + ) + return { + ...(this.#schemas.size > 0 ? { schemas } : undefined), + } + } +} diff --git a/src/generate.ts b/src/generate.ts index 4576b989..a1802802 100644 --- a/src/generate.ts +++ b/src/generate.ts @@ -21,6 +21,7 @@ import { getBrandedType, getPromisePayloadType, } from './utils' +import { Components } from './components' interface GenerateOptions { log: Logger @@ -29,6 +30,7 @@ interface GenerateOptions { interface Result { fileName: string paths: OpenAPIV3.PathsObject + components: OpenAPIV3.ComponentsObject } export const generate = ( @@ -47,12 +49,18 @@ export const generate = ( if (sourceFile.isDeclarationFile) continue ts.forEachChild(sourceFile, (node) => { + const components = new Components() const paths = visitTopLevelNode( context(checker, sourceFile, log, node), + components, node ) if (paths) { - result.push({ fileName: sourceFile.fileName, paths }) + result.push({ + fileName: sourceFile.fileName, + paths, + components: components.build(), + }) } }) } @@ -62,6 +70,7 @@ export const generate = ( const visitTopLevelNode = ( ctx: Context, + components: Components, node: ts.Node ): OpenAPIV3.PathsObject | undefined => { if (ts.isExportAssignment(node) && !node.isExportEquals) { @@ -86,6 +95,7 @@ const visitTopLevelNode = ( } const routeDeclaration = getRouteDeclaration( withLocation(ctx, location), + components, symbol ) if (routeDeclaration) { @@ -125,6 +135,7 @@ const getRouterCallArgSymbols = ( const getRouteDeclaration = ( ctx: Context, + components: Components, symbol: ts.Symbol ): [string, Method, OpenAPIV3.OperationObject] | undefined => { const description = getDescriptionFromComment(ctx, symbol) @@ -152,12 +163,12 @@ const getRouteDeclaration = ( ? ctx.checker.typeToString(contentType).replace(/"/g, '') : undefined - const responses = getResponseTypes(ctx, symbol) + const responses = getResponseTypes(ctx, components, symbol) if (!responses) return const requestBody = requestNode && body - ? typeToSchema(withLocation(ctx, requestNode), body) + ? typeToSchema(withLocation(ctx, requestNode), components, body) : undefined const parameters = [ @@ -207,7 +218,7 @@ const getRouteTags = (symbol: ts.Symbol): string[] | undefined => .map((tag) => tag.trim()) const operationRequestBody = ( - contentSchema: OpenAPIV3.SchemaObject | undefined, + contentSchema: OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | undefined, contentType = 'application/json' ): { requestBody: OpenAPIV3.RequestBodyObject } | undefined => { if (!contentSchema) return @@ -392,6 +403,7 @@ const getRouteInput = ( const getResponseTypes = ( ctx: Context, + components: Components, symbol: ts.Symbol ): OpenAPIV3.ResponsesObject | undefined => { const descriptions = getResponseDescriptions(symbol) @@ -436,11 +448,21 @@ const getResponseTypes = ( const result: OpenAPIV3.ResponsesObject = {} if (isObjectType(responseType)) { - const responseDef = getResponseDefinition(ctx, descriptions, responseType) + const responseDef = getResponseDefinition( + ctx, + components, + descriptions, + responseType + ) if (responseDef) result[responseDef.status] = responseDef.response } else if (responseType.isUnion()) { responseType.types.forEach((type) => { - const responseDef = getResponseDefinition(ctx, descriptions, type) + const responseDef = getResponseDefinition( + ctx, + components, + descriptions, + type + ) if (responseDef) result[responseDef.status] = responseDef.response }) } @@ -473,6 +495,7 @@ const getResponseDescriptions = ( const getResponseDefinition = ( ctx: Context, + components: Components, responseDescriptions: Partial>, responseType: ts.Type ): { status: string; response: OpenAPIV3.ResponseObject } | undefined => { @@ -507,11 +530,9 @@ const getResponseDefinition = ( return } - // TODO: If bodyType is an interface (or type alias?), generate a schema - // component object and a reference to it? - let bodySchema: OpenAPIV3.SchemaObject | undefined + let bodySchema: OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | undefined if (!isUndefinedType(bodyType)) { - bodySchema = typeToSchema(ctx, bodyType) + bodySchema = typeToSchema(ctx, components, bodyType) if (!bodySchema) return } @@ -600,143 +621,156 @@ const getBaseSchema = ( const typeToSchema = ( ctx: Context, + components: Components, type: ts.Type, - options: { symbol?: ts.Symbol; optional?: boolean } = {} -): OpenAPIV3.SchemaObject | undefined => { - let base = getBaseSchema(ctx, options.symbol) - - if (type.isUnion()) { - let elems = type.types + options: { propSymbol?: ts.Symbol; optional?: boolean } = {} +): OpenAPIV3.SchemaObject | OpenAPIV3.ReferenceObject | undefined => + components.withSymbol(type.aliasSymbol ?? type.getSymbol(), (addComponent) => { + let base = getBaseSchema(ctx, options.propSymbol) - if (options.optional) { - elems = type.types.filter((elem) => !isUndefinedType(elem)) - } + if (type.isUnion()) { + let elems = type.types - if (elems.some(isNullType)) { - // One of the union elements is null - base = { ...base, nullable: true } - elems = elems.filter((elem) => !isNullType(elem)) - } + if (options.optional) { + elems = type.types.filter((elem) => !isUndefinedType(elem)) + } - if (elems.every(isBooleanLiteralType)) { - // All elements are boolean literals => boolean - return { type: 'boolean', ...base } - } else if (elems.every(isNumberLiteralType)) { - // All elements are number literals => enum - return { - type: 'number', - enum: elems.map((elem) => elem.value), - ...base, + if (elems.some(isNullType)) { + // One of the union elements is null + base = { ...base, nullable: true } + elems = elems.filter((elem) => !isNullType(elem)) } - } else if (elems.every(isStringLiteralType)) { - // All elements are string literals => enum - return { - type: 'string', - enum: elems.map((elem) => elem.value), - ...base, + + if (elems.every(isBooleanLiteralType)) { + // All elements are boolean literals => boolean + return { type: 'boolean', ...base } } - } else if (elems.length >= 2) { - // 2 or more types remain => anyOf - return { - anyOf: elems.map((elem) => typeToSchema(ctx, elem)).filter(isDefined), - ...base, + + if (elems.every(isNumberLiteralType)) { + // All elements are number literals => enum + addComponent() + return { + type: 'number', + enum: elems.map((elem) => elem.value), + ...base, + } + } else if (elems.every(isStringLiteralType)) { + // All elements are string literals => enum + addComponent() + return { + type: 'string', + enum: elems.map((elem) => elem.value), + ...base, + } + } else if (elems.length >= 2) { + // 2 or more types remain => anyOf + addComponent() + return { + anyOf: elems + .map((elem) => typeToSchema(ctx, components, elem)) + .filter(isDefined), + ...base, + } + } else { + // Only one element left in the union. Fall through and consider it as the + // sole type. + type = elems[0] } - } else { - // Only one element left in the union. Fall through and consider it as the - // sole type. - type = elems[0] } - } - if (isArrayType(type)) { - const elemType = type.getNumberIndexType() - if (!elemType) { - ctx.log('warn', 'Could not get array element type') - return - } - const elemSchema = typeToSchema(ctx, elemType) - if (!elemSchema) return + if (isArrayType(type)) { + const elemType = type.getNumberIndexType() + if (!elemType) { + ctx.log('warn', 'Could not get array element type') + return + } + const elemSchema = typeToSchema(ctx, components, elemType) + if (!elemSchema) return - return { type: 'array', items: elemSchema, ...base } - } + return { type: 'array', items: elemSchema, ...base } + } - if (isDateType(type)) { - // TODO: dates are always represented as date-time strings. It should be - // possible to override this. - return { type: 'string', format: 'date-time', ...base } - } + if (isDateType(type)) { + // TODO: dates are always represented as date-time strings. It should be + // possible to override this. + return { type: 'string', format: 'date-time', ...base } + } - if (isBufferType(type)) { - return { type: 'string', format: 'binary', ...base } - } + if (isBufferType(type)) { + return { type: 'string', format: 'binary', ...base } + } - if ( - isObjectType(type) || - (type.isIntersection() && type.types.every((part) => isObjectType(part))) - ) { - const props = ctx.checker.getPropertiesOfType(type) - return { - type: 'object', - required: props - .filter((prop) => !isOptional(prop)) - .map((prop) => prop.name), - ...base, - properties: Object.fromEntries( - props - .map((prop) => { - const propType = ctx.checker.getTypeOfSymbolAtLocation( - prop, - ctx.location - ) - if (!propType) { - ctx.log('warn', 'Could not get type for property', prop.name) - return - } - const propSchema = typeToSchema(ctx, propType, { - symbol: prop, - optional: isOptional(prop), + if ( + isObjectType(type) || + (type.isIntersection() && type.types.every((part) => isObjectType(part))) + ) { + addComponent() + const props = ctx.checker.getPropertiesOfType(type) + return { + type: 'object', + required: props + .filter((prop) => !isOptional(prop)) + .map((prop) => prop.name), + ...base, + properties: Object.fromEntries( + props + .map((prop) => { + const propType = ctx.checker.getTypeOfSymbolAtLocation( + prop, + ctx.location + ) + if (!propType) { + ctx.log('warn', 'Could not get type for property', prop.name) + return + } + const propSchema = typeToSchema(ctx, components, propType, { + propSymbol: prop, + optional: isOptional(prop), + }) + if (!propSchema) { + ctx.log('warn', 'Could not get schema for property', prop.name) + return + } + return [prop.name, propSchema] }) - if (!propSchema) { - ctx.log('warn', 'Could not get schema for property', prop.name) - return - } - return [prop.name, propSchema] - }) - .filter(isDefined) - ), + .filter(isDefined) + ), + } } - } - if (isStringType(type)) { - return { type: 'string', ...base } - } - if (isNumberType(type)) { - return { type: 'number', ...base } - } - if (isBooleanType(type)) { - return { type: 'boolean', ...base } - } - if (isStringLiteralType(type)) { - return { type: 'string', enum: [type.value], ...base } - } - if (isNumberLiteralType(type)) { - return { type: 'number', enum: [type.value], ...base } - } + if (isStringType(type)) { + return { type: 'string', ...base } + } + if (isNumberType(type)) { + return { type: 'number', ...base } + } + if (isBooleanType(type)) { + return { type: 'boolean', ...base } + } + if (isStringLiteralType(type)) { + return { type: 'string', enum: [type.value], ...base } + } + if (isNumberLiteralType(type)) { + return { type: 'number', enum: [type.value], ...base } + } - const branded = getBrandedType(ctx, type) - if (branded) { - // io-ts branded type - const { brandName, brandedType } = branded + const branded = getBrandedType(ctx, type) + if (branded) { + // io-ts branded type + const { brandName, brandedType } = branded - if (brandName === 'Brand') { - // io-ts Int - return { type: 'integer', ...base } - } + if (brandName === 'Brand') { + // io-ts Int + return { type: 'integer', ...base } + } - // other branded type - return typeToSchema(ctx, brandedType, options) - } + // other branded type + return typeToSchema(ctx, components, brandedType, options) + } - ctx.log('warn', `Ignoring an unknown type: ${ctx.checker.typeToString(type)}`) - return -} + ctx.log( + 'warn', + `Ignoring an unknown type: ${ctx.checker.typeToString(type)}` + ) + return + }) diff --git a/src/utils.ts b/src/utils.ts index f192fa91..0698b77c 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -28,6 +28,10 @@ export const isUndefinedType = (type: ts.Type): boolean => !!(type.flags & ts.TypeFlags.Undefined) export const isNullType = (type: ts.Type): boolean => !!(type.flags & ts.TypeFlags.Null) +export const isInterface = (symbol: ts.Symbol): boolean => + !!(symbol.flags & ts.SymbolFlags.Interface) +export const isTypeAlias = (symbol: ts.Symbol): boolean => + !!(symbol.flags & ts.SymbolFlags.TypeAlias) // Check for a specific object type based on type name and property names const duckTypeChecker = diff --git a/tests/__snapshots__/generate.spec.ts.snap b/tests/__snapshots__/generate.spec.ts.snap index c97dd308..64be5ee9 100644 --- a/tests/__snapshots__/generate.spec.ts.snap +++ b/tests/__snapshots__/generate.spec.ts.snap @@ -13,6 +13,116 @@ Array [ exports[`generate works 1`] = ` Array [ Object { + "components": Object { + "schemas": Object { + "DirectRecursiveIntersection": Object { + "properties": Object { + "children": Object { + "$ref": "#/components/schemas/DirectRecursiveIntersection", + }, + "id": Object { + "type": "string", + }, + }, + "required": Array [ + "id", + "children", + ], + "type": "object", + }, + "DirectRecursiveType": Object { + "properties": Object { + "children": Object { + "items": Object { + "$ref": "#/components/schemas/DirectRecursiveType", + }, + "type": "array", + }, + "id": Object { + "type": "string", + }, + }, + "required": Array [ + "id", + "children", + ], + "type": "object", + }, + "DocumentedInterface": Object { + "properties": Object { + "outputField": Object { + "description": "Output field description here", + "type": "string", + }, + }, + "required": Array [ + "outputField", + ], + "type": "object", + }, + "IndirectRecursiveType": Object { + "properties": Object { + "hello": Object { + "type": "string", + }, + "items": Object { + "items": Object { + "$ref": "#/components/schemas/MutuallyRecursive", + }, + "type": "array", + }, + }, + "required": Array [ + "hello", + "items", + ], + "type": "object", + }, + "MutuallyRecursive": Object { + "properties": Object { + "other": Object { + "$ref": "#/components/schemas/IndirectRecursiveType", + }, + }, + "required": Array [ + "other", + ], + "type": "object", + }, + "User": Object { + "properties": Object { + "petName": Object { + "nullable": true, + "type": "string", + }, + "shoeSize": Object { + "type": "number", + }, + "updated": Object { + "format": "date-time", + "type": "string", + }, + }, + "required": Array [ + "shoeSize", + "petName", + "updated", + ], + "type": "object", + }, + "User2": Object { + "properties": Object { + "name": Object { + "type": "string", + }, + }, + "required": Array [ + "name", + ], + "type": "object", + }, + }, + }, "fileName": "tests/test-routes.ts", "paths": Object { "/binary-response": Object { @@ -216,25 +326,7 @@ Array [ "application/json": Object { "schema": Object { "items": Object { - "properties": Object { - "petName": Object { - "nullable": true, - "type": "string", - }, - "shoeSize": Object { - "type": "number", - }, - "updated": Object { - "format": "date-time", - "type": "string", - }, - }, - "required": Array [ - "shoeSize", - "petName", - "updated", - ], - "type": "object", + "$ref": "#/components/schemas/User", }, "type": "array", }, @@ -252,25 +344,7 @@ Array [ "content": Object { "application/json": Object { "schema": Object { - "properties": Object { - "petName": Object { - "nullable": true, - "type": "string", - }, - "shoeSize": Object { - "type": "number", - }, - "updated": Object { - "format": "date-time", - "type": "string", - }, - }, - "required": Array [ - "shoeSize", - "petName", - "updated", - ], - "type": "object", + "$ref": "#/components/schemas/User", }, }, }, @@ -316,9 +390,9 @@ Array [ "responses": Object { "200": Object { "content": Object { - "text/plain": Object { + "application/json": Object { "schema": Object { - "type": "string", + "$ref": "#/components/schemas/User2", }, }, }, @@ -365,6 +439,22 @@ Array [ }, }, }, + "/recursive-types": Object { + "get": Object { + "responses": Object { + "200": Object { + "content": Object { + "application/json": Object { + "schema": Object { + "$ref": "#/components/schemas/IndirectRecursiveType", + }, + }, + }, + "description": "OK", + }, + }, + }, + }, "/request-body": Object { "post": Object { "description": "This one has request body and two possible successful responses and multiple tags", @@ -620,16 +710,7 @@ Array [ "content": Object { "application/json": Object { "schema": Object { - "properties": Object { - "outputField": Object { - "description": "Output field description here", - "type": "string", - }, - }, - "required": Array [ - "outputField", - ], - "type": "object", + "$ref": "#/components/schemas/DocumentedInterface", }, }, }, diff --git a/tests/exported-routes.ts b/tests/exported-routes.ts index aea163b1..76487c50 100644 --- a/tests/exported-routes.ts +++ b/tests/exported-routes.ts @@ -1,9 +1,14 @@ import { Response, Route, route } from 'typera-express' -export const otherFileExport: Route> = route +// This interface also exists in test-routes.ts => it should be named `User2` in component schemas. +interface User { + name: string +} + +export const otherFileExport: Route> = route .get('/other-file-export') .handler(async () => { - return Response.ok('hello') + return Response.ok({ name: 'hello' }) }) export default route.get('/other-file-default-export').handler(async () => { diff --git a/tests/test-routes.ts b/tests/test-routes.ts index cc75aee1..12d53c62 100644 --- a/tests/test-routes.ts +++ b/tests/test-routes.ts @@ -273,6 +273,32 @@ const withContentTypeMiddleware: Route< return Response.ok(request.body.a) }) +interface DirectRecursiveType { + id: string + children: DirectRecursiveType[] +} + +type DirectRecursiveIntersection = { id: string } & { + children: DirectRecursiveIntersection +} + +interface IndirectRecursiveType { + hello: string + items: MutuallyRecursive[] +} + +interface MutuallyRecursive { + other: IndirectRecursiveType +} + +const recursiveTypes: Route< + | Response.Ok + | Response.Ok + | Response.Ok +> = route.get('/recursive-types').handler(async () => { + return Response.ok({ id: 'hell', children: [] }) +}) + export default router( constant, directRouteCall, @@ -295,6 +321,7 @@ export default router( handlerNotInline, typeAlias, withContentTypeMiddleware, + recursiveTypes, otherFileExport, // export from another module otherFileDefaultExport // default export from another module )