diff --git a/tools/packages/bench/bin/bench.ts b/tools/packages/bench/bin/bench.ts index c91974982..6309363a5 100644 --- a/tools/packages/bench/bin/bench.ts +++ b/tools/packages/bench/bin/bench.ts @@ -1,7 +1,7 @@ import { WGSLLinker } from "@use-gpu/shader"; import fs from "fs/promises"; import path from "node:path"; -import { link, parseWESL } from "wesl"; +import { link, ModulePath, parseWESL } from "wesl"; import { WgslReflect } from "wgsl_reflect"; import yargs from "yargs"; @@ -105,7 +105,7 @@ function runNTimes(n: number, variant: ParserVariant, file: BenchTest): void { interface BenchTest { name: string; /** Path to the main file */ - mainFile: string; + rootModule: ModulePath; /** All relevant files (file paths and their contents) */ files: Map; } @@ -115,34 +115,44 @@ async function loadAllFiles(): Promise { const reduceBuffer = await loadBench( "reduceBuffer", examplesDir, - "./reduceBuffer.wgsl", + ["./reduceBuffer.wgsl"], + ["package", "reduceBuffer"], + ); + const particle = await loadBench( + "particle", + examplesDir, + ["./particle.wgsl"], + ["package", "particle"], ); - const particle = await loadBench("particle", examplesDir, "./particle.wgsl"); const rasterize = await loadBench( "rasterize", examplesDir, - "./rasterize_05_fine.wgsl", + ["./rasterize_05_fine.wgsl"], + ["package", "rasterize_05_fine"], ); const boat = await loadBench( "unity_webgpu_0000026E5689B260", examplesDir, - "./unity_webgpu_000002B8376A5020.fs.wgsl", + ["./unity_webgpu_000002B8376A5020.fs.wgsl"], + ["package", "unity_webgpu_000002B8376A5020"], ); const imports_only = await loadBench( "imports_only", examplesDir, - "./imports_only.wgsl", + ["./imports_only.wgsl"], + ["package", "imports_only"], ); const bevy_deferred_lighting = await loadBench( "bevy_deferred_lighting", "./src/examples/bevy", - "./bevy_generated_deferred_lighting.wgsl", + ["./bevy_generated_deferred_lighting.wgsl"], + ["package", "bevy_generated_deferred_lighting"], ); const bevy_linking = await loadBench( "bevy_linking", "./src/examples/naga_oil_example", - "./pbr.wgsl", [ + "./pbr.wgsl", "./clustered_forward.wgsl", "./mesh_bindings.wgsl", "./mesh_types.wgsl", @@ -156,6 +166,7 @@ async function loadAllFiles(): Promise { "./shadows.wgsl", "./utils.wgsl", ], + ["package", "pbr"], ); return [ reduceBuffer, @@ -171,18 +182,17 @@ async function loadAllFiles(): Promise { async function loadBench( name: string, cwd: string, - mainFile: string, - extraFiles: string[] = [], + filePaths: string[] = [], + rootModule: ModulePath, ): Promise { const files = new Map(); const addFile = async (p: string) => files.set(p, await fs.readFile(path.join(cwd, p), { encoding: "utf8" })); - await addFile(mainFile); - for (const path of extraFiles) { + for (const path of filePaths) { await addFile(path); } - return { name, mainFile, files }; + return { name, rootModule, files }; } function runOnce(parserVariant: ParserVariant, test: BenchTest): void { @@ -193,7 +203,7 @@ function runOnce(parserVariant: ParserVariant, test: BenchTest): void { } else if (parserVariant === "wesl-link") { link({ weslSrc: Object.fromEntries(test.files.entries()), - rootModuleName: test.mainFile, + rootModulePath: test.rootModule, }); } else if (parserVariant === "wgsl_reflect") { for (const [path, text] of test.files) { diff --git a/tools/packages/bulk-test/src/stripWgsl.ts b/tools/packages/bulk-test/src/stripWgsl.ts index 08d28e7b7..effda2cf6 100644 --- a/tools/packages/bulk-test/src/stripWgsl.ts +++ b/tools/packages/bulk-test/src/stripWgsl.ts @@ -1,34 +1,14 @@ import { WeslStream } from "wesl"; -/** Remove extra bits from WGSL for test comparisons. - * - * removes: - * . extra whitespace, - * . comments, - * . trailing commas in brackets, paren, and array containers - */ +/** Removes extra whitespace and comments from WGSL */ export function stripWesl(text: string): string { const stream = new WeslStream(text); const firstToken = stream.nextToken(); if (firstToken === null) return ""; - let result = firstToken.text; while (true) { const token = stream.nextToken(); if (token === null) return result; - - if (token.text === ",") { - const nextToken = stream.nextToken(); - const nextText = nextToken?.text; - if (nextText === "}" || nextText === "]" || nextText === ")") { - // Ignore trailing comma - result += " "; - result += nextText; - } else { - result += ", " + (nextToken?.text ?? ""); - } - } else { - result += " " + token.text; - } + result += " " + token.text; } } diff --git a/tools/packages/bulk-test/src/test/StripWgsl.test.ts b/tools/packages/bulk-test/src/test/StripWgsl.test.ts deleted file mode 100644 index fc85c7844..000000000 --- a/tools/packages/bulk-test/src/test/StripWgsl.test.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { expect, test } from "vitest"; -import { stripWesl } from "../stripWgsl.ts"; - -test("strip trailing commas", () => { - const withComma = ` - struct A { a: f32, } - `; - const noComma = ` - struct A { - a: f32 - } - `; - - expect(stripWesl(withComma)).toMatchInlineSnapshot(`"struct A { a : f32 }"`); - expect(stripWesl(noComma)).toMatchInlineSnapshot(`"struct A { a : f32 }"`); - expect(stripWesl(noComma)).equals(stripWesl(withComma)); -}); diff --git a/tools/packages/bulk-test/src/testWgslFiles.ts b/tools/packages/bulk-test/src/testWgslFiles.ts index e52aedf1a..d18faf5f2 100644 --- a/tools/packages/bulk-test/src/testWgslFiles.ts +++ b/tools/packages/bulk-test/src/testWgslFiles.ts @@ -19,13 +19,13 @@ export async function testWgslFiles(namedPaths: NamedPath[]) { const config = { plugins: [bindingStructsPlugin()] }; namedPaths.forEach(({ name, filePath }) => { - const shortPath = "./" + name; test(name, async () => { const orig = await fs.readFile(filePath, { encoding: "utf8" }); const result = await expectNoLogAsync(() => { - const weslSrc = { [shortPath]: orig }; - const rootModuleName = noSuffix(name); - return link({ weslSrc, rootModuleName, config }); + const shortPath = "./" + name; + const weslSrc = { [shortPath]: text }; + const rootModulePath = ["package", ...noSuffix(name).split("/")]; + return link({ weslSrc, rootModulePath, config }); }); expect(stripWesl(result.dest)).toBe(stripWesl(orig)); }); diff --git a/tools/packages/mini-parse/src/Parser.ts b/tools/packages/mini-parse/src/Parser.ts index a91e8b704..2fc13e2e1 100644 --- a/tools/packages/mini-parse/src/Parser.ts +++ b/tools/packages/mini-parse/src/Parser.ts @@ -18,7 +18,8 @@ import { tracing, withTraceLogging, } from "./ParserTracing.js"; -import { Stream, Token, TypedToken } from "./Stream.js"; +import { Span } from "./Span.js"; +import { peekToken, Stream, Token, TypedToken } from "./Stream.js"; export interface AppState { /** @@ -215,6 +216,18 @@ export class Parser { return map(this, fn); } + /** map results to a new value. + */ + mapSpanned(fn: (value: T, span: Span) => U): Parser { + return mapSpanned(this, fn); + } + + /** map results to a new value, or backtracks + */ + verifyMap(fn: (value: T) => OptParserResult): Parser { + return verifyMap(this, fn); + } + /** Queue a function that runs later, typically to collect AST elements from the parse. * when a commit() is parsed. * Collection functions are dropped with parser backtracking, so @@ -363,6 +376,43 @@ function map(p: Parser, fn: (value: T) => U): Parser { return mapParser; } +/** return a parser that maps the current results */ +function mapSpanned( + p: Parser, + fn: (value: T, span: Span) => U, +): Parser { + const mapSpannedParser = parser( + `mapSpanned`, + function _mapSpanned(ctx: ParserContext): OptParserResult { + const start = peekToken(ctx.stream)?.span?.[0] ?? null; + const result = p._run(ctx); + if (result === null) return null; + const end = ctx.stream.checkpoint(); + return { value: fn(result.value, [start ?? end, end]) }; + }, + ); + + trackChildren(mapSpannedParser, p); + return mapSpannedParser; +} + +function verifyMap( + p: Parser, + fn: (value: T) => OptParserResult, +): Parser { + const verifyMapParser = parser( + `verifyMap`, + function _verifyMap(ctx: ParserContext): OptParserResult { + const result = p._run(ctx); + if (result === null) return null; + return fn(result.value); + }, + ); + + trackChildren(verifyMapParser, p); + return verifyMapParser; +} + type ToParserFn = (results: ParserResult) => Parser | null; function toParser( diff --git a/tools/packages/mini-parse/src/ParserLogging.ts b/tools/packages/mini-parse/src/ParserLogging.ts index 318f84439..41fe0816d 100644 --- a/tools/packages/mini-parse/src/ParserLogging.ts +++ b/tools/packages/mini-parse/src/ParserLogging.ts @@ -1,17 +1,13 @@ import { ParserContext } from "./Parser.js"; import { parserLog, tracePos, tracing } from "./ParserTracing.js"; -import { SrcMap, SrcWithPath } from "./SrcMap.js"; +import { Span } from "./Span.js"; import { log } from "./WrappedLog.js"; /** log an message along with the source line and a caret indicating the error position in the line * @param pos is the position the source string, or if src is a SrcMap, then * pos is the position in the dest (e.g. preprocessed) text */ -export function srcLog( - src: string | SrcMap, - pos: number | [number, number], - ...msgs: any[] -): void { +export function srcLog(src: string, pos: number | Span, ...msgs: any[]): void { logInternal(log, src, pos, ...msgs); } @@ -21,7 +17,7 @@ export function quotedText(text?: string): string { /** log a message along with src line, but only if tracing is active in the current parser */ export function srcTrace( - src: string | SrcMap, + src: string, pos: number | [number, number], ...msgs: any[] ): void { @@ -38,45 +34,17 @@ export function ctxLog(ctx: ParserContext, ...msgs: any[]): void { */ function logInternal( log: typeof console.log, - srcOrSrcMap: string | SrcMap, - destPos: number | [number, number], + src: string, + destPos: number | Span, ...msgs: any[] ): void { - if (typeof srcOrSrcMap === "string") { - logInternalSrc(log, srcOrSrcMap, destPos, ...msgs); - return; - } - const { src, positions } = mapSrcPositions(srcOrSrcMap, destPos); - - logInternalSrc(log, src.text, positions, ...msgs); -} - -interface SrcPositions { - positions: number | [number, number]; - src: SrcWithPath; -} - -function mapSrcPositions( - srcMap: SrcMap, - destPos: number | [number, number], -): SrcPositions { - const srcPos = srcMap.mapPositions(...[destPos].flat()); - const { src } = srcPos[0]; - - let positions: [number, number] | number; - if (srcPos[1]?.src?.path === src.path && srcPos[1]?.src?.text === src.text) { - positions = srcPos.map(p => p.position) as [number, number]; - } else { - positions = srcPos[0].position; - } - - return { src, positions }; + logInternalSrc(log, src, destPos, ...msgs); } function logInternalSrc( log: typeof console.log, src: string, - pos: number | [number, number], + pos: number | Span, ...msgs: any[] ): void { log(...msgs); @@ -120,10 +88,7 @@ interface SrcLine { } /** return the line in the src containing a given character postion */ -export function srcLine( - src: string, - position: number | [number, number], -): SrcLine { +export function srcLine(src: string, position: number | Span): SrcLine { let pos: number; let pos2: number | undefined; if (typeof position === "number") { diff --git a/tools/packages/mini-parse/src/Span.ts b/tools/packages/mini-parse/src/Span.ts index d5998cbde..e29853c32 100644 --- a/tools/packages/mini-parse/src/Span.ts +++ b/tools/packages/mini-parse/src/Span.ts @@ -2,3 +2,7 @@ * An range, from start (inclusive) to end (exclusive). */ export type Span = readonly [number, number]; + +export function isSpan(span: any): span is Span { + return Array.isArray(span) && span.length === 2; +} diff --git a/tools/packages/mini-parse/src/SrcMap.ts b/tools/packages/mini-parse/src/SrcMap.ts index 21bfce503..e2c51dc87 100644 --- a/tools/packages/mini-parse/src/SrcMap.ts +++ b/tools/packages/mini-parse/src/SrcMap.ts @@ -1,16 +1,15 @@ -/** A source map file, and a path for debug purposes. */ +import { Span } from "./Span.ts"; + export interface SrcWithPath { /** User friendly path */ path?: string; text: string; } -export interface SrcMapEntry { - src: SrcWithPath; - srcStart: number; - srcEnd: number; - destStart: number; - destEnd: number; +export interface SrcMapBuilder { + add(fragment: string, srcSpan: Span, isName?: boolean): void; + addRange(fragment: string, srcStart: number, isName?: boolean): void; + addSynthetic(fragment: string): void; } export interface SrcPosition { @@ -18,129 +17,270 @@ export interface SrcPosition { position: number; } -/** map text ranges in multiple src texts to a single dest text */ -export class SrcMap { - entries: SrcMapEntry[]; - dest: SrcWithPath; +/** A source span. It's possible that the end of the generated span doesn't map to a sensible source. */ +export interface SrcSpan { + src: SrcWithPath; + span: [number, number | null]; +} - constructor(dest: SrcWithPath, entries: SrcMapEntry[] = []) { - this.dest = dest; - this.entries = entries; +export class SrcMap { + private sources: SrcWithPath[] = []; + private dest: string = ""; + private entries: MinSrcMapEntry[] = []; + constructor() {} + builderFor(source: SrcWithPath): SrcMapBuilder { + const sourceId = this.addSource(source); + const self = this; + return { + add(fragment, srcSpan, isName = false) { + const isRange = source.text.slice(srcSpan[0], srcSpan[1]) === fragment; + self.add({ + fragment, + srcSpan, + source: sourceId, + isName, + isRange, + }); + }, + addRange(fragment, srcStart, isName = false) { + const srcText = source.text.slice(srcStart, srcStart + fragment.length); + const isRange = srcText === fragment; + if (!isRange) { + throw new Error( + `${fragment} is not a range, the underlying text is ${srcText}`, + ); + } + self.add({ + fragment, + srcSpan: [srcStart, srcStart + fragment.length], + source: sourceId, + isName, + isRange, + }); + }, + addSynthetic(fragment) { + self.add({ + fragment, + srcSpan: null, + source: sourceId, + isName: false, + isRange: false, + }); + }, + }; } - - /** add a new mapping from src to dest ranges. - * entries must be non-overlapping in the destination - */ - addEntries(newEntries: SrcMapEntry[]): void { - this.entries.push(...newEntries); + addSource(source: SrcWithPath): number { + const srcId = this.sources.length; + this.sources.push(source); + return srcId; } + add(opts: { + fragment: string; + srcSpan: Span | null; + source: number; + isName: boolean; + isRange: boolean; + }) { + let flags = 0; + if (opts.isName) { + flags |= IsNameFlag; + } + if (opts.isRange) { + flags |= IsRangeFlag; + } + if (opts.srcSpan === null) { + flags |= IsSyntheticFlag; + } + const srcSpan = opts.srcSpan ?? [0, 0]; + + // We do not merge normal source map entries, since that'd lose information + // When remapping, we map destination locations to clamped source locations + // We can, however, losslessly merge range entries, and synthetic entries. + if ((opts.isRange || opts.srcSpan === null) && this.entries.length > 0) { + const lastEntry = this.entries[this.entries.length - 1]; + const canBeMerged = + (lastEntry.flags === flags && lastEntry.srcId === opts.source, + lastEntry.srcEnd === srcSpan[0]); + if (canBeMerged) { + lastEntry.srcEnd = srcSpan[1]; + this.dest += opts.fragment; + return; + } + } - /** given positions in the dest string, - * @return corresponding positions in the src strings */ - mapPositions(...positions: number[]): SrcPosition[] { - return positions.map(p => this.destToSrc(p)); + this.entries.push({ + srcId: opts.source, + srcStart: srcSpan[0], + srcEnd: srcSpan[1], + destStart: this.dest.length, + flags, + }); + this.dest += opts.fragment; } - /** internally compress adjacent entries where possible */ - compact(): void { - if (!this.entries.length) return; - let prev = this.entries[0]; - const newEntries: SrcMapEntry[] = [prev]; + /** Gets a source map entry. Filters out synthetic entries. */ + private getEntry(destPos: number): SrcMapEntry | null { + if ( + this.entries.length === 0 || + destPos < 0 || + destPos > this.dest.length + ) { + return null; + } + // LATER use correct binary search + // e.g. https://github.com/stefnotch/typestef/blob/a705b8a37ced3757ce0c613f75b0ea66fe71e932/src/array-utils.ts#L7 + let nextEntryIndex = this.entries.findIndex(e => destPos < e.destStart); + if (nextEntryIndex === 0) { + // The first entry already rejects us + return null; + } + let entryIndex = + nextEntryIndex === -1 ? this.entries.length - 1 : nextEntryIndex - 1; + let entry = this.entries[entryIndex]; - for (let i = 1; i < this.entries.length; i++) { - const e = this.entries[i]; + if ((entry.flags & IsSyntheticFlag) !== 0) { + // Attempt to fall back to the nearest non-synthetic entry + const previousEntry = this.entries.at(entryIndex - 1); if ( - e.src.path === prev.src.path && - e.src.text === prev.src.text && - prev.destEnd === e.destStart && - prev.srcEnd === e.srcStart + destPos === entry.destStart && + previousEntry !== undefined && + (previousEntry.flags & IsSyntheticFlag) === 0 ) { - // combine adjacent range entries into one - prev.destEnd = e.destEnd; - prev.srcEnd = e.srcEnd; + entryIndex = entryIndex - 1; + entry = previousEntry; } else { - newEntries.push(e); - prev = e; + return null; } } - this.entries = newEntries; - } - /** sort in destination order */ - sort(): void { - this.entries.sort((a, b) => a.destStart - b.destStart); - } + const nextEntryStart = + this.entries.at(entryIndex + 1)?.destStart ?? this.dest.length; - /** This SrcMap's destination is a src for the other srcmap, - * so combine the two and return the result. - */ - merge(other: SrcMap): SrcMap { - if (other === this) return this; - - const mappedEntries = other.entries.filter( - e => e.src.path === this.dest.path && e.src.text === this.dest.text, - ); - if (mappedEntries.length === 0) { - console.log("other source map does not link to this one"); - // dlog({ this: this }); - // dlog({ other }); - return other; - } - sortSrc(mappedEntries); - const newEntries = mappedEntries.map(e => { - const { src, position: srcStart } = this.destToSrc(e.srcStart); - const { src: endSrc, position: srcEnd } = this.destToSrc(e.srcEnd); - if (endSrc !== src) throw new Error("NYI, need to split"); - const newEntry: SrcMapEntry = { - src, - srcStart, - srcEnd, - destStart: e.destStart, - destEnd: e.destEnd, - }; - // dlog({ newEntry }); - return newEntry; - }); - - const otherSources = other.entries.filter( - e => e.src.path !== this.dest.path || e.src.text !== this.dest.text, - ); - - const newMap = new SrcMap(other.dest, [...otherSources, ...newEntries]); - newMap.sort(); - return newMap; + return { + index: entryIndex, + src: this.sources[entry.srcId], + srcSpan: [entry.srcStart, entry.srcEnd], + destSpan: [entry.destStart, nextEntryStart], + flags: entry.flags, + }; } /** - * @param entries should be sorted in destStart order * @return the source position corresponding to a provided destination position - * */ - destToSrc(destPos: number): SrcPosition { - const entry = this.entries.find( - e => e.destStart <= destPos && e.destEnd >= destPos, - ); - if (!entry) { - /* LATER: @stefnotch will replace with the reworked version - Original error: - this console.log triggers during debugging, now that preprocessing doesn't produce a real srcMap. - remove the warning or fix the reason for the warning? - */ - - // console.log(`no SrcMapEntry for dest position: ${destPos}`); - return { - src: this.dest, - position: destPos, - }; + destToSrc(destPos: number): SrcPosition | null { + const entry = this.getEntry(destPos); + if (entry === null) { + return null; + } + const position = mapPosition(entry, destPos, true); + if (position === null) { + return null; } return { + position, src: entry.src, - position: entry.srcStart + destPos - entry.destStart, }; } + + /** @return a source span corresponding to the provided destination span */ + destSpanToSrc(destSpan: Span): SrcSpan | null { + const startEntry = this.getEntry(destSpan[0]); + if (startEntry === null) { + return null; + } + const startPosition = mapPosition(startEntry, destSpan[0], true); + if (startPosition === null) { + return null; + } + + let endPosition: number | null; + // destSpan[1] is an exclusive range. + // also, account for the possibility of it being mapped to a wildly different place + if (destSpan[1] === destSpan[0]) { + endPosition = startPosition; + } else if (destSpan[1] <= startEntry.destSpan[1]) { + endPosition = mapPosition(startEntry, destSpan[1], false); + } else { + const endEntry = this.getEntry(destSpan[1]); + if (endEntry === null) { + endPosition = null; + } else { + if (endEntry.src !== startEntry.src) { + return null; + } else { + endPosition = mapPosition(endEntry, destSpan[1], false); + if (endPosition! < startPosition) { + // Nonsensical end position + endPosition = null; + } + } + } + } + + return { + src: startEntry.src, + span: [startPosition, endPosition], + }; + } + + /** @returns the destination text */ + get code(): string { + return this.dest; + } } -/** sort entries in place by src start position */ -function sortSrc(entries: SrcMapEntry[]): void { - entries.sort((a, b) => a.srcStart - b.srcStart); +type SrcMapEntryFlag = number; +/** Used for identifiers that have a meaning in the original source. */ +const IsNameFlag: SrcMapEntryFlag = 1 << 0; +/** Guarantees that every character in the dest can be mapped to the corresponding character in src. */ +const IsRangeFlag: SrcMapEntryFlag = 1 << 1; +/** A generated source that cannot be mapped back to the original source. */ +const IsSyntheticFlag: SrcMapEntryFlag = 2 << 1; + +/** Based more closely off of the src map specification */ +interface MinSrcMapEntry { + srcId: number; + /** Synthetic entries are mapped to [0,0], and get a special flag. */ + srcStart: number; + /** Extra field compared to the spec */ + srcEnd: number; + /** + * Dest *end* is the dest start of the next entry. + * If it's a range, then the src length and the dest length will be equal. + */ + destStart: number; + flags: SrcMapEntryFlag; +} + +/** A decompressed source map entry */ +interface SrcMapEntry { + index: number; + src: SrcWithPath; + srcSpan: Span; + destSpan: Span; + flags: SrcMapEntryFlag; +} + +function mapPosition( + entry: SrcMapEntry | null, + destPos: number, + clampStart: boolean, +): number | null { + if (entry === null) { + return null; + } + + if ((entry.flags & IsRangeFlag) !== 0) { + // Ranges get mapped to exact positions + return entry.srcSpan[0] + Math.max(0, destPos - entry.destSpan[0]); + } else if (destPos === entry.destSpan[0]) { + return entry.srcSpan[0]; + } else if (destPos === entry.destSpan[1]) { + return entry.srcSpan[1]; + } else if (clampStart) { + return entry.srcSpan[0]; + } else { + return entry.srcSpan[1]; + } } diff --git a/tools/packages/mini-parse/src/SrcMapBuilder.ts b/tools/packages/mini-parse/src/SrcMapBuilder.ts deleted file mode 100644 index 3c1792fc4..000000000 --- a/tools/packages/mini-parse/src/SrcMapBuilder.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { SrcMap, SrcMapEntry, SrcWithPath } from "./SrcMap.js"; - -/** - * Incrementally append to a string, tracking source references - */ -export class SrcMapBuilder { - #fragments: string[] = []; - #destLength = 0; - #entries: SrcMapEntry[] = []; - - constructor(public source: SrcWithPath) {} - - /** append a string fragment to the destination string */ - add(fragment: string, srcStart: number, srcEnd: number): void { - // dlog({fragment}) - const destStart = this.#destLength; - this.#destLength += fragment.length; - const destEnd = this.#destLength; - - this.#fragments.push(fragment); - this.#entries.push({ - src: this.source, - srcStart, - srcEnd, - destStart, - destEnd, - }); - } - - /** - * Append a fragment to the destination string, - * mapping source to the pervious, - * and guessing that the source fragment is just as long as the the dest fragment. - * (LATER we plan to drop or make optional src end positions) - */ - appendNext(fragment: string): void { - const lastEnd = this.#entries.at(-1)?.destEnd ?? 0; - this.add(fragment, lastEnd, lastEnd + fragment.length); - } - - addSynthetic( - fragment: string, - syntheticSource: string, - srcStart: number, - srcEnd: number, - ): void { - // dlog({fragment}) - const destStart = this.#destLength; - this.#destLength += fragment.length; - const destEnd = this.#destLength; - - this.#fragments.push(fragment); - this.#entries.push({ - src: { text: syntheticSource }, - srcStart, - srcEnd, - destStart, - destEnd, - }); - } - - /** append a synthetic newline, mapped to previous source location */ - addNl(): void { - const lastEntry = this.#entries.at(-1) ?? { srcStart: 0, srcEnd: 0 }; - const { srcStart, srcEnd } = lastEntry; - this.add("\n", srcStart, srcEnd); - } - - /** copy a string fragment from the src to the destination string */ - addCopy(srcStart: number, srcEnd: number): void { - const fragment = this.source.text.slice(srcStart, srcEnd); - this.add(fragment, srcStart, srcEnd); - } - - /** return a SrcMap */ - static build(builders: SrcMapBuilder[]): SrcMap { - const map = new SrcMap( - { text: builders.map(b => b.#fragments.join("")).join("") }, - builders.flatMap(b => b.#entries), - ); - map.compact(); - return map; - } -} diff --git a/tools/packages/mini-parse/src/index.ts b/tools/packages/mini-parse/src/index.ts index b86174798..72738c98b 100644 --- a/tools/packages/mini-parse/src/index.ts +++ b/tools/packages/mini-parse/src/index.ts @@ -6,7 +6,6 @@ export * from "./ParserToString.js"; export * from "./ParserTracing.js"; export * from "./Span.js"; export * from "./SrcMap.js"; -export * from "./SrcMapBuilder.js"; export * from "./Stream.js"; export { CachingStream } from "./stream/CachingStream.ts"; export { FilterStream } from "./stream/FilterStream.ts"; diff --git a/tools/packages/mini-parse/src/test-util/TestParse.ts b/tools/packages/mini-parse/src/test-util/TestParse.ts index 075da772d..e15460b49 100644 --- a/tools/packages/mini-parse/src/test-util/TestParse.ts +++ b/tools/packages/mini-parse/src/test-util/TestParse.ts @@ -35,34 +35,31 @@ export const testMatcher = new RegexMatchers({ ws: /\s+/, }); -export interface TestParseResult { +export interface TestParseResult { parsed: OptParserResult; position: number; - stable: S; } /** utility for testing parsers */ -export function testParse( +export function testParse( p: Parser, T>, src: string, tokenMatcher: RegexMatchers = testMatcher, - appState: AppState = { context: {} as C, stable: [] as S }, -): TestParseResult { +): TestParseResult { const stream = new FilterStream( new MatchersStream(src, tokenMatcher), t => t.kind !== "ws", ); - const parsed = p.parse({ stream, appState }); - return { parsed, position: stream.checkpoint(), stable: appState.stable }; + const parsed = p.parse({ stream }); + return { parsed, position: stream.checkpoint() }; } -export function testParseWithStream( +export function testParseWithStream( p: Parser, T>, stream: Stream, - appState: AppState = { context: {} as C, stable: [] as S }, -): TestParseResult { - const parsed = p.parse({ stream, appState: appState }); - return { parsed, position: stream.checkpoint(), stable: appState.stable }; +): TestParseResult { + const parsed = p.parse({ stream }); + return { parsed, position: stream.checkpoint() }; } /** run a test function and expect that no error logs are produced */ diff --git a/tools/packages/mini-parse/src/test/ParserCombinator.test.ts b/tools/packages/mini-parse/src/test/ParserCombinator.test.ts index c48adc0e7..faee80af9 100644 --- a/tools/packages/mini-parse/src/test/ParserCombinator.test.ts +++ b/tools/packages/mini-parse/src/test/ParserCombinator.test.ts @@ -127,9 +127,12 @@ test("recurse with fn()", () => { repeat(or(kind(m.word), () => p)).map(v => v.flat()), "}", ); - const wrap = or(p).mapExtended(r => r.app.stable.push(r.value)); - const { stable } = testParse(wrap, src); - expect(stable[0]).toEqual(["a", "b"]); + const wrap = or(p).mapExtended(r => { + r.app.stable.push(r.value); + return r.value; + }); + const { parsed } = testParse(wrap, src); + expect(parsed?.value).toEqual(["a", "b"]); }); test("tracing", () => { diff --git a/tools/packages/mini-parse/src/test/SrcMap.test.ts b/tools/packages/mini-parse/src/test/SrcMap.test.ts index e90acbbad..54546bb3c 100644 --- a/tools/packages/mini-parse/src/test/SrcMap.test.ts +++ b/tools/packages/mini-parse/src/test/SrcMap.test.ts @@ -1,75 +1,46 @@ import { expect, test } from "vitest"; import { SrcMap } from "../SrcMap.js"; -test("compact", () => { - const src = "a b"; - const dest = "|" + src + " d"; +test("map 1:1 correspondence", () => { + const source = { text: "let foo;" }; + const srcMap = new SrcMap(); + const builder = srcMap.builderFor(source); + builder.add("let", [0, 3]); + builder.addSynthetic(" "); + builder.addName("foo", [4, 7]); - const srcMap = new SrcMap({ text: dest }); - srcMap.addEntries([ - { src: { text: src }, srcStart: 0, srcEnd: 2, destStart: 1, destEnd: 3 }, - { src: { text: src }, srcStart: 2, srcEnd: 3, destStart: 3, destEnd: 4 }, - ]); - srcMap.compact(); - expect(srcMap.entries).toMatchInlineSnapshot(` - [ - { - "destEnd": 4, - "destStart": 1, - "src": { - "text": "a b", - }, - "srcEnd": 3, - "srcStart": 0, - }, - ] - `); + expect(srcMap.destToSrc(0)).toEqual({ + src: source, + position: 0, + }); + expect(srcMap.destToSrc(1)).toEqual({ + src: source, + position: 1, + }); + expect(srcMap.destToSrc(3)?.position).toBe(3); + expect(srcMap.destToSrc(4)?.position).toBe(4); + expect(srcMap.destToSrc(5)?.position).toBe(5); + expect(srcMap.destToSrc(7)?.position).toBe(7); + expect(srcMap.destToSrc(8)).toBe(null); }); -test("merge", () => { - const src = "a b"; - const src2 = "d"; - const mid = "|" + src + " " + src2; - const dest = "xx" + mid + " z"; - /* - mid: - 01234567890 - |a b d - dest: - 01234567890 - xx|a b d z - */ +test("map shifted and renamed", () => { + const source = { text: "let foo;" }; + const srcMap = new SrcMap(); + const builder = srcMap.builderFor(source); + builder.add("const", [0, 3]); + builder.addSynthetic(" "); + builder.addName("nya", [4, 7]); + builder.addSynthetic(";"); - const map1 = new SrcMap({ text: mid }, [ - { src: { text: src }, srcStart: 0, srcEnd: 3, destStart: 1, destEnd: 4 }, - ]); + expect(srcMap.code).toBe("const nya;"); - const map2 = new SrcMap({ text: dest }, [ - { src: { text: mid }, srcStart: 1, srcEnd: 4, destStart: 3, destEnd: 6 }, - { src: { text: src2 }, srcStart: 0, srcEnd: 1, destStart: 8, destEnd: 9 }, - ]); - - const merged = map1.merge(map2); - expect(merged.entries).toMatchInlineSnapshot(` - [ - { - "destEnd": 6, - "destStart": 3, - "src": { - "text": "a b", - }, - "srcEnd": 3, - "srcStart": 0, - }, - { - "destEnd": 9, - "destStart": 8, - "src": { - "text": "d", - }, - "srcEnd": 1, - "srcStart": 0, - }, - ] - `); + expect(srcMap.destToSrc(0)?.position).toBe(0); + expect(srcMap.destToSrc(1)?.position).toBe(0); + expect(srcMap.destToSrc(3)?.position).toBe(0); + expect(srcMap.destToSrc(5)?.position).toBe(3); + expect(srcMap.destToSrc(6)?.position).toBe(4); + expect(srcMap.destToSrc(7)?.position).toBe(4); + expect(srcMap.destToSrc(9)?.position).toBe(7); + expect(srcMap.destToSrc(10)).toBe(null); }); diff --git a/tools/packages/wesl-link/src/cli.ts b/tools/packages/wesl-link/src/cli.ts index e523eb94f..8d551a831 100644 --- a/tools/packages/wesl-link/src/cli.ts +++ b/tools/packages/wesl-link/src/cli.ts @@ -2,7 +2,7 @@ import { createTwoFilesPatch } from "diff"; import fs from "fs"; import { enableTracing, log } from "mini-parse"; import path from "path"; -import { astToString, link, noSuffix, scopeToString } from "wesl"; +import { astToString, link, noSuffix } from "wesl"; import yargs from "yargs"; import { parsedRegistry, @@ -63,13 +63,13 @@ async function linkNormally(paths: string[]): Promise { const relativePath = path.relative(weslRoot, f); return [toUnixPath(relativePath), text]; }); - const rootModuleName = noSuffix(path.relative(weslRoot, paths[0])); + const rootModulePath = ["package", ...noSuffix(pathAndTexts[0][0])]; const weslSrc = Object.fromEntries(pathAndTexts); // TODO conditions // TODO external defines if (argv.emit) { - const linked = await link({ weslSrc, rootModuleName }); + const linked = await link({ weslSrc, rootModulePath }); if (argv.emit) log(linked.dest); } if (argv.details) { @@ -79,12 +79,10 @@ async function linkNormally(paths: string[]): Promise { } catch (e) { console.error(e); } - Object.entries(registry.modules).forEach(([modulePath, ast]) => { - log(`---\n${modulePath}`); + registry.getModules().forEach(module => { + log(`---\n${module.srcModule.modulePath}`); log(`\n->ast`); - log(astToString(ast.moduleElem)); - log(`\n->scope`); - log(scopeToString(ast.rootScope)); + log(astToString(module.moduleElem)); log(); }); } diff --git a/tools/packages/wesl-plugin/src/BindingLayoutExtension.ts b/tools/packages/wesl-plugin/src/BindingLayoutExtension.ts index 33f117c48..4c7c33927 100644 --- a/tools/packages/wesl-plugin/src/BindingLayoutExtension.ts +++ b/tools/packages/wesl-plugin/src/BindingLayoutExtension.ts @@ -1,8 +1,8 @@ import { bindAndTransform, bindingStructsPlugin, LinkConfig } from "wesl"; import { - bindingGroupLayoutTs, - reportBindingStructsPlugin -} from "../../wesl/src/Reflection.ts"; // TODO fix import + bindingGroupLayoutTs, + reportBindingStructsPlugin, +} from "../../wesl/src/Reflection.ts"; // TODO fix import import { PluginExtension, PluginExtensionApi } from "./PluginExtension.ts"; export const bindingLayoutExtension: PluginExtension = { @@ -29,7 +29,7 @@ async function bindingLayoutJs( }), ], }; - bindAndTransform({ registry, rootModuleName: main, config }); + bindAndTransform({ rootModuleName: main, config }, registry); return structsJs; } diff --git a/tools/packages/wesl-plugin/test/linkExtension/LinkExtension.test.ts b/tools/packages/wesl-plugin/test/linkExtension/LinkExtension.test.ts index 461a71d11..34b52e68c 100644 --- a/tools/packages/wesl-plugin/test/linkExtension/LinkExtension.test.ts +++ b/tools/packages/wesl-plugin/test/linkExtension/LinkExtension.test.ts @@ -7,9 +7,9 @@ import linkParams from "./shaders/app.wesl?link"; test("verify ?link", async () => { expectTypeOf(linkParams).toMatchTypeOf(); - const { rootModuleName, debugWeslRoot, weslSrc, libs } = + const { rootModulePath, debugWeslRoot, weslSrc, libs } = linkParams as LinkParams; - expect(rootModuleName).toMatchInlineSnapshot(`"app"`); + expect(rootModulePath).toMatchInlineSnapshot(`"app"`); dlog("fixme", { debugWeslRoot }); // TODO this result can't be right... weslRoot should be relative to the tomlDir probably. diff --git a/tools/packages/wesl-reflect/src/SimpleReflectExtension.ts b/tools/packages/wesl-reflect/src/SimpleReflectExtension.ts index 4959a4816..43f4ba0c5 100644 --- a/tools/packages/wesl-reflect/src/SimpleReflectExtension.ts +++ b/tools/packages/wesl-reflect/src/SimpleReflectExtension.ts @@ -41,9 +41,11 @@ function makeReflect(options: SimpleReflectOptions) { ): Promise { const registry = await api.weslRegistry(); - const astStructs = Object.entries(registry.modules).flatMap(([, module]) => - module.moduleElem.contents.filter(e => e.kind === "struct"), - ); + const astStructs = registry + .getModules() + .flatMap(module => + module.moduleElem.contents.filter(e => e.kind === "struct"), + ); const jsStructs = weslStructs(astStructs); diff --git a/tools/packages/wesl/src/AbstractElems.ts b/tools/packages/wesl/src/AbstractElems.ts deleted file mode 100644 index e2b52e0f7..000000000 --- a/tools/packages/wesl/src/AbstractElems.ts +++ /dev/null @@ -1,446 +0,0 @@ -import { Span } from "mini-parse"; -import { DeclIdent, RefIdent, SrcModule } from "./Scope.ts"; - -/** - * Structures to describe the 'interesting' parts of a WESL source file. - * - * The parts of the source that need to analyze further in the linker - * are pulled out into these structures. - * - * The parts that are uninteresting the the linker are recorded - * as 'TextElem' nodes, which are generally just copied to the output WGSL - * along with their containing element. - */ -export type AbstractElem = GrammarElem | SyntheticElem; - -export type GrammarElem = ContainerElem | TerminalElem; - -export type ContainerElem = - | AttributeElem - | AliasElem - | ConstAssertElem - | ConstElem - | UnknownExpressionElem - | SimpleMemberRef - | FnElem - | TypedDeclElem - | GlobalVarElem - | LetElem - | ModuleElem - | OverrideElem - | FnParamElem - | StructElem - | StructMemberElem - | StuffElem - | TypeRefElem - | VarElem - | StatementElem - | SwitchClauseElem; - -/** Inspired by https://github.com/wgsl-tooling-wg/wesl-rs/blob/3b2434eac1b2ebda9eb8bfb25f43d8600d819872/crates/wgsl-parse/src/syntax.rs#L364 */ -export type ExpressionElem = - | Literal - | TranslateTimeFeature - | RefIdentElem - | ParenthesizedExpression - | ComponentExpression - | ComponentMemberExpression - | UnaryExpression - | BinaryExpression - | FunctionCallExpression; - -export type TerminalElem = - | DirectiveElem - | DeclIdentElem // - | NameElem - | RefIdentElem - | TextElem - | ImportElem; - -export type GlobalDeclarationElem = - | AliasElem - | ConstElem - | FnElem - | GlobalVarElem - | OverrideElem - | StructElem; - -export type DeclarationElem = GlobalDeclarationElem | FnParamElem | VarElem; - -export type ElemWithAttributes = Extract; - -export interface AbstractElemBase { - kind: AbstractElem["kind"]; - start: number; - end: number; -} - -export interface ElemWithContentsBase extends AbstractElemBase { - contents: AbstractElem[]; -} - -export interface HasAttributes { - attributes?: AttributeElem[]; -} - -/* ------ Terminal Elements (don't contain other elements) ------ */ - -/** - * a raw bit of text in WESL source that's typically copied to the linked WGSL. - * e.g. a keyword like 'var' - * or a phrase we needn't analyze further like '@diagnostic(off,derivative_uniformity)' - */ -export interface TextElem extends AbstractElemBase { - kind: "text"; - srcModule: SrcModule; -} - -/** a name that doesn't need to be an Ident - * e.g. - * - a struct member name - * - a diagnostic rule name - * - an enable-extension name - * - an interpolation sampling name - */ -export interface NameElem extends AbstractElemBase { - kind: "name"; - name: string; -} - -/** an identifier that 'refers to' a declaration (aka a symbol reference) */ -export interface RefIdentElem extends AbstractElemBase { - kind: RefIdent["kind"]; - ident: RefIdent; - srcModule: SrcModule; -} - -/** a declaration identifier (aka a symbol declaration) */ -export interface DeclIdentElem extends AbstractElemBase { - kind: DeclIdent["kind"]; - ident: DeclIdent; - srcModule: SrcModule; -} - -/** Holds an import statement, and has a span */ -export interface ImportElem extends AbstractElemBase { - kind: "import"; - imports: ImportStatement; -} - -/** - * An import statement, which is tree shaped. - * `import foo::bar::{baz, cat as neko}; - */ -export interface ImportStatement { - kind: "import-statement"; - segments: ImportSegment[]; - finalSegment: ImportCollection | ImportItem; -} - -/** - * A collection of import trees. - * `{baz, cat as neko}` - */ -export interface ImportSegment { - kind: "import-segment"; - name: string; -} - -/** - * A primitive segment in an import statement. - * `foo` - */ -export interface ImportCollection { - kind: "import-collection"; - subtrees: ImportStatement[]; -} - -/** - * A renamed item at the end of an import statement. - * `cat as neko` - */ -export interface ImportItem { - kind: "import-item"; - name: string; - as?: string; -} - -/* ------ Synthetic element (for transformations, not produced by grammar) ------ */ - -/** generated element, produced after parsing and binding */ -export interface SyntheticElem { - kind: "synthetic"; - text: string; -} - -/* ------ Container Elements (contain other elements) ------ */ - -/** a declaration identifer with a possible type */ -export interface TypedDeclElem extends ElemWithContentsBase { - kind: "typeDecl"; - decl: DeclIdentElem; - typeRef?: TypeRefElem; // TODO Consider a variant for fn params and alias where typeRef is required -} - -/** an alias statement */ -export interface AliasElem extends ElemWithContentsBase, HasAttributes { - kind: "alias"; - name: DeclIdentElem; - typeRef: TypeRefElem; -} - -/** an attribute like '@compute' or '@binding(0)' */ -export interface AttributeElem extends ElemWithContentsBase { - kind: "attribute"; - attribute: Attribute; -} - -export type Attribute = - | StandardAttribute - | InterpolateAttribute - | BuiltinAttribute - | DiagnosticAttribute - | IfAttribute; - -export interface StandardAttribute { - kind: "@attribute"; - name: string; - params?: UnknownExpressionElem[]; -} - -export interface InterpolateAttribute { - kind: "@interpolate"; - params: NameElem[]; -} - -export interface BuiltinAttribute { - kind: "@builtin"; - param: NameElem; -} - -export interface DiagnosticAttribute { - kind: "@diagnostic"; - severity: NameElem; - rule: [NameElem, NameElem | null]; -} - -export interface IfAttribute { - kind: "@if"; - param: TranslateTimeExpressionElem; -} - -/** a const_assert statement */ -export interface ConstAssertElem extends ElemWithContentsBase, HasAttributes { - kind: "assert"; -} - -/** a const declaration */ -export interface ConstElem extends ElemWithContentsBase, HasAttributes { - kind: "const"; - name: TypedDeclElem; -} - -/** an expression w/o special handling, used inside attribute parameters */ -export interface UnknownExpressionElem extends ElemWithContentsBase { - kind: "expression"; -} - -/** an expression that can be safely evaluated at compile time */ -export interface TranslateTimeExpressionElem { - kind: "translate-time-expression"; - expression: ExpressionElem; - span: Span; -} - -/** A literal value in WESL source. A boolean or a number. */ -export interface Literal { - kind: "literal"; - value: string; - span: Span; -} - -/** `words`s inside `@if` */ -export interface TranslateTimeFeature { - kind: "translate-time-feature"; - name: string; - span: Span; -} - -/** (expr) */ -export interface ParenthesizedExpression { - kind: "parenthesized-expression"; - expression: ExpressionElem; -} - -/** `foo[expr]` */ -export interface ComponentExpression { - kind: "component-expression"; - base: ExpressionElem; - access: ExpressionElem; -} - -/** `foo.member` */ -export interface ComponentMemberExpression { - kind: "component-member-expression"; - base: ExpressionElem; - access: NameElem; -} - -/** `+foo` */ -export interface UnaryExpression { - kind: "unary-expression"; - operator: UnaryOperator; - expression: ExpressionElem; -} - -/** `foo + bar` */ -export interface BinaryExpression { - kind: "binary-expression"; - operator: BinaryOperator; - left: ExpressionElem; - right: ExpressionElem; -} - -/** `foo(arg, arg)` */ -export interface FunctionCallExpression { - kind: "call-expression"; - function: RefIdentElem; - arguments: ExpressionElem[]; -} - -export interface UnaryOperator { - value: "!" | "&" | "*" | "-" | "~"; - span: Span; -} - -export interface BinaryOperator { - value: - | ("||" | "&&" | "+" | "-" | "*" | "/" | "%" | "==") - | ("!=" | "<" | "<=" | ">" | ">=" | "|" | "&" | "^") - | ("<<" | ">>"); - span: Span; -} - -export type DirectiveVariant = - | DiagnosticDirective - | EnableDirective - | RequiresDirective; - -export interface DirectiveElem extends AbstractElemBase, HasAttributes { - kind: "directive"; - directive: DirectiveVariant; -} - -export interface DiagnosticDirective { - kind: "diagnostic"; - severity: NameElem; - rule: [NameElem, NameElem | null]; -} - -export interface EnableDirective { - kind: "enable"; - extensions: NameElem[]; -} - -export interface RequiresDirective { - kind: "requires"; - extensions: NameElem[]; -} - -/** a function declaration */ -export interface FnElem extends ElemWithContentsBase, HasAttributes { - // LATER doesn't need contents - kind: "fn"; - name: DeclIdentElem; - params: FnParamElem[]; - body: StatementElem; - returnAttributes?: AttributeElem[]; - returnType?: TypeRefElem; -} - -/** a global variable declaration (at the root level) */ -export interface GlobalVarElem extends ElemWithContentsBase, HasAttributes { - kind: "gvar"; - name: TypedDeclElem; -} - -/** an entire file */ -export interface ModuleElem extends ElemWithContentsBase { - kind: "module"; -} - -/** an override declaration */ -export interface OverrideElem extends ElemWithContentsBase, HasAttributes { - kind: "override"; - name: TypedDeclElem; -} - -/** a parameter in a function declaration */ -export interface FnParamElem extends ElemWithContentsBase, HasAttributes { - kind: "param"; - name: TypedDeclElem; -} - -/** simple references to structures, like myStruct.bar - * (used for transforming refs to binding structs) */ -export interface SimpleMemberRef extends ElemWithContentsBase { - kind: "memberRef"; - name: RefIdentElem; - member: NameElem; - extraComponents?: StuffElem; -} - -/** a struct declaration */ -export interface StructElem extends ElemWithContentsBase, HasAttributes { - kind: "struct"; - name: DeclIdentElem; - members: StructMemberElem[]; - bindingStruct?: true; // used later during binding struct transformation -} - -/** generic container of other elements */ -export interface StuffElem extends ElemWithContentsBase { - kind: "stuff"; -} - -/** a struct declaration that's been marked as a bindingStruct */ -export interface BindingStructElem extends StructElem { - bindingStruct: true; - entryFn?: FnElem; -} - -/** a member of a struct declaration */ -export interface StructMemberElem extends ElemWithContentsBase, HasAttributes { - kind: "member"; - name: NameElem; - typeRef: TypeRefElem; - mangledVarName?: string; // root name if transformed to a var (for binding struct transformation) -} - -export type TypeTemplateParameter = TypeRefElem | UnknownExpressionElem; - -/** a reference to a type, like 'f32', or 'MyStruct', or 'ptr, read_only>' */ -export interface TypeRefElem extends ElemWithContentsBase { - kind: "type"; - name: RefIdent; - templateParams?: TypeTemplateParameter[]; -} - -/** a variable declaration */ -export interface VarElem extends ElemWithContentsBase, HasAttributes { - kind: "var"; - name: TypedDeclElem; -} - -export interface LetElem extends ElemWithContentsBase, HasAttributes { - kind: "let"; - name: TypedDeclElem; -} - -export interface StatementElem extends ElemWithContentsBase, HasAttributes { - kind: "statement"; -} - -export interface SwitchClauseElem extends ElemWithContentsBase, HasAttributes { - kind: "switch-clause"; -} diff --git a/tools/packages/wesl/src/AstVisitor.ts b/tools/packages/wesl/src/AstVisitor.ts new file mode 100644 index 000000000..85e4567df --- /dev/null +++ b/tools/packages/wesl/src/AstVisitor.ts @@ -0,0 +1,280 @@ +import { + AttributeElem, + GlobalDeclarationElem, + LhsExpression, + ModuleElem, + Statement, +} from "./parse/WeslElems.ts"; +import { ExpressionElem } from "./parse/ExpressionElem.ts"; +import { DirectiveElem } from "./parse/DirectiveElem.ts"; +import { ImportElem } from "./parse/ImportElems.ts"; +import { assertUnreachable } from "./Assertions.ts"; + +export abstract class AstVisitor { + module(module: ModuleElem) { + walkModule(module, this); + } + import(importElem: ImportElem) { + walkImport(importElem, this); + } + directive(directive: DirectiveElem) { + walkDirective(directive, this); + } + attribute(attribute: AttributeElem) { + walkAttribute(attribute, this); + } + /** A global declaration and its attributes */ + globalDeclaration(declaration: GlobalDeclarationElem) { + walkGlobalDeclaration(declaration, this); + } + /** A global declaration after the attributes */ + globalDeclarationInner(declaration: GlobalDeclarationElem) { + walkGlobalDeclarationInner(declaration, this); + } + /** A statement and its attributes */ + statement(statement: Statement) { + walkStatement(statement, this); + } + /** A statement after the attributes */ + statementInner(statement: Statement) { + walkStatementInner(statement, this); + } + expression(expression: ExpressionElem): void { + walkExpression(expression, this); + } + lhsExpression(expression: LhsExpression): void { + walkLhsExpression(expression, this); + } +} + +export function walkModule(module: ModuleElem, visitor: AstVisitor) { + for (const importElem of module.imports) { + visitor.import(importElem); + } + for (const directive of module.directives) { + visitor.directive(directive); + } + for (const declaration of module.declarations) { + visitor.globalDeclaration(declaration); + } +} + +export function walkImport(importElem: ImportElem, visitor: AstVisitor) { + visitAttributes(importElem.attributes, visitor); +} + +export function walkDirective(directive: DirectiveElem, visitor: AstVisitor) { + visitAttributes(directive.attributes, visitor); +} + +/** Helper function so I don't have to write this out every time */ +function visitAttributes( + attributes: AttributeElem[] | undefined, + visitor: AstVisitor, +) { + attributes?.forEach(attribute => visitor.attribute(attribute)); +} + +function walkAttribute(attribute: AttributeElem, visitor: AstVisitor) { + if (attribute.attribute.kind === "attribute") { + attribute.attribute.params.forEach(v => visitor.expression(v)); + } +} + +export function walkGlobalDeclaration( + declaration: GlobalDeclarationElem, + visitor: AstVisitor, +) { + visitAttributes(declaration.attributes, visitor); + visitor.globalDeclarationInner(declaration); +} + +export function walkGlobalDeclarationInner( + declaration: GlobalDeclarationElem, + visitor: AstVisitor, +) { + const kind = declaration.kind; + if (kind === "alias") { + visitor.expression(declaration.type); + } else if (kind === "assert") { + visitor.expression(declaration.expression); + } else if (kind === "declaration") { + if (declaration.variant.kind === "var") { + declaration.variant.template?.forEach(v => visitor.expression(v)); + } + if (declaration.type) { + visitor.expression(declaration.type); + } + if (declaration.initializer !== undefined) { + visitor.expression(declaration.initializer); + } + } else if (kind === "function") { + declaration.params.forEach(p => { + visitAttributes(p.attributes, visitor); + visitor.expression(p.type); + }); + visitAttributes(declaration.returnAttributes, visitor); + if (declaration.returnType) { + visitor.expression(declaration.returnType); + } + visitor.statement(declaration.body); + } else if (kind === "struct") { + declaration.members.forEach(member => { + declaration.attributes?.forEach(attribute => + visitor.attribute?.(attribute), + ); + visitor.expression(member.type); + }); + } else { + assertUnreachable(kind); + } +} + +export function walkExpression( + expression: ExpressionElem, + visitor: AstVisitor, +): void { + const kind = expression.kind; + if (kind === "binary-expression") { + visitor.expression(expression.left); + visitor.expression(expression.right); + } else if (kind === "call-expression") { + visitor.expression(expression.function); + expression.arguments.forEach(arg => visitor.expression(arg)); + } else if (kind === "component-expression") { + visitor.expression(expression.base); + visitor.expression(expression.access); + } else if (kind === "component-member-expression") { + visitor.expression(expression.base); + } else if (kind === "literal") { + // Nothing to do + } else if (kind === "parenthesized-expression") { + visitor.expression(expression.expression); + } else if (kind === "templated-ident") { + expression.template?.forEach(v => visitor.expression(v)); + } else if (kind === "unary-expression") { + visitor.expression(expression.expression); + } else { + assertUnreachable(kind); + } +} + +export function walkLhsExpression( + expression: LhsExpression, + visitor: AstVisitor, +): void { + const kind = expression.kind; + if (kind === "component-expression") { + visitor.lhsExpression(expression.base); + visitor.expression(expression.access); + } else if (kind === "component-member-expression") { + visitor.lhsExpression(expression.base); + } else if (kind === "lhs-ident") { + // Nothing to do + } else if (kind === "parenthesized-expression") { + visitor.lhsExpression(expression.expression); + } else if (kind === "unary-expression") { + visitor.lhsExpression(expression.expression); + } else { + assertUnreachable(kind); + } +} + +export function walkStatement(statement: Statement, visitor: AstVisitor) { + visitAttributes(statement.attributes, visitor); + visitor.statementInner(statement); +} + +export function walkStatementInner(statement: Statement, visitor: AstVisitor) { + const kind = statement.kind; + if (kind === "assert") { + visitor.expression(statement.expression); + } else if (kind === "assignment-statement") { + if (statement.left.kind !== "discard-expression") { + visitor.lhsExpression(statement.left); + } + visitor.expression(statement.right); + } else if (kind === "break-statement") { + // Nothing to do + } else if (kind === "call-statement") { + visitor.expression(statement.function); + statement.arguments.forEach(v => visitor.expression(v)); + } else if (kind === "compound-statement") { + statement.body.forEach(v => visitor.statement(v)); + } else if (kind === "continue-statement") { + // Nothing to do + } else if (kind === "declaration") { + if (statement.variant.kind === "var") { + statement.variant.template?.forEach(v => visitor.expression(v)); + } + if (statement.type !== undefined) { + visitor.expression(statement.type); + } + if (statement.initializer !== undefined) { + visitor.expression(statement.initializer); + } + } else if (kind === "postfix-statement") { + visitor.lhsExpression(statement.expression); + } else if (kind === "discard-statement") { + // Nothing to do + } else if (kind === "for-statement") { + if (statement.initializer !== undefined) { + visitor.statement(statement.initializer); + } + if (statement.condition !== undefined) { + visitor.expression(statement.condition); + } + if (statement.update !== undefined) { + visitor.statement(statement.update); + } + visitor.statement(statement.body); + } else if (kind === "if-else-statement") { + let current = statement.main; + while (true) { + visitor.expression(current.condition); + visitor.statement(current.accept); + if (current.reject === undefined) { + break; + } else if (current.reject.kind === "if-clause") { + current = current.reject; + } else { + visitor.statement(current.reject); + break; + } + } + } else if (kind === "loop-statement") { + visitor.statement(statement.body); + if (statement.continuing !== undefined) { + visitAttributes(statement.continuing.attributes, visitor); + visitor.statement(statement.continuing.body); + const breakIf = statement.continuing.breakIf; + if (breakIf !== undefined) { + visitAttributes(breakIf.attributes, visitor); + visitor.expression(breakIf.expression); + } + } + } else if (kind === "return-statement") { + if (statement.expression !== undefined) { + visitor.expression(statement.expression); + } + } else if (kind === "switch-statement") { + visitor.expression(statement.selector); + visitAttributes(statement.bodyAttributes, visitor); + for (const clause of statement.clauses) { + visitAttributes(clause.attributes, visitor); + for (const switchCase of clause.cases) { + if (switchCase.expression === "default") { + // Nothing to do + } else { + visitor.expression(switchCase.expression); + } + } + visitor.statement(clause.body); + } + } else if (kind === "while-statement") { + visitor.expression(statement.condition); + visitor.statement(statement.body); + } else { + assertUnreachable(kind); + } +} diff --git a/tools/packages/wesl/src/BindIdents.ts b/tools/packages/wesl/src/BindIdents.ts index 4e742b249..c87798976 100644 --- a/tools/packages/wesl/src/BindIdents.ts +++ b/tools/packages/wesl/src/BindIdents.ts @@ -11,7 +11,7 @@ import { FlatImport } from "./FlattenTreeImport.ts"; import { LinkRegistryParams, VirtualLibraryFn } from "./Linker.ts"; import { LiveDecls, makeLiveDecls } from "./LiveDeclarations.ts"; import { ManglerFn, minimalMangle } from "./Mangler.ts"; -import { ParsedRegistry } from "./ParsedRegistry.ts"; +import { TranslationUnit } from "./lower/TranslationUnit.ts"; import { flatImports, parseSrcModule, WeslAST } from "./ParseWESL.ts"; import { Conditions, @@ -155,7 +155,7 @@ export function findValidRootDecls( /** state used during the recursive scope tree walk to bind references to declarations */ interface BindContext { - registry: ParsedRegistry; + registry: TranslationUnit; /** live runtime conditions currently defined by the user */ conditions: Record; @@ -406,7 +406,7 @@ function findDeclInModule( * or via an inline qualified ident e.g. foo::bar() */ function findQualifiedImport( refIdent: RefIdent, - parsed: ParsedRegistry, + parsed: TranslationUnit, conditions: Conditions, virtuals?: VirtualLibrarySet, ): FoundDecl | undefined { @@ -451,7 +451,7 @@ interface FoundDecl { function findExport( modulePathParts: string[], srcModule: SrcModule, - parsed: ParsedRegistry, + parsed: TranslationUnit, conditions: Conditions = {}, virtuals?: VirtualLibrarySet, ): FoundDecl | undefined { diff --git a/tools/packages/wesl/src/Conditions.ts b/tools/packages/wesl/src/Conditions.ts index 479aee220..f9fc8754d 100644 --- a/tools/packages/wesl/src/Conditions.ts +++ b/tools/packages/wesl/src/Conditions.ts @@ -1,74 +1,71 @@ -import { - AttributeElem, - ElemWithAttributes, - ExpressionElem, - IfAttribute, -} from "./AbstractElems.ts"; -import { assertThatDebug, assertUnreachable } from "./Assertions.ts"; -import { Conditions, Scope } from "./Scope.ts"; -import { findMap } from "./Util.ts"; +import { assertThat, assertUnreachable } from "../../mini-parse/src/Assertions"; +import { ExpressionElem } from "./parse/ExpressionElem"; +import { AttributeElem, IfAttribute } from "./parse/WeslElems"; -/** @return true if the element is valid under current Conditions */ -export function elementValid( - elem: ElemWithAttributes, +/** Maps every condition to a value. A condition being missing is an error. */ +export type Conditions = Map; + +export function evaluateIfAttribute( conditions: Conditions, + attributes: AttributeElem[] | undefined, ): boolean { - const attributes = elem.attributes; - if (!attributes) return true; - const ifAttr = findMap(attributes, extractIfAttribute); - return !ifAttr || evaluateIfAttribute(ifAttr, conditions); -} - -/** @return true if the scope is valid under current conditions */ -export function scopeValid(scope: Scope, conditions: Conditions): boolean { - const { ifAttribute } = scope; - if (!ifAttribute) return true; - const result = evaluateIfAttribute(ifAttribute, conditions); // LATER cache? - return result; -} - -/** @return return IfAttribute if AttributeElem contains an IfAttribute */ -function extractIfAttribute(elem: AttributeElem): IfAttribute | undefined { - const { attribute } = elem; - return attribute.kind === "@if" ? attribute : undefined; + const condAttribute = attributes?.find(v => v.attribute.kind === "@if"); + if (condAttribute === undefined) return true; + return evaluateConditions(conditions, condAttribute.attribute as IfAttribute); } -/** @return true if the @if attribute is valid with current Conditions */ -function evaluateIfAttribute( - ifAttribute: IfAttribute, +export function evaluateConditions( conditions: Conditions, + ifAttribute: IfAttribute, ): boolean { - return evaluateIfExpression(ifAttribute.param.expression, conditions); + return evaluateExpression(conditions, ifAttribute.param.expression); } -/** Evaluate an @if expression based on current runtime Conditions - * @return true if the expression is true */ -function evaluateIfExpression( - expression: ExpressionElem, +function evaluateExpression( conditions: Conditions, + expression: ExpressionElem, ): boolean { - const { kind } = expression; - if (kind == "unary-expression") { - assertThatDebug(expression.operator.value === "!"); - return !evaluateIfExpression(expression.expression, conditions); - } else if (kind == "binary-expression") { - const op = expression.operator.value; - assertThatDebug(op === "||" || op === "&&"); - const leftResult = evaluateIfExpression(expression.left, conditions); - if (op === "||") { - return leftResult || evaluateIfExpression(expression.right, conditions); - } else if (op === "&&") { - return leftResult && evaluateIfExpression(expression.right, conditions); + if (expression.kind == "binary-expression") { + const operator = expression.operator.value; + assertThat(operator === "||" || operator === "&&"); + const left = evaluateExpression(conditions, expression.left); + if (left && operator === "||") { + return true; + } else if (!left && operator === "||") { + return evaluateExpression(conditions, expression.right); + } else if (left && operator === "&&") { + return evaluateExpression(conditions, expression.right); + } else if (!left && operator === "&&") { + return false; } else { - assertUnreachable(op); + assertUnreachable(operator as never); + } + } else if (expression.kind == "call-expression") { + throw new Error("Function calls are not supported in an @if()"); + } else if (expression.kind == "component-expression") { + throw new Error("Component access is not supported in an @if()"); + } else if (expression.kind == "component-member-expression") { + throw new Error("Component access is not supported in an @if()"); + } else if (expression.kind == "literal") { + assertThat(expression.value === "true" || expression.value === "false"); + return expression.value === "true" ? true : false; + } else if (expression.kind == "parenthesized-expression") { + return evaluateExpression(conditions, expression.expression); + } else if (expression.kind == "templated-ident") { + assertThat(expression.ident.segments.length === 1); + assertThat( + expression.template === undefined || expression.template.length === 0, + ); + const name = expression.ident.segments[0]; + const condition = conditions.get(name); + if (condition === undefined) { + throw new Error(`Condition ${name} has not been defined`); } - } else if (kind == "literal") { - const { value } = expression; - assertThatDebug(value === "true" || value === "false"); - return value === "true"; - } else if (kind == "parenthesized-expression") { - return evaluateIfExpression(expression.expression, conditions); + return condition; + } else if (expression.kind == "unary-expression") { + assertThat(expression.operator.value === "!"); + return !evaluateExpression(conditions, expression.expression); } else { - throw new Error("unexpected @if expression ${expression}"); + assertUnreachable(expression); } } diff --git a/tools/packages/wesl/src/FlattenTreeImport.ts b/tools/packages/wesl/src/FlattenTreeImport.ts deleted file mode 100644 index dc93a2b05..000000000 --- a/tools/packages/wesl/src/FlattenTreeImport.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { - ImportCollection, - ImportItem, - ImportSegment, - ImportStatement, -} from "./AbstractElems"; -import { assertUnreachable } from "./Assertions"; - -export interface FlatImport { - importPath: string[]; - modulePath: string[]; -} - -/** - * Simplify importTree into a flattened map from import paths to module paths. - * - * @return map from import path (with 'as' renaming) to module Path - */ -export function flattenTreeImport(imp: ImportStatement): FlatImport[] { - return recursiveResolve([], [], imp.segments, imp.finalSegment); - - /** recurse through segments of path, producing */ - function recursiveResolve( - resolvedImportPath: string[], - resolvedExportPath: string[], - remainingPath: ImportSegment[], - finalSegment: ImportCollection | ImportItem, - ): FlatImport[] { - if (remainingPath.length > 0) { - const [segment, ...rest] = remainingPath; - const importPath = [...resolvedImportPath, segment.name]; - const modulePath = [...resolvedExportPath, segment.name]; - return recursiveResolve(importPath, modulePath, rest, finalSegment); - } else if (finalSegment.kind === "import-collection") { - // resolve path with each element in the list - return finalSegment.subtrees.flatMap(elem => { - return recursiveResolve( - resolvedImportPath, - resolvedExportPath, - elem.segments, - elem.finalSegment, - ); - }); - } else if (finalSegment.kind === "import-item") { - const importPath = [ - ...resolvedImportPath, - finalSegment.as ?? finalSegment.name, - ]; - const modulePath = [...resolvedExportPath, finalSegment.name]; - return [{ importPath, modulePath }]; - } else { - assertUnreachable(finalSegment); - } - } -} diff --git a/tools/packages/wesl/src/LinkedWesl.ts b/tools/packages/wesl/src/LinkedWesl.ts index c8e751587..5621271c1 100644 --- a/tools/packages/wesl/src/LinkedWesl.ts +++ b/tools/packages/wesl/src/LinkedWesl.ts @@ -1,12 +1,12 @@ import { SrcMap } from "mini-parse"; -import { assertThat } from "../../mini-parse/src/Assertions"; -import { errorHighlight, offsetToLineNumber } from "./Util"; -import type { WeslDevice } from "./WeslDevice"; +import type { WeslDevice } from "./WeslDevice.ts"; +import { offsetToLineNumber, str } from "./Util.ts"; +import { assertThat } from "./Assertions.ts"; /** Results of shader compilation. Has {@link WeslGPUCompilationMessage} * which are aware of the WESL module that an error was thrown from. */ export interface WeslGPUCompilationInfo extends GPUCompilationInfo { - messages: WeslGPUCompilationMessage[]; + messages: (WeslGPUCompilationMessage | GPUCompilationMessage)[]; } export interface WeslGPUCompilationMessage extends GPUCompilationMessage { @@ -96,7 +96,7 @@ export class LinkedWesl { * better error reporting experience. */ get dest() { - return this.sourceMap.dest.text; + return this.sourceMap.code; } /** Turns raw compilation info into compilation info @@ -114,32 +114,34 @@ export class LinkedWesl { private mapGPUCompilationMessage( message: GPUCompilationMessage, - ): WeslGPUCompilationMessage { + ): WeslGPUCompilationMessage | GPUCompilationMessage { const srcMap = this.sourceMap; - const srcPosition = srcMap.destToSrc(message.offset); - // LATER what if this gets mapped to a completely different place? - const srcEndPosition = - message.length > 0 ? - srcMap.destToSrc(message.offset + message.length) - : srcPosition; - const length = srcEndPosition.position - srcPosition.position; + const srcSpan = srcMap.destSpanToSrc([ + message.offset, + message.offset + message.length, + ]); + if (srcSpan === null) { + return message; + } + const length = + srcSpan.span[1] !== null ? srcSpan.span[1] - srcSpan.span[0] : 0; let [lineNum, linePos] = offsetToLineNumber( - srcPosition.position, - srcPosition.src.text, + srcSpan.span[0], + srcSpan.src.text, ); return { __brand: message.__brand, type: message.type, message: message.message, - offset: srcPosition.position, + offset: srcSpan.span[0], length, lineNum, linePos, module: { - url: srcPosition.src.path ?? "", - text: srcPosition.src.text, + url: srcSpan.src.path ?? "", + text: srcSpan.src.text, }, }; } @@ -167,9 +169,9 @@ function compilationInfoToErrorMessage( } for (const message of compilationInfo.messages) { const { lineNum, linePos } = message; - - result += `${message.module.url}:${lineNum}:${linePos}`; - result += ` ${message.type}: ${message.message}\n`; + const module = "module" in message ? message.module : null; + result += str`${module?.url ?? ""}:${lineNum}:${linePos}`; + result += str` ${message.type}: ${message.message}\n`; // LATER unmangle code snippets in the message const source = message.module.text; diff --git a/tools/packages/wesl/src/Linker.ts b/tools/packages/wesl/src/Linker.ts index 0be3ce1f4..d4c5d7d27 100644 --- a/tools/packages/wesl/src/Linker.ts +++ b/tools/packages/wesl/src/Linker.ts @@ -1,20 +1,18 @@ import { SrcMap, SrcMapBuilder, tracing } from "mini-parse"; -import { AbstractElem, ModuleElem } from "./AbstractElems.ts"; -import { bindIdents, EmittableElem } from "./BindIdents.ts"; -import { LinkedWesl } from "./LinkedWesl.ts"; -import { lowerAndEmit } from "./LowerAndEmit.ts"; +import { ModuleElem } from "./parse/WeslElems.ts"; +import { lowerAndEmit } from "./lower/LowerAndEmit.ts"; import { ManglerFn } from "./Mangler.ts"; import { parsedRegistry, - ParsedRegistry, + TranslationUnit, parseIntoRegistry, - parseLibsIntoRegistry, - selectModule, -} from "./ParsedRegistry.ts"; -import { WeslAST } from "./ParseWESL.ts"; -import { Conditions, DeclIdent, SrcModule } from "./Scope.ts"; +} from "./lower/TranslationUnit.ts"; import { filterMap, mapValues } from "./Util.ts"; import { WgslBundle } from "./WgslBundle.ts"; +import { LinkedWesl } from "./LinkedWesl.ts"; +import { Conditions } from "./Conditions.ts"; +import { assertThat } from "./Assertions.ts"; +import { WeslAST } from "./Module.ts"; type LinkerTransform = (boundAST: TransformedAST) => TransformedAST; @@ -22,8 +20,7 @@ export interface WeslJsPlugin { transform?: LinkerTransform; } -export interface TransformedAST - extends Pick { +export interface TransformedAST extends WeslAST { globalNames: Set; notableElems: Record; } @@ -49,9 +46,9 @@ export interface LinkParams { /** name of root wesl module * for an app, the root module normally contains the '@compute', '@vertex' or '@fragment' entry points * for a library, the root module defines the public api fo the library - * can be specified as file path (./main.wesl), a module path (package::main), or just a module name (main) + * is a module path like ["package", "main"] */ - rootModuleName?: string; + rootModulePath: string[]; /** For debug logging. Will be prepended to file paths. */ debugWeslRoot?: string; @@ -93,22 +90,19 @@ export async function link(params: LinkParams): Promise { const { weslSrc, debugWeslRoot, libs = [] } = params; const registry = parsedRegistry(); parseIntoRegistry(weslSrc, registry, "package", debugWeslRoot); - parseLibsIntoRegistry(libs, registry); - return new LinkedWesl(linkRegistry({ registry, ...params })); + libs.forEach(lib => parseIntoRegistry(lib.modules, registry, lib.name)); + return new LinkedWesl(linkRegistry(params, registry)); } -export interface LinkRegistryParams - extends Pick< - LinkParams, - | "rootModuleName" - | "conditions" - | "virtualLibs" - | "config" - | "constants" - | "mangler" - > { - registry: ParsedRegistry; -} +type LinkRegistryParams = Pick< + LinkParams, + | "rootModulePath" + | "conditions" + | "virtualLibs" + | "config" + | "constants" + | "mangler" +>; /** Link wesl from a registry of already parsed modules. * @@ -117,9 +111,12 @@ export interface LinkRegistryParams * each time, or perhaps to produce multiple wgsl shaders * that share some sources.) */ -export function linkRegistry(params: LinkRegistryParams): SrcMap { - const bound = bindAndTransform(params); - const { transformedAst, newDecls, newStatements } = bound; +export function linkRegistry( + params: LinkRegistryParams, + registry: TranslationUnit, +): SrcMap { + const bound = bindAndTransform(params, registry); + const { transformedAst, newDecls } = bound; return SrcMapBuilder.build( emitWgsl( @@ -141,10 +138,16 @@ export interface BoundAndTransformed { /** bind identifers and apply any transform plugins */ export function bindAndTransform( params: LinkRegistryParams, + registry: TranslationUnit, ): BoundAndTransformed { - const { registry, mangler } = params; - const { rootModuleName = "main", conditions = {} } = params; - const rootAst = getRootModule(registry, rootModuleName); + const { mangler } = params; + const { rootModulePath, conditions = {} } = params; + assertThat( + rootModulePath[0] === "package", + "Root module must be inside the package", + ); + assertThat(rootModulePath.length >= 2, "Root module must point at a module"); + const rootAst = registry.getModule(rootModulePath); // setup virtual modules from code generation or host constants provided by the user const { constants, config } = params; @@ -173,22 +176,6 @@ function constantsGenerator( .join("\n"); } -/** get a reference to the root module, selecting by module name */ -function getRootModule( - parsed: ParsedRegistry, - rootModuleName: string, -): WeslAST { - const rootModule = selectModule(parsed, rootModuleName); - if (!rootModule) { - if (tracing) { - console.log(`parsed modules: ${Object.keys(parsed.modules)}`); - console.log(`root module not found: ${rootModuleName}`); - } - throw new Error(`Root module not found: ${rootModuleName}`); - } - return rootModule; -} - /** run any plugins that transform the AST */ function applyTransformPlugins( rootModule: WeslAST, @@ -198,7 +185,12 @@ function applyTransformPlugins( const { moduleElem, srcModule } = rootModule; // for now only transform the root module - const startAst = { moduleElem, srcModule, globalNames, notableElems: {} }; + const startAst: TransformedAST = { + moduleElem, + srcModule, + globalNames, + notableElems: {}, + }; const plugins = config?.plugins ?? []; const transforms = filterMap(plugins, plugin => plugin.transform); const transformedAst = transforms.reduce( diff --git a/tools/packages/wesl/src/LinkerUtil.ts b/tools/packages/wesl/src/LinkerUtil.ts deleted file mode 100644 index 580dd9dcd..000000000 --- a/tools/packages/wesl/src/LinkerUtil.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { srcLog } from "mini-parse"; -import { - AbstractElem, - ContainerElem, - DeclIdentElem, - RefIdentElem, -} from "./AbstractElems.ts"; - -export function visitAst( - elem: AbstractElem, - visitor: (elem: AbstractElem) => void, -) { - visitor(elem); - if ((elem as ContainerElem).contents) { - const container = elem as ContainerElem; - container.contents.forEach(child => visitAst(child, visitor)); - } -} - -export function identElemLog( - identElem: DeclIdentElem | RefIdentElem, - ...messages: any[] -): void { - srcLog( - identElem.srcModule.src, - [identElem.start, identElem.end], - ...messages, - ); -} diff --git a/tools/packages/wesl/src/LiveDeclarations.ts b/tools/packages/wesl/src/LiveDeclarations.ts deleted file mode 100644 index 949be40b8..000000000 --- a/tools/packages/wesl/src/LiveDeclarations.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { identToString } from "./debug/ScopeToString.ts"; -import { DeclIdent } from "./Scope.ts"; - -/** decls currently visible in this scope */ -export interface LiveDecls { - /** decls currently visible in this scope */ - decls: Map; - - /** live decls in the parent scope. null for the modue root scope */ - parent?: LiveDecls | null; -} - -/** create a LiveDecls */ -export function makeLiveDecls(parent: LiveDecls | null = null): LiveDecls { - return { decls: new Map(), parent }; -} - -/** debug routine for logging LiveDecls */ -export function liveDeclsToString(liveDecls: LiveDecls): string { - const { decls, parent } = liveDecls; - const declsStr = Array.from(decls.entries()) - .map(([name, decl]) => `${name}:${identToString(decl)}`) - .join(", "); - const parentStr = parent ? liveDeclsToString(parent) : "null"; - return `decls: { ${declsStr} }, parent: ${parentStr}`; -} - -/* -LATER try not creating a map for small scopes. -Instead just track the current live index in the scope array. -*/ diff --git a/tools/packages/wesl/src/LowerAndEmit.ts b/tools/packages/wesl/src/LowerAndEmit.ts deleted file mode 100644 index 53e9d03d1..000000000 --- a/tools/packages/wesl/src/LowerAndEmit.ts +++ /dev/null @@ -1,413 +0,0 @@ -import { srcLog, SrcMapBuilder } from "mini-parse"; -import { - AbstractElem, - AttributeElem, - ContainerElem, - DeclIdentElem, - DirectiveElem, - ElemWithAttributes, - ExpressionElem, - FnElem, - NameElem, - RefIdentElem, - StructElem, - SyntheticElem, - TextElem, -} from "./AbstractElems.ts"; -import { - assertThatDebug, - assertUnreachable, - assertUnreachableSilent, -} from "./Assertions.ts"; -import { isGlobal } from "./BindIdents.ts"; -import { elementValid } from "./Conditions.ts"; -import { identToString } from "./debug/ScopeToString.ts"; -import { Conditions, DeclIdent, Ident } from "./Scope.ts"; - -/** passed to the emitters */ -interface EmitContext { - srcBuilder: SrcMapBuilder; // constructing the linked output - conditions: Conditions; // settings for conditional compilation - extracting: boolean; // are we extracting or copying the root module -} - -/** traverse the AST, starting from root elements, emitting wgsl for each */ -export function lowerAndEmit( - srcBuilder: SrcMapBuilder, - rootElems: AbstractElem[], - conditions: Conditions, - extracting = true, -): void { - const emitContext: EmitContext = { conditions, srcBuilder, extracting }; - // rootElems.forEach(r => console.log(astToString(r) + "\n")); - rootElems.forEach(e => lowerAndEmitElem(e, emitContext)); -} - -export function lowerAndEmitElem(e: AbstractElem, ctx: EmitContext): void { - if (!conditionsValid(e, ctx.conditions)) return; - - switch (e.kind) { - // import statements are dropped from from emitted text - case "import": - return; - - // terminal elements copy strings to the output - case "text": - return emitText(e, ctx); - case "name": - return emitName(e, ctx); - case "synthetic": - return emitSynthetic(e, ctx); - - // identifiers are copied to the output, but with potentially mangled names - case "ref": - return emitRefIdent(e, ctx); - case "decl": - return emitDeclIdent(e, ctx); - - // container elements just emit their child elements - case "param": - case "var": - case "typeDecl": - case "let": - case "module": - case "member": - case "memberRef": - case "expression": - case "type": - case "statement": - case "stuff": - case "switch-clause": - return emitContents(e, ctx); - - // root level container elements get some extra newlines to make the output prettier - case "override": - case "const": - case "assert": - case "alias": - case "gvar": - emitRootElemNl(ctx); - return emitContents(e, ctx); - - case "fn": - emitRootElemNl(ctx); - return emitFn(e, ctx); - - case "struct": - emitRootElemNl(ctx); - return emitStruct(e, ctx); - - case "attribute": - return emitAttribute(e, ctx); - case "directive": - return emitDirective(e, ctx); - - default: - assertUnreachable(e); - } -} - -/** emit root elems with a blank line inbetween */ -function emitRootElemNl(ctx: EmitContext): void { - if (ctx.extracting) { - ctx.srcBuilder.addNl(); - ctx.srcBuilder.addNl(); - } -} - -export function emitText(e: TextElem, ctx: EmitContext): void { - ctx.srcBuilder.addCopy(e.start, e.end); -} - -export function emitName(e: NameElem, ctx: EmitContext): void { - ctx.srcBuilder.add(e.name, e.start, e.end); -} - -/** emit function explicitly so we can control commas between conditional parameters */ -export function emitFn(e: FnElem, ctx: EmitContext): void { - const { attributes, name, params, returnAttributes, returnType, body } = e; - const { conditions, srcBuilder: builder } = ctx; - - emitAttributes(attributes, ctx); - - builder.add("fn ", name.start - 3, name.start); - emitDeclIdent(name, ctx); - - builder.appendNext("("); - const validParams = params.filter(p => conditionsValid(p, conditions)); - validParams.forEach((p, i) => { - emitContentsNoWs(p, ctx); - if (i < validParams.length - 1) { - builder.appendNext(", "); - } - }); - builder.appendNext(") "); - - if (returnType) { - builder.appendNext("-> "); - emitAttributes(returnAttributes, ctx); - emitContents(returnType, ctx); - builder.appendNext(" "); - } - - emitContents(body, ctx); -} - -function emitAttributes( - attributes: AttributeElem[] | undefined, - ctx: EmitContext, -): void { - attributes?.forEach(a => { - emitAttribute(a, ctx); - ctx.srcBuilder.add(" ", a.start, a.end); - }); -} - -/** emit structs explicitly so we can control commas between conditional members */ -export function emitStruct(e: StructElem, ctx: EmitContext): void { - const { name, members, start, end } = e; - const { srcBuilder } = ctx; - - const validMembers = members.filter(m => conditionsValid(m, ctx.conditions)); - const validLength = validMembers.length; - - if (validLength === 0) { - warnEmptyStruct(e); - return; - } - - srcBuilder.add("struct ", start, name.start); - emitDeclIdent(name, ctx); - - if (validLength === 1) { - srcBuilder.add(" { ", name.end, members[0].start); - emitContentsNoWs(validMembers[0], ctx); - srcBuilder.add(" }\n", end - 1, end); - } else { - srcBuilder.add(" {\n", name.end, members[0].start); - - validMembers.forEach(m => { - srcBuilder.add(" ", m.start - 1, m.start); - emitContentsNoWs(m, ctx); - srcBuilder.add(",", m.end, m.end + 1); - srcBuilder.addNl(); - }); - - srcBuilder.add("}\n", end - 1, end); - } -} - -function warnEmptyStruct(e: StructElem): void { - const { name, members } = e; - const condStr = members.length ? "(with current conditions)" : ""; - const { debugFilePath: filePath } = name.srcModule; - srcLog( - name.srcModule.src, - e.start, - `struct ${name.ident.originalName} in ${filePath} has no members ${condStr}`, - ); -} - -export function emitSynthetic(e: SyntheticElem, ctx: EmitContext): void { - const { text } = e; - ctx.srcBuilder.addSynthetic(text, text, 0, text.length); -} - -export function emitContents(elem: ContainerElem, ctx: EmitContext): void { - elem.contents.forEach(e => lowerAndEmitElem(e, ctx)); -} - -/** emit contents w/o white space */ -function emitContentsNoWs(elem: ContainerElem, ctx: EmitContext): void { - elem.contents.forEach(e => { - if (e.kind === "text") { - const { srcModule, start, end } = e; - const text = srcModule.src.slice(start, end); - if (text.trim() === "") { - return; - } - } - lowerAndEmitElem(e, ctx); - }); -} - -export function emitRefIdent(e: RefIdentElem, ctx: EmitContext): void { - if (e.ident.std) { - ctx.srcBuilder.add(e.ident.originalName, e.start, e.end); - } else { - const declIdent = findDecl(e.ident); - const mangledName = displayName(declIdent); - ctx.srcBuilder.add(mangledName!, e.start, e.end); - } -} - -export function emitDeclIdent(e: DeclIdentElem, ctx: EmitContext): void { - const mangledName = displayName(e.ident); - ctx.srcBuilder.add(mangledName!, e.start, e.end); -} - -function emitAttribute(e: AttributeElem, ctx: EmitContext): void { - const { kind } = e.attribute; - // LATER emit more precise source map info by making use of all the spans - // Like the first case does - if (kind === "@attribute") { - const { params } = e.attribute; - if (!params || params.length === 0) { - ctx.srcBuilder.add("@" + e.attribute.name, e.start, e.end); - } else { - ctx.srcBuilder.add( - "@" + e.attribute.name + "(", - e.start, - params[0].start, - ); - for (let i = 0; i < params.length; i++) { - emitContents(params[i], ctx); - if (i < params.length - 1) { - ctx.srcBuilder.add(",", params[i].end, params[i + 1].start); - } - } - ctx.srcBuilder.add(")", params[params.length - 1].end, e.end); - } - } else if (kind === "@builtin") { - ctx.srcBuilder.add( - "@builtin(" + e.attribute.param.name + ")", - e.start, - e.end, - ); - } else if (kind === "@diagnostic") { - ctx.srcBuilder.add( - "@diagnostic" + - diagnosticControlToString(e.attribute.severity, e.attribute.rule), - e.start, - e.end, - ); - } else if (kind === "@if") { - // (@if is wesl only, dropped from wgsl) - } else if (kind === "@interpolate") { - ctx.srcBuilder.add( - `@interpolate(${e.attribute.params.map(v => v.name).join(", ")})`, - e.start, - e.end, - ); - } else { - assertUnreachable(kind); - } -} - -export function diagnosticControlToString( - severity: NameElem, - rule: [NameElem, NameElem | null], -): string { - const ruleStr = rule[0].name + (rule[1] !== null ? "." + rule[1].name : ""); - return `(${severity.name}, ${ruleStr})`; -} - -export function expressionToString(elem: ExpressionElem): string { - const { kind } = elem; - if (kind === "binary-expression") { - return `${expressionToString(elem.left)} ${elem.operator.value} ${expressionToString(elem.right)}`; - } else if (kind === "unary-expression") { - return `${elem.operator.value}${expressionToString(elem.expression)}`; - } else if (kind === "ref") { - return elem.ident.originalName; - } else if (kind === "literal") { - return elem.value; - } else if (kind === "translate-time-feature") { - return elem.name; - } else if (kind === "parenthesized-expression") { - return `(${expressionToString(elem.expression)})`; - } else if (kind === "component-expression") { - return `${expressionToString(elem.base)}[${elem.access}]`; - } else if (kind === "component-member-expression") { - return `${expressionToString(elem.base)}.${elem.access}`; - } else if (kind === "call-expression") { - return `${elem.function.ident.originalName}(${elem.arguments.map(expressionToString).join(", ")})`; - } else { - assertUnreachable(kind); - } -} - -function emitDirective(e: DirectiveElem, ctx: EmitContext): void { - const { directive } = e; - const { kind } = directive; - if (kind === "diagnostic") { - ctx.srcBuilder.add( - `diagnostic${diagnosticControlToString(directive.severity, directive.rule)};`, - e.start, - e.end, - ); - } else if (kind === "enable") { - ctx.srcBuilder.add( - `enable ${directive.extensions.map(v => v.name).join(", ")};`, - e.start, - e.end, - ); - } else if (kind === "requires") { - ctx.srcBuilder.add( - `requires ${directive.extensions.map(v => v.name).join(", ")};`, - e.start, - e.end, - ); - } else { - assertUnreachable(kind); - } -} - -function displayName(declIdent: DeclIdent): string { - if (isGlobal(declIdent)) { - assertThatDebug( - declIdent.mangledName, - `ERR: mangled name not found for decl ident ${identToString(declIdent)}`, - ); - // mangled name was set in binding step - return declIdent.mangledName!; - } - - return declIdent.mangledName || declIdent.originalName; -} - -/** trace through refersTo links in reference Idents until we find the declaration - * expects that bindIdents has filled in all refersTo: links - */ -export function findDecl(ident: Ident): DeclIdent { - let i: Ident | undefined = ident; - do { - if (i.kind === "decl") { - return i; - } - i = i.refersTo; - } while (i); - - // TODO show source position if this can happen in a non buggy linker. - throw new Error(`unresolved identifer: ${ident.originalName}`); -} - -/** check if the element is visible with the current current conditional compilation settings */ -export function conditionsValid( - elem: AbstractElem, - conditions: Conditions, -): true | false | undefined { - const attrElem = elem as ElemWithAttributes; - const { kind } = attrElem; - - switch (kind) { - case "alias": - case "assert": - case "const": - case "directive": - case "member": - case "var": - case "let": - case "statement": - case "switch-clause": - case "override": - case "gvar": - case "fn": - case "struct": - case "param": - return elementValid(attrElem, conditions); - default: - assertUnreachableSilent(kind); - } - return true; -} diff --git a/tools/packages/wesl/src/Mangler.ts b/tools/packages/wesl/src/Mangler.ts index 3ebd128b9..f0294b867 100644 --- a/tools/packages/wesl/src/Mangler.ts +++ b/tools/packages/wesl/src/Mangler.ts @@ -1,4 +1,3 @@ -import { DeclIdent, SrcModule } from "./Scope.ts"; /** * A function for constructing a unique identifier name for a global declaration. * Global names must be unique in the linked wgsl. @@ -10,13 +9,10 @@ import { DeclIdent, SrcModule } from "./Scope.ts"; */ export type ManglerFn = ( /** global declaration that needs a name */ - decl: DeclIdent, + decl: string, /** module that contains the declaration */ - srcModule: SrcModule, - - /** name at use site (possibly import as renamed from declaration) */ - proposedName: string, + modulePath: string[], /** current set of mangled root level names for the linked result (read only) */ globalNames: Set, @@ -27,12 +23,8 @@ export type ManglerFn = ( * module path separated by underscores. * Corresponds to "Underscore-count mangling" from [NameMangling.md](https://github.com/wgsl-tooling-wg/wesl-spec/blob/main/NameMangling.md) */ -export function underscoreMangle( - decl: DeclIdent, - srcModule: SrcModule, -): string { - const { modulePath } = srcModule; - return [...modulePath.split("::"), decl.originalName] +export function underscoreMangle(decl: string, modulePath: string[]): string { + return [...modulePath, decl] .map(v => { const underscoreCount = (v.match(/_/g) ?? []).length; if (underscoreCount > 0) { @@ -47,17 +39,11 @@ export function underscoreMangle( /** * Construct a globally unique name based on the declaration */ -export function lengthPrefixMangle( - decl: DeclIdent, - srcModule: SrcModule, -): string { +export function lengthPrefixMangle(decl: string, modulePath: string[]): string { function codepointCount(text: string): number { return [...text].length; } - const qualifiedIdent = [ - ...srcModule.modulePath.split("::"), - decl.originalName, - ]; + const qualifiedIdent = [...modulePath, decl]; return "_" + qualifiedIdent.map(v => codepointCount(v) + v).join(""); } @@ -66,12 +52,11 @@ export function lengthPrefixMangle( * using the requested name plus a uniquing number suffix if necessary */ export function minimalMangle( - _d: DeclIdent, - _s: SrcModule, - proposedName: string, + decl: string, + _modulePath: string[], globalNames: Set, ): string { - return minimallyMangledName(proposedName, globalNames); + return minimallyMangledName(decl, globalNames); } /** @@ -79,15 +64,15 @@ export function minimalMangle( * and appending a number suffix necessary */ export function minimallyMangledName( - proposedName: string, + name: string, globalNames: Set, ): string { - let renamed = proposedName; + let renamed = name; let conflicts = 0; // create a unique name while (globalNames.has(renamed)) { - renamed = proposedName + conflicts++; + renamed = name + conflicts++; } return renamed; diff --git a/tools/packages/wesl/src/Module.ts b/tools/packages/wesl/src/Module.ts new file mode 100644 index 000000000..cf564516a --- /dev/null +++ b/tools/packages/wesl/src/Module.ts @@ -0,0 +1,38 @@ +import { assertThat } from "./Assertions.ts"; +import { ModuleElem } from "./parse/WeslElems.ts"; + +export type ModulePathString = string & { __modulePath: never }; + +/** An absolute path to a module. Is unique. */ +export class ModulePath { + constructor(public path: string[]) { + assertThat(path.length > 0); + } + + toString(): ModulePathString { + let result: string = this.path.join("::"); + return result as ModulePathString; + } +} + +/** + * result of a parse for one wesl module (e.g. one .wesl file) + */ +export interface WeslAST { + /** source text for this module */ + srcModule: SrcModule; + + /** root module element */ + moduleElem: ModuleElem; +} + +export interface SrcModule { + /** module path "rand_pkg::sub::foo", or "package::main" */ + modulePath: ModulePath; + + /** file path to the module for user error reporting e.g "rand_pkg:sub/foo.wesl", or "./sub/foo.wesl" */ + debugFilePath: string; + + /** original src for module */ + src: string; +} diff --git a/tools/packages/wesl/src/ParseWESL.ts b/tools/packages/wesl/src/ParseWESL.ts deleted file mode 100644 index db6bc1eb1..000000000 --- a/tools/packages/wesl/src/ParseWESL.ts +++ /dev/null @@ -1,158 +0,0 @@ -import { AppState, ParserInit, SrcMap } from "mini-parse"; -import { - ConstAssertElem, - ImportStatement, - ModuleElem, -} from "./AbstractElems.ts"; -import { FlatImport, flattenTreeImport } from "./FlattenTreeImport.ts"; -import { weslRoot } from "./parse/WeslGrammar.ts"; -import { WeslStream } from "./parse/WeslStream.ts"; -import { emptyScope, Scope, SrcModule } from "./Scope.ts"; -import { OpenElem } from "./WESLCollect.ts"; -import { ParseError } from "mini-parse"; -import { errorHighlight, offsetToLineNumber } from "./Util.ts"; -import { throwClickableError } from "./WeslDevice.ts"; - -/** result of a parse for one wesl module (e.g. one .wesl file) - * - * The parser constructs the AST constructed into three sections - * for convenient access by the binding stage. - * - import statements - * - language elements (fn, struct, etc) - * - scopes - * - */ -export interface WeslAST { - /** source text for this module */ - srcModule: SrcModule; - - /** root module element */ - moduleElem: ModuleElem; - - /** root scope for this module */ - rootScope: Scope; - - /** imports found in this module */ - imports: ImportStatement[]; - - /** module level const_assert statements */ - moduleAsserts?: ConstAssertElem[]; -} - -/** an extended version of the AST */ -export interface BindingAST extends WeslAST { - /* a flattened version of the import statements constructed on demand from import trees, and cached here */ - _flatImports?: FlatImport[]; -} - -/** stable and unstable state used during parsing */ -export interface WeslParseState - extends AppState { - context: WeslParseContext; - stable: StableState; -} - -/** stable values used or accumulated during parsing */ -export type StableState = WeslAST; - -/** unstable values used during parse collection */ -export interface WeslParseContext { - scope: Scope; // current scope (points somewhere in rootScope) - openElems: OpenElem[]; // elems that are collecting their contents -} - -/** - * An error when parsing WESL fails. Designed to be human-readable. - */ -export class WeslParseError extends Error { - position: number; - src: SrcModule; - constructor(opts: { cause: ParseError; src: SrcModule }) { - const source = opts.src.src; - const [lineNum, linePos] = offsetToLineNumber(opts.cause.position, source); - let message = `${opts.src.debugFilePath}:${lineNum}:${linePos}`; - message += ` error: ${opts.cause.message}\n`; - message += errorHighlight(source, [ - opts.cause.position, - opts.cause.position + 1, - ]).join("\n"); - super(message, { - cause: opts.cause, - }); - this.position = opts.cause.position; - this.src = opts.src; - } -} - -/** Parse a WESL file. Throws on error. */ -export function parseSrcModule(srcModule: SrcModule, srcMap?: SrcMap): WeslAST { - const stream = new WeslStream(srcModule.src); - - const appState = blankWeslParseState(srcModule); - - const init: ParserInit = { stream, appState }; - try { - const parseResult = weslRoot.parse(init); - if (parseResult === null) { - throw new Error("parseWESL failed"); - } - } catch (e) { - if (e instanceof ParseError) { - const [lineNumber, lineColumn] = offsetToLineNumber( - e.position, - srcModule.src, - ); - const error = new WeslParseError({ cause: e, src: srcModule }); - throwClickableError({ - url: srcModule.debugFilePath, - text: srcModule.src, - error, - lineNumber, - lineColumn, - length: 1, - }); - } else { - throw e; - } - } - - return appState.stable as WeslAST; -} - -export function parseWESL(src: string, srcMap?: SrcMap): WeslAST { - const srcModule: SrcModule = { - modulePath: "package::test", // TODO this ought not be used outside of tests - debugFilePath: "./test.wesl", - src, - }; - - return parseSrcModule(srcModule, srcMap); -} - -export function blankWeslParseState(srcModule: SrcModule): WeslParseState { - const rootScope = emptyScope(null); - const moduleElem = null as any; // we'll fill this in later - return { - context: { scope: rootScope, openElems: [] }, - stable: { srcModule, imports: [], rootScope, moduleElem }, - }; -} - -export function syntheticWeslParseState(): WeslParseState { - const srcModule: SrcModule = { - modulePath: "package::test", - debugFilePath: "./test.wesl", - src: "", - }; - - return blankWeslParseState(srcModule); -} - -/** @return a flattened form of the import tree for convenience in binding idents. */ -export function flatImports(ast: BindingAST): FlatImport[] { - if (ast._flatImports) return ast._flatImports; - - const flat = ast.imports.flatMap(flattenTreeImport); - ast._flatImports = flat; - return flat; -} diff --git a/tools/packages/wesl/src/ParsedRegistry.ts b/tools/packages/wesl/src/ParsedRegistry.ts deleted file mode 100644 index 633b162f2..000000000 --- a/tools/packages/wesl/src/ParsedRegistry.ts +++ /dev/null @@ -1,120 +0,0 @@ -import { WgslBundle } from "wesl"; -import { parseSrcModule, parseWESL, WeslAST } from "./ParseWESL.ts"; -import { normalize, noSuffix } from "./PathUtil.ts"; -import { resetScopeIds, SrcModule } from "./Scope.ts"; - -export interface ParsedRegistry { - modules: Record; // key is module path, e.g. "rand_pkg::foo::bar" -} - -export function parsedRegistry(): ParsedRegistry { - resetScopeIds(); // for debug - return { modules: {} }; -} - -/** for debug */ -export function registryToString(registry: ParsedRegistry): string { - return `modules: ${[...Object.keys(registry.modules)]}`; -} - -/** - * Parse WESL each src module (file) into AST elements and a Scope tree. - * @param src keys are module paths, values are wesl src strings - */ -export function parseWeslSrc(src: Record): ParsedRegistry { - const parsedEntries = Object.entries(src).map(([path, src]) => { - const weslAST = parseWESL(src); - return [path, weslAST]; - }); - return { modules: Object.fromEntries(parsedEntries) }; -} - -/** Look up a module with a flexible selector. - * :: separated module path, package::util - * / separated file path ./util.wesl (or ./util) - * - note: a file path should not include a weslRoot prefix, e.g. not ./shaders/util.wesl - * simpleName util - */ -export function selectModule( - parsed: ParsedRegistry, - selectPath: string, - packageName = "package", -): WeslAST | undefined { - // dlog({reg: [...Object.keys(parsed.modules)]}); - let modulePath: string; - if (selectPath.includes("::")) { - modulePath = selectPath; - } else if ( - selectPath.includes("/") || - selectPath.endsWith(".wesl") || - selectPath.endsWith(".wgsl") - ) { - modulePath = fileToModulePath(selectPath, packageName); - } else { - modulePath = packageName + "::" + selectPath; - } - - return parsed.modules[modulePath]; -} - -/** - * @param srcFiles map of source strings by file path - * key is '/' separated relative path (relative to srcRoot, not absolute file path ) - * value is wesl source string - * @param registry add parsed modules to this registry - * @param packageName name of package - */ -export function parseIntoRegistry( - srcFiles: Record, - registry: ParsedRegistry, - packageName: string = "package", - debugWeslRoot?: string, -): void { - if (debugWeslRoot === undefined) { - debugWeslRoot = ""; - } else if (!debugWeslRoot.endsWith("/")) { - debugWeslRoot += "/"; - } - const srcModules: SrcModule[] = Object.entries(srcFiles).map( - ([filePath, src]) => { - const modulePath = fileToModulePath(filePath, packageName); - return { modulePath, debugFilePath: debugWeslRoot + filePath, src }; - }, - ); - srcModules.forEach(mod => { - const parsed = parseSrcModule(mod, undefined); - if (registry.modules[mod.modulePath]) { - throw new Error(`duplicate module path: '${mod.modulePath}'`); - } - registry.modules[mod.modulePath] = parsed; - }); -} - -export function parseLibsIntoRegistry( - libs: WgslBundle[], - registry: ParsedRegistry, -): void { - libs.forEach(({ modules, name }) => - parseIntoRegistry(modules, registry, name), - ); -} - -const libRegex = /^lib\.w[eg]sl$/i; - -/** convert a file path (./foo/bar.wesl) - * to a module path (package::foo::bar) */ -function fileToModulePath(filePath: string, packageName: string): string { - if (filePath.includes("::")) { - // already a module path - return filePath; - } - if (packageName !== "package" && libRegex.test(filePath)) { - // special case for lib.wesl files in external packages - return packageName; - } - - const strippedPath = noSuffix(normalize(filePath)); - const moduleSuffix = strippedPath.replaceAll("/", "::"); - const modulePath = packageName + "::" + moduleSuffix; - return modulePath; -} diff --git a/tools/packages/wesl/src/RawEmit.ts b/tools/packages/wesl/src/RawEmit.ts index ec822c81a..b28463d6f 100644 --- a/tools/packages/wesl/src/RawEmit.ts +++ b/tools/packages/wesl/src/RawEmit.ts @@ -1,33 +1,23 @@ -import { - AttributeElem, - NameElem, - StuffElem, - TranslateTimeExpressionElem, - TypeRefElem, - TypeTemplateParameter, - UnknownExpressionElem, -} from "./AbstractElems.ts"; +import { AttributeElem, TypeTemplateParameter } from "./parse/WeslElems.ts"; import { assertUnreachable } from "./Assertions.ts"; import { diagnosticControlToString, expressionToString, - findDecl, -} from "./LowerAndEmit.ts"; -import { RefIdent } from "./Scope.ts"; +} from "./lower/LowerAndEmit.ts"; -// LATER DRY emitting elements like this with LowerAndEmit? +// TODO: Completely remove this export function attributeToString(e: AttributeElem): string { const { kind } = e.attribute; // LATER emit more precise source map info by making use of all the spans // Like the first case does - if (kind === "@attribute") { + if (kind === "attribute") { const { params } = e.attribute; if (params === undefined || params.length === 0) { return "@" + e.attribute.name; } else { return `@${e.attribute.name}(${params - .map(param => contentsToString(param)) + .map(param => expressionToString(param)) .join(", ")})`; } } else if (kind === "@builtin") { @@ -47,56 +37,5 @@ export function attributeToString(e: AttributeElem): string { } export function typeListToString(params: TypeTemplateParameter[]): string { - return `<${params.map(typeParamToString).join(", ")}>`; -} - -export function typeParamToString(param?: TypeTemplateParameter): string { - if (param === undefined) return "?"; - if (typeof param === "string") return param; - - if (param.kind === "expression") return contentsToString(param); - if (param.kind === "type") return typeRefToString(param); - assertUnreachable(param); -} - -export function typeRefToString(t?: TypeRefElem): string { - if (!t) return "?"; - const { name, templateParams } = t; - const params = templateParams ? typeListToString(templateParams) : ""; - return `${refToString(name)}${params}`; -} - -function refToString(ref: RefIdent | string): string { - if (typeof ref === "string") return ref; - if (ref.std) return ref.originalName; - const decl = findDecl(ref); - return decl.mangledName || decl.originalName; -} - -export function contentsToString( - elem: - | TranslateTimeExpressionElem - | UnknownExpressionElem - | NameElem - | StuffElem, -): string { - if (elem.kind === "translate-time-expression") { - throw new Error("Not supported"); - } else if (elem.kind === "expression" || elem.kind === "stuff") { - const parts = elem.contents.map(c => { - const { kind } = c; - if (kind === "text") { - return c.srcModule.src.slice(c.start, c.end); - } else if (kind === "ref") { - return refToString(c.ident); - } else { - return `?${c.kind}?`; - } - }); - return parts.join(" "); - } else if (elem.kind === "name") { - return elem.name; - } else { - assertUnreachable(elem); - } + return `<${params.map(expressionToString).join(", ")}>`; } diff --git a/tools/packages/wesl/src/Reflection.ts b/tools/packages/wesl/src/Reflection.ts index ca6e21341..4bc073467 100644 --- a/tools/packages/wesl/src/Reflection.ts +++ b/tools/packages/wesl/src/Reflection.ts @@ -4,10 +4,9 @@ import { NameElem, StructMemberElem, TextElem, - TranslateTimeExpressionElem, + ConditionalExpressionElem, TypeRefElem, - UnknownExpressionElem, -} from "./AbstractElems.ts"; +} from "./parse/WeslElems.ts"; import { assertThat } from "./Assertions.ts"; import { TransformedAST, WeslJsPlugin } from "./Linker.ts"; import { identElemLog } from "./LinkerUtil.ts"; @@ -18,6 +17,7 @@ import { textureStorageTypes, } from "./StandardTypes.ts"; import { findMap } from "./Util.ts"; +import { expressionToString } from "./lower/LowerAndEmit.ts"; export type BindingStructReportFn = (structs: BindingStructElem[]) => void; export const textureStorage = matchOneOf(textureStorageTypes); @@ -122,21 +122,21 @@ function shaderVisiblity(struct: BindingStructElem): string { const { attributes = [] } = entryFn; if ( attributes.find( - ({ attribute: a }) => a.kind === "@attribute" && a.name === "compute", + ({ attribute: a }) => a.kind === "attribute" && a.name === "compute", ) ) { return "GPUShaderStage.COMPUTE"; } if ( attributes.find( - ({ attribute: a }) => a.kind === "@attribute" && a.name === "vertex", + ({ attribute: a }) => a.kind === "attribute" && a.name === "vertex", ) ) { return "GPUShaderStage.VERTEX"; } if ( attributes.find( - ({ attribute: a }) => a.kind === "@attribute" && a.name === "fragment", + ({ attribute: a }) => a.kind === "attribute" && a.name === "fragment", ) ) { return "GPUShaderStage.FRAGMENT"; @@ -155,9 +155,9 @@ function memberToLayoutEntry( visibility: string, ): string { const bindingParam = findMap(member.attributes ?? [], ({ attribute: a }) => - a.kind === "@attribute" && a.name === "binding" ? a : undefined, + a.kind === "attribute" && a.name === "binding" ? a : undefined, )?.params?.[0]; - const binding = bindingParam ? paramText(bindingParam) : "?"; + const binding = bindingParam ? expressionToString(bindingParam) : "?"; const src = ` { @@ -173,7 +173,7 @@ function memberToLayoutEntry( * references to WGSL samplers become 'sampler' GPUSamplerBindingLayout instances, etc. */ function layoutEntry(member: StructMemberElem): string { - const { typeRef } = member; + const { type: typeRef } = member; let entry: string | undefined; const { name: typeName } = typeRef; entry = ptrLayoutEntry(typeRef) ?? storageTextureLayoutEntry(typeRef); @@ -269,17 +269,6 @@ function externalTextureLayoutEntry(typeRef: TypeRefElem): string | undefined { return undefined; } -function paramText( - expression: UnknownExpressionElem | NameElem | TranslateTimeExpressionElem, -): string { - assertThat( - expression.kind === "expression", - "Only expression elements are supported in this position", - ); - const text = expression.contents[0] as TextElem; - return text.srcModule.src.slice(expression.start, expression.end); -} - export function formatToTextureSampleType( format: GPUTextureFormat, float32Filterable = false, diff --git a/tools/packages/wesl/src/Scope.ts b/tools/packages/wesl/src/Scope.ts deleted file mode 100644 index 6fa80d360..000000000 --- a/tools/packages/wesl/src/Scope.ts +++ /dev/null @@ -1,162 +0,0 @@ -import { DeclarationElem, IfAttribute, RefIdentElem } from "./AbstractElems.ts"; -import { assertThatDebug } from "./Assertions.ts"; -import { scopeValid } from "./Conditions.ts"; -import { WeslAST } from "./ParseWESL.ts"; - -export interface SrcModule { - /** module path "rand_pkg::sub::foo", or "package::main" */ - modulePath: string; - - /** file path to the module for user error reporting e.g "rand_pkg:sub/foo.wesl", or "./sub/foo.wesl" */ - debugFilePath: string; - - /** original src for module */ - src: string; -} - -/** a src declaration or reference to an ident */ -export type Ident = DeclIdent | RefIdent; - -/** LATER change this to a Map, so that `toString` isn't accidentally a condition */ -export type Conditions = Record; - -interface IdentBase { - originalName: string; // name in the source code for ident matching (may be mangled in the output) - id?: number; // for debugging -} - -export interface RefIdent extends IdentBase { - kind: "ref"; - - // LATER these fields are set during binding, not parsing. Make a naming scheme _refersTo or a separate interface (BindingRefIdent) to make that clear - refersTo?: Ident; // import or decl ident in scope to which this ident refers. undefined before binding - std?: true; // true if this is a standard wgsl identifier (like sin, or u32) - - // TODO consider tracking the current ast in BindIdents so that this field is unnecessary - ast: WeslAST; // AST from module that contains this ident (to find imports during decl binding) - - refIdentElem: RefIdentElem; // for error reporting and mangling -} - -export interface DeclIdent extends IdentBase { - kind: "decl"; - mangledName?: string; // name in the output code - declElem?: DeclarationElem; // link to AST so that we can traverse scopes and know what elems to emit // LATER make separate GlobalDecl kind with this required - scope: Scope; // scope for the references within this declaration - isGlobal: boolean; // true if this is a global declaration (e.g. not a local variable) - srcModule: SrcModule; // To figure out which module this declaration is from. -} - -/** tree of ident references, organized by lexical scope and partialScope . */ -export type Scope = LexicalScope | PartialScope; - -/** A wgsl scope */ -export interface LexicalScope extends ScopeBase { - kind: "scope"; - - /** @if condition for conditionally translating this scope */ - ifAttribute?: IfAttribute; - - /** - * Efficient access to declarations in this scope. - * constructed on demand, for module root scopes only */ // LATER consider make a special kind for root scopes - scopeDecls?: Map; -} - -/** A synthetic partial scope to contain @if conditioned idents. - * PartialScope idents are considered to be in the wgsl lexical scope of their parent. */ -export interface PartialScope extends ScopeBase { - kind: "partial"; - - /** @if condition for conditionally translating this scope */ - ifAttribute?: IfAttribute; // LATER this is required, consider changing type to reflect that -} - -/** common scope elements */ -interface ScopeBase { - /** id for debugging */ - id: number; - - /** null for root scope in a module */ - parent: Scope | null; - - /* Child scopes and idents in lexical order */ - contents: (Ident | Scope)[]; - - /** @if conditions for conditionally translating this scope */ - ifAttribute?: IfAttribute; -} - -/** Combine two scope siblings. - * The first scope is mutated to append the contents of the second. */ -export function mergeScope(a: Scope, b: Scope): void { - assertThatDebug(a.kind === b.kind); - assertThatDebug(a.parent === b.parent); - assertThatDebug(!b.ifAttribute); - a.contents = a.contents.concat(b.contents); -} - -/** reset scope and ident debugging ids */ -export function resetScopeIds() { - scopeId = 0; - identId = 0; -} - -let scopeId = 0; -let identId = 0; - -export function nextIdentId(): number { - return identId++; -} - -/** make a new Scope object */ -export function emptyScope( - parent: Scope | null, - kind: Scope["kind"] = "scope", -): Omit { - const id = scopeId++; - return { id, kind, parent, contents: [] }; -} - -/** For debugging, - * @return true if a scope is in the rootScope tree somewhere */ -export function containsScope(rootScope: Scope, scope: Scope): boolean { - if (scope === rootScope) { - return true; - } - for (const child of rootScope.contents) { - if (childScope(child) && containsScope(child, scope)) { - return true; - } - } - return false; -} - -/** @returns true if the provided element of a Scope - * is itself a Scope (and not an Ident) */ -export function childScope(child: Scope | Ident): child is Scope { - const { kind } = child; - return kind === "partial" || kind === "scope"; -} - -/** @returns true if the provided element of a Scope - * is an Ident (and not a child Scope) */ -export function childIdent(child: Scope | Ident): child is Ident { - return !childScope(child); -} - -/** find a public declaration with the given original name */ -export function publicDecl( - scope: Scope, - name: string, - conditions: Conditions, -): DeclIdent | undefined { - for (const elem of scope.contents) { - if (elem.kind === "decl" && elem.originalName === name) { - return elem; - } else if (elem.kind === "partial" && scopeValid(elem, conditions)) { - const found = publicDecl(elem, name, conditions); - if (found) return found; - } - } -} diff --git a/tools/packages/wesl/src/TransformBindingStructs.ts b/tools/packages/wesl/src/TransformBindingStructs.ts index 4a5f62546..340e461bd 100644 --- a/tools/packages/wesl/src/TransformBindingStructs.ts +++ b/tools/packages/wesl/src/TransformBindingStructs.ts @@ -4,28 +4,36 @@ import { AttributeElem, BindingStructElem, DeclarationElem, - FnElem, ModuleElem, - SimpleMemberRef, StructElem, StructMemberElem, SyntheticElem, TypeTemplateParameter, -} from "./AbstractElems.ts"; +} from "./parse/WeslElems.ts"; import { TransformedAST, WeslJsPlugin } from "./Linker.ts"; -import { visitAst } from "./LinkerUtil.ts"; -import { findDecl } from "./LowerAndEmit.ts"; import { minimallyMangledName } from "./Mangler.ts"; -import { - attributeToString, - contentsToString, - typeListToString, - typeParamToString, -} from "./RawEmit.ts"; +import { attributeToString, typeListToString } from "./RawEmit.ts"; import { textureStorage } from "./Reflection.ts"; import { DeclIdent, RefIdent } from "./Scope.ts"; import { filterMap } from "./Util.ts"; +/** trace through refersTo links in reference Idents until we find the declaration + * expects that bindIdents has filled in all refersTo: links + */ +function findDecl(ident: DeclIdent | RefIdent): DeclIdent { + let i: DeclIdent | RefIdent | undefined = ident; + do { + if (i.kind === "decl") { + return i; + } + i = i.refersTo; + } while (i); + + throw new Error( + `unresolved ident: ${ident.originalName} (bug in bindIdents?)`, + ); +} + export function bindingStructsPlugin(): WeslJsPlugin { return { transform: lowerBindingStructs, @@ -142,7 +150,7 @@ function bindingAttribute(attributes?: AttributeElem[]): boolean { if (!attributes) return false; return attributes.some( ({ attribute }) => - attribute.kind === "@attribute" && + attribute.kind === "attribute" && (attribute.name === "binding" || attribute.name === "group"), ); } @@ -153,7 +161,7 @@ export function transformBindingStruct( globalNames: Set, ): SyntheticElem[] { return s.members.map(member => { - const { typeRef, name: memberName } = member; + const { type: typeRef, name: memberName } = member; const { name: typeName } = typeRef!; // members should always have a typeRef.. TODO fix typing to show this const typeParameters = typeRef?.templateParams; @@ -309,10 +317,10 @@ export function transformBindingReference( if (tracing) console.log(`missing mangledVarName for ${refName}`); return { kind: "synthetic", text: refName }; } - const { extraComponents } = memberRef; - const extraText = extraComponents ? contentsToString(extraComponents) : ""; + // const { extraComponents } = memberRef; + // const extraText = extraComponents ? contentsToString(extraComponents) : ""; - const text = structMember.mangledVarName + extraText; + const text = structMember.mangledVarName + ""; const synthElem: SyntheticElem = { kind: "synthetic", text }; memberRef.contents = [synthElem]; return synthElem; diff --git a/tools/packages/wesl/src/Util.ts b/tools/packages/wesl/src/Util.ts index 0f9fdcff6..26cefcc83 100644 --- a/tools/packages/wesl/src/Util.ts +++ b/tools/packages/wesl/src/Util.ts @@ -1,5 +1,6 @@ +import { assertThat } from "./Assertions.ts"; +import { ModulePath } from "./Module.ts"; import { Span } from "mini-parse"; - export function multiKeySet( m: Map>, a: A, @@ -169,6 +170,29 @@ export function offsetToLineNumber( } } +/** Types that can be turned into a human-readable string */ +export type FmtDisplay = string | number | ModulePath; + +/** + * It is really easy to accidentally pass something that does not have a sensible toString function to a tagged template. + * This guards against that. + * + * Example + * ```ts + * let name = "cat"; + * let foo = str`hello ${name}`; + * ``` + */ +export function str(template: TemplateStringsArray, ...params: FmtDisplay[]) { + assertThat(template.length === params.length - 1); + let result = template[0]; + for (let i = 0; i < params.length; i++) { + result += params[i]; + result += template[i + 1]; + } + return result; +} + /** Highlights an error. * * Returns a string with the line, and a string with the ^^^^ carets diff --git a/tools/packages/wesl/src/VirtualFilesystem.ts b/tools/packages/wesl/src/VirtualFilesystem.ts new file mode 100644 index 000000000..40380b0b2 --- /dev/null +++ b/tools/packages/wesl/src/VirtualFilesystem.ts @@ -0,0 +1,110 @@ +import { assertThat } from "./Assertions.ts"; +import { ModulePath } from "./Module.ts"; +import { isIdent } from "./parse/WeslStream.ts"; +import { normalize, noSuffix } from "./PathUtil.ts"; + +/** A async virtual filesystem, can be backed by an in-memory map, or by a real filesystem, or by HTTP requests */ +export interface VirtualFilesystem { + /** + * The WESL module loading uses this to load files. + */ + readFile(path: ModulePath): Promise; + + debugFilePath(path: ModulePath): string; +} + +/** Creates a static filesystem from relative, normalized, Linux-style paths. + * All file and folder names must be valid WGSL identifiers. + * `.wgsl` and `.wesl` are the supported file extensions. + */ +export function staticFilesystem( + packageName: string, + files: Record, +): VirtualFilesystem { + // Avoid accidentally having a file called `toString` + const filesystemMap = new Map( + Object.entries(files).map(([path, contents]) => [ + fileToModulePath(path, packageName), + contents, + ]), + ); + function readFile(path: ModulePath): Promise { + const file = filesystemMap.get(path) ?? null; + return Promise.resolve(file); + } + return { + readFile, + debugFilePath(path) { + return path.path.join("/"); + }, + }; +} + +const fileNameRegex = /^(?[^.]+)\.(?wgsl|wesl)$/; + +const libRegex = /^lib\.w[eg]sl$/i; + +/** convert a file path (./foo/bar.wesl) + * to a module path (package::foo::bar) */ +function fileToModulePath(filePath: string, packageName: string): ModulePath { + if (filePath.includes("::")) { + // already a module path + return new ModulePath(filePath.split("::")); + } + if (packageName !== "package" && libRegex.test(filePath)) { + // special case for lib.wesl files in external packages + return new ModulePath([packageName]); + } + + const strippedPath = noSuffix(normalize(filePath)); + const moduleSuffix = strippedPath.split("/"); + return new ModulePath([packageName, ...moduleSuffix]); +} + +// LATER: Replace the above with this more strict verson +function fileToModulePath2(path: string): ModulePath { + if (path.startsWith("/")) { + throw new Error( + `Paths must be relative, but absolute path was found ${path}`, + ); + } + if (path.includes("\\")) { + throw new Error(`Paths must be Linux-style, but \\ was found ${path}`); + } + if (path.includes("..")) { + throw new Error(`Paths must be normalized, but .. was found ${path}`); + } + + const segments = path.split("/"); + if (segments[0] === ".") { + segments.shift(); + } + const lastSegment = segments.pop(); + if (lastSegment === undefined) { + throw new Error(`Path is missing a file name ${path}`); + } else { + const matches = lastSegment.match(fileNameRegex); + if (matches === null) { + throw new Error( + `Expected a valid file name, but ${lastSegment} is not one ${path}`, + ); + } + assertThat(matches.groups !== undefined); + const { name } = matches.groups; + if (!isIdent(name)) { + throw new Error( + `Path must only contain valid WGSL idents, but ${name} is not one ${path}`, + ); + } + } + + for (const segment of segments) { + if (!isIdent(segment)) { + throw new Error( + `Path must only contain valid WGSL idents, but ${segment} is not one ${path}`, + ); + } + } + + return new ModulePath(segments); +} diff --git a/tools/packages/wesl/src/WESLCollect.ts b/tools/packages/wesl/src/WESLCollect.ts deleted file mode 100644 index b6b588fb5..000000000 --- a/tools/packages/wesl/src/WESLCollect.ts +++ /dev/null @@ -1,614 +0,0 @@ -import { dlog } from "berry-pretty"; -import { CollectContext, CollectPair, srcLog, tracing } from "mini-parse"; -import { - AbstractElem, - AliasElem, - Attribute, - AttributeElem, - ConstAssertElem, - ConstElem, - ContainerElem, - DeclarationElem, - DeclIdentElem, - DirectiveElem, - DirectiveVariant, - FnElem, - FnParamElem, - GlobalVarElem, - GrammarElem, - HasAttributes, - IfAttribute, - ImportElem, - LetElem, - ModuleElem, - NameElem, - OverrideElem, - RefIdentElem, - SimpleMemberRef, - StandardAttribute, - StatementElem, - StructElem, - StructMemberElem, - StuffElem, - SwitchClauseElem, - TextElem, - TypedDeclElem, - TypeRefElem, - UnknownExpressionElem, - VarElem, -} from "./AbstractElems.ts"; -import { - StableState, - WeslAST, - WeslParseContext, - WeslParseState, -} from "./ParseWESL.ts"; -import { - DeclIdent, - emptyScope, - Ident, - mergeScope, - nextIdentId, - PartialScope, - RefIdent, - Scope, -} from "./Scope.ts"; -import { filterMap } from "./Util.ts"; - -export function importElem(cc: CollectContext) { - const importElems = cc.tags.owo?.[0] as ImportElem[]; // LATER ts typing - for (const importElem of importElems) { - (cc.app.stable as StableState).imports.push(importElem.imports); - addToOpenElem(cc, importElem as AbstractElem); - } -} - -/** add an elem to the .contents array of the currently containing element */ -function addToOpenElem(cc: CollectContext, elem: AbstractElem): void { - const weslContext: WeslParseContext = cc.app.context; - const { openElems } = weslContext; - if (openElems && openElems.length) { - const open = openElems[openElems.length - 1]; - open.contents.push(elem); - } -} - -/** create reference Ident and add to context */ -export function refIdent(cc: CollectContext): RefIdentElem { - const { src, start, end } = cc; - const app = cc.app as WeslParseState; - const { srcModule } = app.stable; - const originalName = src.slice(start, end); - - const kind = "ref"; - const ident: RefIdent = { - kind, - originalName, - ast: cc.app.stable, - id: nextIdentId(), - refIdentElem: null as any, // set below - }; - const identElem: RefIdentElem = { kind, start, end, srcModule, ident }; - ident.refIdentElem = identElem; - - saveIdent(cc, identElem); - addToOpenElem(cc, identElem); - return identElem; -} - -/** create declaration Ident and add to context */ -export function declCollect(cc: CollectContext): DeclIdentElem { - return declCollectInternal(cc, false); -} - -/** create global declaration Ident and add to context */ -export function globalDeclCollect(cc: CollectContext): DeclIdentElem { - return declCollectInternal(cc, true); -} - -function declCollectInternal( - cc: CollectContext, - isGlobal: boolean, -): DeclIdentElem { - const { src, start, end } = cc; - const app = cc.app as WeslParseState; - const { scope } = app.context; - const { srcModule } = app.stable; - const originalName = src.slice(start, end); - - const kind = "decl"; - const declElem = null as any; // we'll set declElem later - const ident: DeclIdent = { - declElem, - kind, - originalName, - scope, - isGlobal, - id: nextIdentId(), - srcModule, - }; - const identElem: DeclIdentElem = { kind, start, end, srcModule, ident }; - - saveIdent(cc, identElem); - addToOpenElem(cc, identElem); - return identElem; -} - -export const typedDecl = collectElem( - "typeDecl", - (cc: CollectContext, openElem: PartElem) => { - const decl = cc.tags.decl_elem?.[0] as DeclIdentElem; - const typeRef = cc.tags.typeRefElem?.[0] as TypeRefElem | undefined; - const partial: TypedDeclElem = { ...openElem, decl, typeRef }; - const elem = withTextCover(partial, cc); - - return elem; - }, -); - -/** add Ident to current open scope, add IdentElem to current open element */ -function saveIdent( - cc: CollectContext, - identElem: RefIdentElem | DeclIdentElem, -) { - const { ident } = identElem; - ident.id = nextIdentId(); - const weslContext: WeslParseContext = cc.app.context; - weslContext.scope.contents.push(ident); -} - -/** start a new child lexical Scope */ -function startScope(cc: CollectContext) { - startSomeScope("scope", cc); -} - -/** start a new child partial Scope */ -function startPartialScope(cc: CollectContext) { - startSomeScope("partial", cc); -} - -/** start a new lexical or partial scope */ -function startSomeScope(kind: Scope["kind"], cc: CollectContext): void { - const { scope } = cc.app.context as WeslParseContext; - const newScope = emptyScope(scope, kind); - - scope.contents.push(newScope); - cc.app.context.scope = newScope; -} - -/* close current Scope and set current scope to parent */ -function completeScope(cc: CollectContext): Scope { - const weslContext = cc.app.context as WeslParseContext; - const completedScope = weslContext.scope; - const ifAttributes = collectIfAttributes(cc); - - const { parent } = completedScope; - if (parent) { - weslContext.scope = parent; - } else if (tracing) { - console.log("ERR: completeScope, no parent scope", completedScope.contents); - } - completedScope.ifAttribute = ifAttributes?.[0]; - return completedScope; -} - -/** return @if attributes from the 'attribute' tag */ -function collectIfAttributes(cc: CollectContext): IfAttribute[] | undefined { - const attributes = cc.tags.attribute as AttributeElem[] | undefined; - return filterIfAttributes(attributes); -} - -function filterIfAttributes( - attributes?: AttributeElem[], -): IfAttribute[] | undefined { - if (!attributes) return; - return filterMap(attributes, a => - a.attribute.kind === "@if" ? a.attribute : undefined, - ); -} - -// prettier-ignore -export type OpenElem = - Pick< T, "kind" | "contents">; - -// prettier-ignore -export type PartElem = - Pick< T, "kind" | "start" | "end" | "contents"> ; - -// prettier-ignore -type VarLikeElem = - | GlobalVarElem - | VarElem - | LetElem - | ConstElem - | OverrideElem; - -export function collectVarLike( - kind: E["kind"], -): CollectPair { - return collectElem(kind, (cc: CollectContext, openElem: PartElem) => { - const name = cc.tags.var_name?.[0] as TypedDeclElem; - const decl_scope = cc.tags.decl_scope?.[0] as Scope; - const attributes = cc.tags.attribute as AttributeElem[] | undefined; - const partElem = { ...openElem, name, attributes } as E; - const varElem = withTextCover(partElem, cc); - (name.decl.ident as DeclIdent).declElem = varElem as DeclarationElem; - name.decl.ident.scope = decl_scope; - return varElem; - }); -} - -export const aliasCollect = collectElem( - "alias", - (cc: CollectContext, openElem: PartElem) => { - const name = cc.tags.alias_name?.[0] as DeclIdentElem; - const alias_scope = cc.tags.alias_scope?.[0] as Scope; - const typeRef = cc.tags.typeRefElem?.[0] as TypeRefElem; - const attributes: AttributeElem[] = cc.tags.attributes?.flat() ?? []; - const partElem: AliasElem = { ...openElem, name, attributes, typeRef }; - const aliasElem = withTextCover(partElem, cc); - name.ident.scope = alias_scope; - name.ident.declElem = aliasElem; - return aliasElem; - }, -); - -/** - * Collect a FnElem and associated scopes. - * - * Scope definition is a bit complicated in wgsl and wesl for fns. - * Here's what we collect for scopes for this example function: - * @if(true) fn foo(a: u32) -> @location(x) R { let y = a; } - * - * -{ // partial scope in case the whole shebang is prefixed by an `@if` - * %foo - * - * {<=%foo // foo decl references this header+returnType+body scope (for tracing dependencies from decls) - * x // for @location(x) (contains no decls, so ok to merge for tracing) - * %a u32 // merged from header scope - * R // merged from return type (contains no decls, so ok to merge for tracing) - * %y a // merged body scope - * } - * } - */ -export const fnCollect = collectElem( - "fn", - (cc: CollectContext, openElem: PartElem) => { - // extract tags we care about - const ourTags = fnTags(cc); - const { name, headerScope, returnScope, bodyScope, body, params } = ourTags; - const { attributes, returnAttributes, returnType, fnScope } = ourTags; - - // create the fn element - const fnElem: FnElem = { - ...openElem, - ...{ name, attributes, params, returnAttributes, body, returnType }, - }; - - // --- setup the various scopes -- - - // attach ifAttributes to outermost partial scope - fnScope.ifAttribute = filterIfAttributes(attributes)?.[0]; - - // merge the header, return and body scopes into the one scope - const mergedScope = headerScope; - returnScope && mergeScope(mergedScope, returnScope); - mergeScope(mergedScope, bodyScope); - - // rewrite scope contents to remove old scopes and add merged scope - const filtered: (Ident | Scope)[] = []; - for (const e of fnScope.contents) { - if (e === headerScope || e == returnScope) { - continue; - } else if (e === bodyScope) { - filtered.push(mergedScope); - } else { - filtered.push(e); - } - } - fnScope.contents = filtered; - - name.ident.declElem = fnElem; - name.ident.scope = mergedScope; - - return fnElem; - }, -); - -/** Fetch and cast the collection tags for fnCollect - * LATER typechecking for collect! */ -function fnTags(cc: CollectContext) { - const { fn_attributes, fn_name, fn_param, return_attributes } = cc.tags; - const { return_type } = cc.tags; - const { header_scope, return_scope, body_scope, body_statement } = cc.tags; - const { fn_partial_scope } = cc.tags; - - const name = fn_name?.[0] as DeclIdentElem; - const headerScope = header_scope?.[0] as Scope; - const returnScope = return_scope?.[0] as Scope | undefined; - const bodyScope = body_scope?.[0] as Scope; - const body = body_statement?.[0] as StatementElem; - const params: FnParamElem[] = fn_param?.flat(3) ?? []; - const attributes: AttributeElem[] | undefined = fn_attributes?.flat(); - const returnAttributes: AttributeElem[] | undefined = - return_attributes?.flat(); - const returnType: TypeRefElem | undefined = return_type?.flat(3)[0]; - const fnScope = fn_partial_scope?.[0] as PartialScope; - - return { - ...{ name, headerScope, returnScope, bodyScope, body, params }, - ...{ attributes, returnAttributes, returnType, fnScope }, - }; -} - -export const collectFnParam = collectElem( - "param", - (cc: CollectContext, openElem: PartElem) => { - const name = cc.tags.param_name?.[0]! as TypedDeclElem; - const attributes: AttributeElem[] = cc.tags.attributes?.flat() ?? []; - const elem: FnParamElem = { ...openElem, name, attributes }; - const paramElem = withTextCover(elem, cc); - name.decl.ident.declElem = paramElem; // TODO is this right? - - return paramElem; - }, -); - -export const collectStruct = collectElem( - "struct", - (cc: CollectContext, openElem: PartElem) => { - const name = cc.tags.type_name?.[0] as DeclIdentElem; - const members = cc.tags.members as StructMemberElem[]; - const attributes: AttributeElem[] = cc.tags.attributes?.flat() ?? []; - name.ident.scope = cc.tags.struct_scope?.[0] as Scope; - const structElem = { ...openElem, name, attributes, members }; - const elem = withTextCover(structElem, cc); - (name.ident as DeclIdent).declElem = elem as DeclarationElem; - - return elem; - }, -); - -export const collectStructMember = collectElem( - "member", - (cc: CollectContext, openElem: PartElem) => { - const name = cc.tags.nameElem?.[0]!; - const typeRef = cc.tags.typeRefElem?.[0]; - const attributes = cc.tags.attribute?.flat(3) as AttributeElem[]; - const partElem = { ...openElem, name, attributes, typeRef }; - return withTextCover(partElem, cc); - }, -); - -export const specialAttribute = collectElem( - "attribute", - (cc: CollectContext, openElem: PartElem) => { - const attribute = cc.tags.attr_variant?.[0] as Attribute; - const attrElem: AttributeElem = { ...openElem, attribute }; - return attrElem; - }, -); - -/** debug routine to log tags at collect() */ -export function logCollect(msg?: string): (cc: CollectContext) => void { - return function _log(cc: CollectContext) { - dlog(msg ?? "log", { tags: [...Object.keys(cc.tags)] }); - }; -} - -export const assertCollect = attrElemCollect("assert"); -export const statementCollect = attrElemCollect("statement"); -export const switchClauseCollect = - attrElemCollect("switch-clause"); - -/** @return a collector for container elem types that have only an attributes field */ -function attrElemCollect( - kind: T["kind"], -): CollectPair { - return collectElem(kind, (cc: CollectContext, openElem: PartElem) => { - const attributes = cc.tags.attribute?.flat(3) as AttributeElem[]; - const partElem = { ...openElem, attributes }; - return withTextCover(partElem as T, cc); - }); -} - -export const collectAttribute = collectElem( - "attribute", - (cc: CollectContext, openElem: PartElem) => { - const params = cc.tags.attrParam as UnknownExpressionElem[] | undefined; - const name = cc.tags.name?.[0]! as string; - const kind = "@attribute"; - const stdAttribute: StandardAttribute = { kind, name, params }; - const attrElem: AttributeElem = { ...openElem, attribute: stdAttribute }; - return attrElem; - }, -); - -export const typeRefCollect = collectElem( - "type", - // @ts-ignore - (cc: CollectContext, openElem: PartElem) => { - let templateParamsTemp: any[] | undefined = cc.tags.templateParam?.flat(3); - - const typeRef = cc.tags.typeRefName?.[0] as string | RefIdentElem; - const name = typeof typeRef === "string" ? typeRef : typeRef.ident; - const partElem = { - ...openElem, - name, - templateParams: templateParamsTemp as any[], - }; - // @ts-ignore - return withTextCover(partElem, cc); - }, -); - -// LATER This creates useless unknown-expression elements -export const expressionCollect = collectElem( - "expression", - (cc: CollectContext, openElem: PartElem) => { - const partElem = { ...openElem }; - return withTextCover(partElem, cc); - }, -); - -export function globalAssertCollect(cc: CollectContext): void { - const globalAssert = cc.tags.const_assert?.flat()[0]; - const ast = cc.app.stable as WeslAST; - if (!ast.moduleAsserts) ast.moduleAsserts = []; - ast.moduleAsserts.push(globalAssert); -} - -export const stuffCollect = collectElem( - "stuff", - (cc: CollectContext, openElem: PartElem) => { - const partElem = { ...openElem }; - return withTextCover(partElem, cc); - }, -); - -export const memberRefCollect = collectElem( - "memberRef", - (cc: CollectContext, openElem: PartElem) => { - const { component, structRef, extra_components } = cc.tags; - const member = component![0] as NameElem; - const name = structRef?.flat()[0] as RefIdentElem; - const extraComponents = extra_components?.flat()[0] as StuffElem; - - const partElem: SimpleMemberRef = { - ...openElem, - name, - member, - extraComponents, - }; - return withTextCover(partElem, cc) as any; - }, -); - -export function nameCollect(cc: CollectContext): NameElem { - const { start, end, src, app } = cc; - const name = src.slice(start, end); - const elem: NameElem = { kind: "name", start, end, name }; - addToOpenElem(cc, elem); - return elem; -} - -export const collectModule = collectElem( - "module", - (cc: CollectContext, openElem: PartElem) => { - const ccComplete = { ...cc, start: 0, end: cc.src.length }; // force module to cover entire source despite ws skipping - const moduleElem: ModuleElem = withTextCover(openElem, ccComplete); - const weslState: StableState = cc.app.stable; - weslState.moduleElem = moduleElem; - return moduleElem; - }, -); - -export function directiveCollect(cc: CollectContext): DirectiveElem { - const { start, end } = cc; - const directive: DirectiveVariant = cc.tags.directive?.flat()[0]; - const attributes: AttributeElem[] | undefined = cc.tags.attribute?.flat(); - - const kind = "directive"; - const elem: DirectiveElem = { kind, attributes, start, end, directive }; - addToOpenElem(cc, elem); - return elem; -} - -/** - * Collect a LexicalScope. - * - * The scope starts encloses all idents and subscopes inside the parser to which - * .collect is attached - */ -export const scopeCollect: CollectPair = { - before: startScope, - after: completeScope, -}; - -/** - * Collect a PartialScope. - * - * The scope starts encloses all idents and subscopes inside the parser to which - * .collect is attached - */ -export const partialScopeCollect: CollectPair = { - before: startPartialScope, - after: completeScope, -}; - -/** utility to collect an ElemWithContents - * starts the new element as the collection point corresponding - * to the start of the attached grammar and completes - * the element in the at the end of the grammar. - * - * In between the start and the end, the new element is available - * as an 'open' element in the collection context. While this element - * is 'open', other collected are added to the 'contents' field of this - * open element. - */ -function collectElem( - kind: V["kind"], - fn: (cc: CollectContext, partialElem: PartElem) => V, -): CollectPair { - return { - before: (cc: CollectContext) => { - const partialElem = { kind, contents: [] }; - const weslContext: WeslParseContext = cc.app.context; - weslContext.openElems.push(partialElem); - }, - after: (cc: CollectContext) => { - // TODO refine start? - const weslContext: WeslParseContext = cc.app.context; - const partialElem = weslContext.openElems.pop()!; - console.assert(partialElem && partialElem.kind === kind); - const elem = fn(cc, { ...partialElem, start: cc.start, end: cc.end }); - if (elem) addToOpenElem(cc, elem as AbstractElem); - return elem; - }, - }; -} - -/** - * @return a copy of the element with contents extended - * to include TextElems to cover the entire range. - */ -function withTextCover( - elem: T, - cc: CollectContext, -): T { - const contents = coverWithText(cc, elem); - return { ...elem, contents }; -} - -/** cover the entire source range with Elems by creating TextElems to - * cover any parts of the source that are not covered by other elems - * @returns the existing elems combined with any new TextElems, in src order */ -function coverWithText(cc: CollectContext, elem: ContainerElem): GrammarElem[] { - let { start: pos } = cc; - const ast: WeslAST = cc.app.stable; - const { contents, end } = elem; - const sorted = (contents as GrammarElem[]).sort((a, b) => a.start - b.start); - - const elems: GrammarElem[] = []; - for (const elem of sorted) { - if (pos < elem.start) { - elems.push(makeTextElem(elem.start)); - } - elems.push(elem); - pos = elem.end; - } - if (pos < end) { - elems.push(makeTextElem(end)); - } - - return elems; - - function makeTextElem(end: number): TextElem { - return { kind: "text", start: pos, end, srcModule: ast.srcModule }; - } -} - -function collectLog(cc: CollectContext, ...messages: any[]): void { - const { src, start, end } = cc; - srcLog(src, [start, end], ...messages); -} diff --git a/tools/packages/wesl/src/WeslDevice.ts b/tools/packages/wesl/src/WeslDevice.ts index 880fc0842..14996269a 100644 --- a/tools/packages/wesl/src/WeslDevice.ts +++ b/tools/packages/wesl/src/WeslDevice.ts @@ -1,4 +1,4 @@ -import { ExtendedGPUValidationError } from "./LinkedWesl"; +import { ExtendedGPUValidationError } from "./LinkedWesl.ts"; import { encodeVlq } from "./vlq/vlq"; /** diff --git a/tools/packages/wesl/src/debug/ASTtoString.ts b/tools/packages/wesl/src/debug/ASTtoString.ts index 7cc157361..47f814816 100644 --- a/tools/packages/wesl/src/debug/ASTtoString.ts +++ b/tools/packages/wesl/src/debug/ASTtoString.ts @@ -1,290 +1,345 @@ -import { assertUnreachable } from "../../../mini-parse/src/Assertions.ts"; import { - AbstractElem, Attribute, AttributeElem, - DirectiveElem, - FnElem, - StuffElem, - TypedDeclElem, - TypeRefElem, - TypeTemplateParameter, - UnknownExpressionElem, -} from "../AbstractElems.ts"; + ConstAssertElem, + DeclarationElem, + FunctionDeclarationElem, + GlobalDeclarationElem, + IfClause, + ModuleElem, + Statement, + SwitchCaseSelector, +} from "../parse/WeslElems.ts"; +import { assertThat, assertUnreachable } from "../Assertions.ts"; import { diagnosticControlToString, expressionToString, -} from "../LowerAndEmit.ts"; + lhsExpressionToString, + templatedIdentToString, +} from "../lower/LowerAndEmit.ts"; +import { + DiagnosticDirective, + DirectiveElem, + EnableDirective, + RequiresDirective, +} from "../parse/DirectiveElem.ts"; +import { ExpressionElem } from "../parse/ExpressionElem.ts"; +import { ImportElem } from "../parse/ImportElems.ts"; import { importToString } from "./ImportToString.ts"; import { LineWrapper } from "./LineWrapper.ts"; const maxLineLength = 150; -export function astToString(elem: AbstractElem, indent = 0): string { - const { kind } = elem; - const str = new LineWrapper(indent, maxLineLength); - str.add(kind); - addElemFields(elem, str); - let childStrings: string[] = []; - if ("contents" in elem) { - childStrings = elem.contents.map(e => astToString(e, indent + 2)); +export function astToString(ast: ModuleElem, indent = 0): string { + const str = new LineWrapper(indent); + str.add`module`; + str.nl(); + const moduleStr = str.indentedBlock(2); + for (const importElem of ast.imports) { + printImportElem(importElem, moduleStr); } - if (childStrings.length) { - str.nl(); - str.addBlock(childStrings.join("\n"), false); + for (const directive of ast.directives) { + printDirectiveElem(directive, moduleStr); } + for (const decl of ast.declarations) { + printGlobalDecl(decl, moduleStr); + } + return str.print(maxLineLength); +} - return str.result; +export function globalDeclToString( + elem: GlobalDeclarationElem, + indent = 0, +): string { + const str = new LineWrapper(indent); + printGlobalDecl(elem, str); + return str.print(maxLineLength); } -// LATER rewrite to be shorter and easier to read -function addElemFields(elem: AbstractElem, str: LineWrapper): void { +function printGlobalDecl(elem: GlobalDeclarationElem, str: LineWrapper): void { const { kind } = elem; - if (kind === "text") { - const { srcModule, start, end } = elem; - str.add(` '${srcModule.src.slice(start, end)}'`); - } else if ( - kind === "var" || - kind === "let" || - kind === "gvar" || - kind === "const" || - kind === "override" - ) { - addTypedDeclIdent(elem.name, str); - listAttributeElems(elem.attributes, str); + printAttributes(elem.attributes, str); + if (kind === "alias") { + str.add`alias ${elem.name.name}`; + str.add` = ${templatedIdentToString(elem.type)}`; + str.nl(); + } else if (kind === "assert") { + printConstAssert(elem, str); + str.nl(); + } else if (kind === "declaration") { + printDeclaration(elem, str); + str.nl(); + } else if (kind === "function") { + printFunction(elem, str); } else if (kind === "struct") { - str.add(" " + elem.name.ident.originalName); - } else if (kind === "member") { - const { name, typeRef, attributes } = elem; - listAttributeElems(attributes, str); - str.add(" " + name.name); - str.add(": " + typeRefElemToString(typeRef)); - } else if (kind === "name") { - str.add(" " + elem.name); - } else if (kind === "memberRef") { - const { extraComponents } = elem; - const extraText = - extraComponents ? debugContentsToString(extraComponents) : ""; - str.add(` ${elem.name.ident.originalName}.${elem.member.name}${extraText}`); - } else if (kind === "fn") { - addFnFields(elem, str); - } else if (kind === "alias") { - const { name, typeRef } = elem; - const prefix = name.ident.kind === "decl" ? "%" : ""; - str.add(" " + prefix + name.ident.originalName); - str.add("=" + typeRefElemToString(typeRef)); - } else if (kind === "attribute") { - addAttributeFields(elem.attribute, str); - } else if (kind === "expression") { - const contents = elem.contents - .map(e => { - if (e.kind === "text") { - return "'" + e.srcModule.src.slice(e.start, e.end) + "'"; - } else { - return astToString(e); - } - }) - .join(" "); - str.add(" " + contents); - } else if (kind === "type") { - const { name } = elem; - const nameStr = typeof name === "string" ? name : name.originalName; - str.add(" " + nameStr); - - if (elem.templateParams !== undefined) { - const paramStrs = elem.templateParams - .map(templateParamToString) - .join(", "); - str.add("<" + paramStrs + ">"); + str.add`struct ${elem.name.name}`; + str.nl(); + const childPrinter = str.indentedBlock(2); + for (const member of elem.members) { + printAttributes(member.attributes, childPrinter); + childPrinter.add`${member.name.name}`; + childPrinter.add`: ${templatedIdentToString(member.type)}`; + childPrinter.nl(); } - } else if (kind === "synthetic") { - str.add(` '${elem.text}'`); - } else if (kind === "import") { - str.add(" " + importToString(elem.imports)); - } else if (kind === "ref") { - str.add(" " + elem.ident.originalName); - } else if (kind === "typeDecl") { - addTypedDeclIdent(elem, str); - } else if (kind === "decl") { - const { ident } = elem; - str.add(" %" + ident.originalName); - } else if (kind === "assert") { - // Nothing to do for now - } else if (kind === "module") { - // Ignore this kind of elem - } else if (kind === "param") { - // LATER This branch shouldn't exist - } else if (kind === "stuff") { - // Ignore - } else if (kind === "directive") { - addDirective(elem, str); - } else if (kind === "statement") { - listAttributeElems(elem.attributes, str); - } else if (kind === "switch-clause") { - // Nothing to do for now } else { assertUnreachable(kind); } } -function addAttributeFields(attr: Attribute, str: LineWrapper) { - const { kind } = attr; - if (kind === "@attribute") { - const { name, params } = attr; - str.add(" @" + name); - if (params && params.length > 0) { - str.add("("); - str.add(params.map(unknownExpressionToString).join(", ")); - str.add(")"); +/** Does not include the new line */ +function printConstAssert(elem: ConstAssertElem, str: LineWrapper) { + str.add`const_assert `; + str.add`${expressionToString(elem.expression)}`; +} + +function printImportElem(elem: ImportElem, str: LineWrapper) { + printAttributes(elem.attributes, str); + str.add`${importToString(elem.imports)}`; + str.nl(); +} + +function printAttributes(elems: AttributeElem[] | undefined, str: LineWrapper) { + if (elems === undefined || elems.length === 0) return; + for (let i = 0; i < elems.length - 1; i++) { + printAttribute(elems[i].attribute, str); + str.add` `; + } + printAttribute(elems[elems.length - 1].attribute, str); + str.nl(); +} + +function printAttribute(elem: Attribute, str: LineWrapper) { + const { kind } = elem; + if (kind === "attribute") { + const { name, params } = elem; + if (params.length > 0) { + str.add`@${name}(`; + printExpressions(params, str); + str.add`)`; + } else { + str.add`@${name}`; } } else if (kind === "@builtin") { - str.add(` @builtin(${attr.param.name})`); + str.add`@builtin(${elem.param.name})`; } else if (kind === "@diagnostic") { - str.add( - ` @diagnostic${diagnosticControlToString(attr.severity, attr.rule)}`, - ); + str.add` @diagnostic${diagnosticControlToString(elem.severity, elem.rule)}`; } else if (kind === "@if") { - str.add(" @if"); - str.add("("); - str.add(expressionToString(attr.param.expression)); - str.add(")"); + str.add`@if`; + str.add`(`; + str.add`${expressionToString(elem.param.expression)}`; + str.add`)`; } else if (kind === "@interpolate") { - str.add(` @interpolate(${attr.params.map(v => v.name).join(", ")})`); + str.add`@interpolate(${elem.params.map(v => v.name).join(", ")})`; } else { assertUnreachable(kind); } } -/** @return string representation of an attribute (for test/debug) */ -export function attributeToString(attr: Attribute): string { - const str = new LineWrapper(0, maxLineLength); - addAttributeFields(attr, str); - return str.result; -} - -function addTypedDeclIdent(elem: TypedDeclElem, str: LineWrapper) { - const { decl, typeRef } = elem; - str.add(" %" + decl.ident.originalName); - if (typeRef) { - str.add(" : " + typeRefElemToString(typeRef)); +function printExpressions(expressions: ExpressionElem[], str: LineWrapper) { + if (expressions.length === 0) return; + for (let i = 0; i < expressions.length - 1; i++) { + str.add`${expressionToString(expressions[i])}, `; } + printExpression(expressions[expressions.length - 1], str); +} +function printExpression(expression: ExpressionElem, str: LineWrapper) { + str.add`${expressionToString(expression)}`; } -function addFnFields(elem: FnElem, str: LineWrapper) { - const { name, params, returnType, attributes } = elem; - - str.add(" " + name.ident.originalName); - - str.add("("); - const paramStrs = params - .map( - ( - p, // LATER DRY - ) => { - const { name } = p; - const { originalName } = name.decl.ident; - const typeRef = typeRefElemToString(name.typeRef!); - return originalName + ": " + typeRef; - }, - ) - .join(", "); - str.add(paramStrs); - str.add(")"); - - listAttributeElems(attributes, str); - - if (returnType) { - str.add(" -> " + typeRefElemToString(returnType)); - } +function printDirectiveElem(elem: DirectiveElem, str: LineWrapper) { + printAttributes(elem.attributes, str); + printDirective(elem.directive, str); + str.nl(); } -/** show attribute names in short form to verify collection */ -function listAttributeElems( - attributes: AttributeElem[] | undefined, +function printDirective( + elem: DiagnosticDirective | EnableDirective | RequiresDirective, str: LineWrapper, ) { - attributes?.forEach(a => str.add(" " + attributeName(a.attribute))); -} - -function attributeName(attr: Attribute): string { - const { kind } = attr; - if (kind === "@attribute") { - return "@" + attr.name; - } else { - return kind; - } -} - -function addDirective(elem: DirectiveElem, str: LineWrapper) { - const { directive, attributes } = elem; - const { kind } = directive; + const { kind } = elem; if (kind === "diagnostic") { - const { severity, rule } = directive; - const control = diagnosticControlToString(severity, rule); - str.add(` diagnostic${control}`); - } else if (kind === "enable" || kind === "requires") { - str.add(` ${kind} ${directive.extensions.map(v => v.name).join(", ")}`); + str.add`diagnostic${diagnosticControlToString(elem.severity, elem.rule)}`; + } else if (kind === "enable") { + str.add`enable ${elem.extensions.map(v => v.name).join(", ")}`; + } else if (kind === "requires") { + str.add`requires${elem.extensions.map(v => v.name).join(", ")}`; } else { assertUnreachable(kind); } - listAttributeElems(attributes, str); } -function unknownExpressionToString(elem: UnknownExpressionElem): string { - // LATER Temp hack while I clean up the expression parsing - if ("contents" in elem) { - // @ts-ignore - const contents = elem.contents - // @ts-ignore - .map(e => { - if (e.kind === "text") { - return "'" + e.srcModule.src.slice(e.start, e.end) + "'"; - } else { - return astToString(e); - } - }) - .join(" "); - return contents; +/** Does not include the new line */ +function printDeclaration(elem: DeclarationElem, str: LineWrapper) { + str.add`${elem.variant.kind}`; + str.add` ${elem.name.name}`; + if (elem.type) { + str.add` : ${templatedIdentToString(elem.type)}`; + } + if (elem.initializer) { + str.add` = `; + printExpression(elem.initializer, str); } - return astToString(elem); } -function templateParamToString(p: TypeTemplateParameter): string { - if (typeof p === "string") { - return p; - } else if (p.kind === "type") { - return typeRefElemToString(p); - } else if (p.kind === "expression") { - return unknownExpressionToString(p); +function printFunction(elem: FunctionDeclarationElem, str: LineWrapper) { + str.add`fn ${elem.name.name}`; + str.add`(`; + if ( + elem.params.some(v => v.attributes !== undefined && v.attributes.length > 0) + ) { + // Switch to long param printing mode iff there are attributes + str.nl(); + const paramsStr = str.indentedBlock(2); + for (const p of elem.params) { + printAttributes(p.attributes, paramsStr); + paramsStr.add`${p.name.name}: ${templatedIdentToString(p.type)}`; + paramsStr.nl(); + } } else { - console.log("unknown template parameter type", p); - return "??"; + const paramsStr = elem.params + .map(p => p.name.name + ": " + templatedIdentToString(p.type)) + .join(", "); + str.add`${paramsStr}`; } + + str.add`)`; + printAttributes(elem.returnAttributes, str); + if (elem.returnType) { + str.add` -> ${templatedIdentToString(elem.returnType)}`; + } + str.nl(); + printStatement(elem.body, str); } -function typeRefElemToString(elem: TypeRefElem): string { - if (!elem) return "?type?"; - const { name } = elem; - const nameStr = typeof name === "string" ? name : name.originalName; +function printStatement(stmt: Statement, str: LineWrapper) { + printAttributes(stmt.attributes, str); + if (stmt.kind === "compound-statement") { + if (stmt.body.length > 0) { + const bodyStr = str.indentedBlock(2); + stmt.body.forEach(v => printStatement(v, bodyStr)); + } + return; // Skip printing the final newline + } - let params = ""; - if (elem.templateParams !== undefined) { - const paramStrs = elem.templateParams.map(templateParamToString).join(", "); - params = "<" + paramStrs + ">"; + if (stmt.kind === "assert") { + printConstAssert(stmt, str); + } else if (stmt.kind === "assignment-statement") { + if (stmt.left.kind === "discard-expression") { + str.add`_`; + } else { + str.add`${lhsExpressionToString(stmt.left)}`; + } + str.add` ${stmt.operator.value} `; + printExpression(stmt.right, str); + } else if (stmt.kind === "call-statement") { + str.add`${templatedIdentToString(stmt.function)}(`; + printExpressions(stmt.arguments, str); + str.add`)`; + } else if (stmt.kind === "declaration") { + printDeclaration(stmt, str); + } else if ( + stmt.kind === "break-statement" || + stmt.kind === "continue-statement" || + stmt.kind === "discard-statement" + ) { + str.add`${stmt.kind}`; + } else if (stmt.kind === "return-statement") { + if (stmt.expression) { + str.add`return ${expressionToString(stmt.expression)}`; + } else { + str.add`return`; + } + } else if (stmt.kind === "postfix-statement") { + str.add`${lhsExpressionToString(stmt.expression)}`; + str.add`${stmt.operator.value}`; + } else if (stmt.kind === "for-statement") { + str.add`for(`; + str.nl(); + const childStr = str.indentedBlock(2); + if (stmt.initializer !== undefined) { + printStatement(stmt.initializer, childStr); + } + if (stmt.condition !== undefined) { + printExpression(stmt.condition, childStr); + childStr.nl(); + } + if (stmt.update !== undefined) { + printStatement(stmt.update, childStr); + } + str.add`)`; + str.nl(); + printStatement(stmt.body, str); + } else if (stmt.kind === "if-else-statement") { + printIfClause(stmt.main, str); + } else if (stmt.kind === "loop-statement") { + str.add`loop`; + str.nl(); + printStatement(stmt.body, str); + if (stmt.continuing !== undefined) { + const bodyStr = str.indentedBlock(2); + printAttributes(stmt.continuing.attributes, bodyStr); + bodyStr.add`continuing`; + bodyStr.nl(); + printStatement(stmt.continuing.body, bodyStr); + + const breakIf = stmt.continuing.breakIf; + if (breakIf !== undefined) { + const continuingStr = str.indentedBlock(2); + printAttributes(breakIf.attributes, continuingStr); + continuingStr.add`break if ${expressionToString(breakIf.expression)}`; + continuingStr.nl(); + } + } + } else if (stmt.kind === "switch-statement") { + str.add`switch `; + printExpression(stmt.selector, str); + str.nl(); + printAttributes(stmt.bodyAttributes, str); + const clauseStr = str.indentedBlock(2); + for (const clause of stmt.clauses) { + printAttributes(clause.attributes, clauseStr); + assertThat(clause.cases.length > 0); + clauseStr.add`case `; + printSwitchCase(clause.cases[0], clauseStr); + for (let i = 1; i < clause.cases.length; i++) { + clauseStr.add`, `; + printSwitchCase(clause.cases[i], clauseStr); + } + clauseStr.add`:`; + clauseStr.nl(); + printStatement(clause.body, clauseStr); + } + } else if (stmt.kind === "while-statement") { + str.add`while `; + printExpression(stmt.condition, str); + str.nl(); + printStatement(stmt.body, str); + } else { + assertUnreachable(stmt); } - return nameStr + params; + str.nl(); } -export function debugContentsToString(elem: StuffElem): string { - const parts = elem.contents.map(c => { - const { kind } = c; - if (kind === "text") { - return c.srcModule.src.slice(c.start, c.end); - } else if (kind === "ref") { - return c.ident.originalName; // not using the mapped to decl name, so this can be used for debug.. +function printIfClause(clause: IfClause, str: LineWrapper) { + str.add`if `; + printExpression(clause.condition, str); + str.nl(); + printStatement(clause.accept, str); + if (clause.reject !== undefined) { + if (clause.reject.kind === "compound-statement") { + str.add`else`; + str.nl(); + printStatement(clause.reject, str); } else { - return `?${c.kind}?`; + str.add`else `; + printIfClause(clause.reject, str); } - }); - return parts.join(" "); + } +} +function printSwitchCase(switchCase: SwitchCaseSelector, str: LineWrapper) { + if (switchCase.expression === "default") { + str.add`default`; + } else { + printExpression(switchCase.expression, str); + } } diff --git a/tools/packages/wesl/src/debug/ImportToString.ts b/tools/packages/wesl/src/debug/ImportToString.ts index f171c3cb1..9e790446c 100644 --- a/tools/packages/wesl/src/debug/ImportToString.ts +++ b/tools/packages/wesl/src/debug/ImportToString.ts @@ -1,12 +1,12 @@ -import { assertUnreachable } from "../../../mini-parse/src/Assertions"; +import { assertUnreachable } from "../Assertions.ts"; import { ImportCollection, ImportItem, ImportStatement, -} from "../AbstractElems"; +} from "../parse/ImportElems.ts"; export function importToString(tree: ImportStatement): string { - return importToStringImpl(tree) + ";"; + return "import " + importToStringImpl(tree) + ";"; } function importToStringImpl(tree: ImportStatement): string { diff --git a/tools/packages/wesl/src/debug/LineWrapper.ts b/tools/packages/wesl/src/debug/LineWrapper.ts index fe6f7e489..d05cdd1ca 100644 --- a/tools/packages/wesl/src/debug/LineWrapper.ts +++ b/tools/packages/wesl/src/debug/LineWrapper.ts @@ -1,67 +1,75 @@ -/** debug utility for constructing strings that wrap at a fixed column width - * text beyond the column width is wrapped to start on the next line - */ +import { FmtDisplay, str } from "../Util.ts"; + +type IndentedBlock = { + indent: number; + text: (string | IndentedBlock)[]; +}; + export class LineWrapper { - #fragments: string[] = []; - #column = 0; - #spc: string; - #oneLine = true; - #isHanging = false; - #hangingSpc: string; + block: IndentedBlock; + + constructor(indent: number) { + this.block = { + indent, + text: [], + }; + } - constructor( - readonly indent = 0, - readonly maxWidth = 60, - readonly hangingIndent = 2, - ) { - this.#spc = " ".repeat(indent); - this.#hangingSpc = " ".repeat(hangingIndent); + indentedBlock(indent: number): LineWrapper { + const newWrapper = new LineWrapper(indent); + this.block.text.push(newWrapper.block); + return newWrapper; } /** add a new line to the constructed string */ nl() { - this.#fragments.push("\n"); - this.#column = 0; - this.#oneLine = false; - this.#isHanging = false; + this.block.text.push("\n"); } /** add a string, wrapping to the next line if necessary */ - add(s: string) { - if (this.#column + firstLineLength(s) > this.maxWidth) { - this.hangingNl(); - } - if (this.#column === 0) { - this.#fragments.push(this.#spc); - if (this.#isHanging) { - this.#fragments.push(this.#hangingSpc); - } - this.#column = this.indent; - } - this.#fragments.push(s); - this.#column += s.length; - } - - /** add a raw block of text with no wrapping */ - addBlock(s: string, andNewLine = true) { - this.#fragments.push(s); - if (andNewLine) this.nl(); + add(template: TemplateStringsArray, ...params: FmtDisplay[]) { + this.block.text.push(str(template, ...params)); } /** @return the constructed string */ - get result(): string { - return this.#fragments.join(""); + print(maxWidth = 60, hangingIndent = 2): string { + return printBlock(this.block, maxWidth, hangingIndent); } - - /** true if the result contains no newlines */ - get oneLine(): boolean { - return this.#oneLine; +} +function printBlock( + block: IndentedBlock, + maxWidth = 60, + hangingIndent = 2, +): string { + let result = ""; + const spc = " ".repeat(block.indent); + const hangingSpc = " ".repeat(hangingIndent); + for (const s of block.text) { + const column = getColumn(result); + if (typeof s === "string") { + if (column + firstLineLength(s) > maxWidth) { + result += "\n"; + result += spc; + result += hangingSpc; + } else if (column === 0) { + result += spc; + } + result += s; + } else { + // A nested block + result += printBlock( + { indent: block.indent + s.indent, text: s.text }, + maxWidth, + hangingIndent, + ); + } } + return result; +} - private hangingNl() { - this.nl(); - this.#isHanging = true; - } +function getColumn(text: string): number { + let afterLastNewline = text.lastIndexOf("\n") + 1; + return text.length - afterLastNewline; } function firstLineLength(s: string): number { diff --git a/tools/packages/wesl/src/debug/ScopeToString.ts b/tools/packages/wesl/src/debug/ScopeToString.ts deleted file mode 100644 index c98496dca..000000000 --- a/tools/packages/wesl/src/debug/ScopeToString.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { childScope, Ident, Scope } from "../Scope.ts"; -import { attributeToString } from "./ASTtoString.ts"; -import { LineWrapper } from "./LineWrapper.ts"; - -/** A debugging print of the scope tree with identifiers in nested brackets */ -export function scopeToString( - scope: Scope, - indent = 0, - shortIdents = true, -): string { - const { contents, kind, ifAttribute } = scope; - - const str = new LineWrapper(indent); - const attrStrings = ifAttribute && attributeToString(ifAttribute); - if (attrStrings) str.add(attrStrings + " "); - if (kind === "partial") str.add("-"); - str.add("{ "); - - const last = contents.length - 1; - let lastWasScope = false; - let hasBlock = false; - contents.forEach((elem, i) => { - if (childScope(elem)) { - const childScope: Scope = elem; - const childBlock = scopeToString(childScope, indent + 2, shortIdents); - !lastWasScope && str.nl(); - str.addBlock(childBlock); - lastWasScope = true; - hasBlock = true; - } else { - lastWasScope && str.add(" "); - lastWasScope = false; - const ident: Ident = elem; - if (shortIdents) { - str.add(identShortString(ident)); - } else { - str.add(identToString(ident)); - } - if (i < last) str.add(" "); - } - }); - - if (!hasBlock && str.oneLine) { - str.add(" }"); - } else { - if (hasBlock && !lastWasScope) str.nl(); - str.add("}"); - } - - str.add(` #${scope.id}`); - - return str.result; -} - -/** A debug print of the scope tree with identifiers in long form in nested brackets */ -export function scopeToStringLong(scope: Scope): string { - return scopeToString(scope, 0, false); -} - -/** name of an identifier, with decls prefixed with '%' */ -function identShortString(ident: Ident): string { - const { kind, originalName } = ident; - const prefix = kind === "decl" ? "%" : ""; - return `${prefix}${originalName}`; -} - -export function identToString(ident?: Ident): string { - if (!ident) return JSON.stringify(ident); - const { kind, originalName } = ident; - const idStr = ident.id ? `#${ident.id}` : ""; - if (kind === "ref") { - const ref = identToString(ident.refersTo!); - return `${originalName} ${idStr} -> ${ref}`; - } else { - const { mangledName } = ident; - const mangled = mangledName ? `(${mangledName})` : ""; - return `%${originalName}${mangled} ${idStr} `; - } -} diff --git a/tools/packages/wesl/src/index.ts b/tools/packages/wesl/src/index.ts index 52587f860..85a8dd9c0 100644 --- a/tools/packages/wesl/src/index.ts +++ b/tools/packages/wesl/src/index.ts @@ -1,11 +1,12 @@ -export * from "./debug/ASTtoString.js"; -export * from "./debug/ScopeToString.js"; -export * from "./LinkedWesl.js"; -export * from "./Linker.js"; -export { WeslStream } from "./parse/WeslStream.js"; -export * from "./ParsedRegistry.js"; -export * from "./ParseWESL.js"; -export * from "./PathUtil.js"; -export * from "./TransformBindingStructs.js"; -export * from "./WeslDevice.js"; -export * from "./WgslBundle.js"; +export * from "./debug/ASTtoString.ts"; +export * from "./lower/TranslationUnit.ts"; +export { WeslStream } from "./parse/WeslStream.ts"; +export * from "./Linker.ts"; +export * from "./LinkedWesl.ts"; +export type { Conditions } from "./Conditions.ts"; +export * from "./Mangler.ts"; +export * from "./Module.ts"; +export * from "./TransformBindingStructs.ts"; +export * from "./VirtualFilesystem.ts"; +export * from "./WeslDevice.ts"; +export * from "./WgslBundle.ts"; diff --git a/tools/packages/wesl/src/lower/CompileToWgsl.ts b/tools/packages/wesl/src/lower/CompileToWgsl.ts new file mode 100644 index 000000000..87c7d4966 --- /dev/null +++ b/tools/packages/wesl/src/lower/CompileToWgsl.ts @@ -0,0 +1,237 @@ +import { SrcMap, SrcMapBuilder } from "mini-parse"; +import { assertThat } from "../Assertions.ts"; +import { Conditions } from "../Conditions.ts"; +import { LinkedWesl } from "../LinkedWesl.ts"; +import { WeslJsPlugin } from "../Linker.ts"; +import { ManglerFn } from "../Mangler.ts"; +import { ModulePath, ModulePathString, WeslAST } from "../Module.ts"; +import { + bindSymbols, + ExportedDeclarations, + SymbolReference, + SymbolTable, +} from "../pass/SymbolsTablePass.ts"; +import { str } from "../Util.ts"; +import { lowerAndEmit } from "./LowerAndEmit.ts"; + +export interface CompilationOptions { + conditions: Conditions; + mangler: ManglerFn; + + plugins?: WeslJsPlugin[]; +} + +interface ItemInModule { + modulePath: ModulePath; + item: string; + symbolTableId: number; + symbolRef: SymbolReference; +} + +/** + * Everything has been fetched at this point. Link the files together. + * + * The modules are identified by their absolute path. + * With re-exports, an import path could be different from the absolute path. + */ +export function compileToWgsl( + rootModulePath: ModulePath, + modules: ReadonlyMap, + opts: CompilationOptions, +): LinkedWesl { + const { compiledModules, symbolTables } = compileModules( + rootModulePath, + modules, + opts, + ); + // No dead code elimination by default, this gives the user more type checking of their shader. + // Dead code elimination should + // - be benchmarked. I suspect that doing dead code elimination might be slower than not doing it. + // - be implemented in an optional pass. + // The user should be able to choose between "precompile shader with max space savings" and + // "fully debug shader with naga" + + let result = new SrcMap(); + + // For the name mangling + const globalNames = new Set(); + + // I could do the entire symbol table pass right now + // HOWEVER compiledModules includes more modules than will actually be emitted + opts.mangler(dependency.item, dependency.modulePath.path, globalNames); + + const emittedModules = new Set(); + function emitModules(modulePathString: ModulePathString) { + if (emittedModules.has(modulePathString)) return; + emittedModules.add(modulePathString); + + const compiledModule = compiledModules.get(modulePathString); + assertThat(compiledModule !== undefined); + + const { moduleElem, srcModule } = compiledModule.result; + lowerAndEmit( + moduleElem, + result.builderFor({ + text: srcModule.src, + path: srcModule.debugFilePath, + }), + { + conditions: opts.conditions, + isRoot: emittedModules.size === 1, + tables: symbolTables, + tableId: compiledModule.symbolTableId, + }, + ); + + // Rely on the dependencies set being sorted + for (const dependency of compiledModule.dependencies) { + emitModules(dependency); + } + } + emitModules(rootModulePath.toString()); + + return new LinkedWesl(result); +} + +interface CompiledModule extends CompiledSingleModule { + symbolTableId: number; + /** Which modules does this module depend on */ + dependencies: Set; +} + +/** + * Compiles all modules that the root module needs. + * Also resolves the symbol table. No more imports. + */ +function compileModules( + rootModulePath: ModulePath, + modules: ReadonlyMap, + opts: CompilationOptions, +) { + const packageNames = getPackageNames(modules); + const compiledModules = new Map(); + + const resolveCache = new Map(); + function resolveAndCompile(modulePath: string[], item: string): ItemInModule { + let cacheKey = modulePath.toString() + "::" + item; + let cachedValue = resolveCache.get(cacheKey); + if (cachedValue !== undefined) { + return cachedValue; + } + + for (let i = 1; i < modulePath.length; i++) { + // query if it's a valid module, starting with the root + let partPath = modulePath + .slice(0, i) + .join("::") satisfies string as ModulePathString; + if (!modules.has(partPath)) { + continue; + } + let partModule = compileModule(partPath); + let exportedDecl = partModule.exportedDeclarations.get(item); + if (exportedDecl !== undefined) { + // We don't have re-exports yet, so we can just check whether we are at the end + if (i === modulePath.length - 1) { + const result: ItemInModule = { + modulePath: partModule.result.srcModule.modulePath, + item, + symbolTableId: partModule.symbolTableId, + symbolRef: exportedDecl.symbol, + }; + resolveCache.set(cacheKey, result); + return result; + } else { + throw new Error( + str`Item ${item} already found in ${partPath}, but ${modulePath.join("::")} was requested.`, + ); + } + } + } + throw new Error(str`Item ${item} not found in ${modulePath.join("::")}`); + } + + const symbolTables: SymbolTable[] = []; + + /** Compile a module and its dependencies */ + function compileModule(modulePathString: ModulePathString): CompiledModule { + const cached = compiledModules.get(modulePathString); + if (cached !== undefined) { + return cached; + } + const module = modules.get(modulePathString); + if (module === undefined) { + throw new Error(str`Could not find module ${modulePathString}`); + } + const symbolTableId = symbolTables.length; + const dependencies = new Set(); + + const compiled: CompiledModule = { + ...compileSingleModule(module, opts, packageNames), + symbolTableId, + dependencies, + }; + symbolTables.push(compiled.symbolTable); + compiledModules.set(modulePathString, compiled); + + for (let i = 0; i < compiled.symbolTable.length; i++) { + const symbol = compiled.symbolTable[i]; + if (symbol.kind === "import") { + const dependency = resolveAndCompile(symbol.module, symbol.value); + dependencies.add(dependency.modulePath.toString()); + compiled.symbolTable[i] = { + kind: "extern", + table: dependency.symbolTableId, + index: dependency.symbolRef, + }; + } + } + return compiled; + } + compileModule(rootModulePath.toString()); + + return { + compiledModules, + symbolTables, + }; +} + +interface CompiledSingleModule { + /** The symbols table for this module */ + symbolTable: SymbolTable; + + /** The public declarations in this module */ + exportedDeclarations: ExportedDeclarations; + + /** The mutated AST */ + result: WeslAST; +} + +function compileSingleModule( + module: WeslAST, + opts: CompilationOptions, + packageNames: string[], +): CompiledSingleModule { + module = structuredClone(module); + const boundModule = bindSymbols( + module.moduleElem, + opts.conditions, + packageNames, + ); + // TODO: apply the passes (including the plugins) + + return { + symbolTable: boundModule.symbolsTable, + exportedDeclarations: boundModule.exportedDeclarations, + result: module, + }; +} + +function getPackageNames( + modules: ReadonlyMap, +): string[] { + return [ + ...new Set( + Iterator.from(modules.values()).map(v => v.srcModule.modulePath.path[0]), + ), + ]; +} diff --git a/tools/packages/wesl/src/lower/LowerAndEmit.ts b/tools/packages/wesl/src/lower/LowerAndEmit.ts new file mode 100644 index 000000000..d8cea5152 --- /dev/null +++ b/tools/packages/wesl/src/lower/LowerAndEmit.ts @@ -0,0 +1,369 @@ +import { SrcMapBuilder } from "mini-parse"; +import { Conditions, evaluateIfAttribute } from "../Conditions"; +import { WeslAST } from "../Module"; +import { + AliasElem, + AttributeElem, + ConstAssertElem, + DeclarationElem, + DeclIdent, + FunctionDeclarationElem, + GlobalDeclarationElem, + ModuleElem, + StructElem, +} from "../parse/WeslElems"; +import { ManglerFn } from "../Mangler"; +import { DirectiveElem } from "../parse/DirectiveElem"; +import { assertUnreachable } from "../Assertions"; +import { getSymbol, SymbolTable } from "../pass/SymbolsTablePass"; +import { assertThat } from "../../../mini-parse/src/Assertions"; +import { str } from "../Util"; + +/** passed to the emitters */ +interface EmitContext { + /** constructing the linked output */ + srcBuilder: SrcMapBuilder; + opts: EmitOptions; +} + +export interface EmitOptions { + conditions: Conditions; + isRoot: boolean; + tables: SymbolTable[]; + tableId: number; +} + +/** traverse the AST, starting from root elements, emitting wgsl for each */ +export function lowerAndEmit( + module: ModuleElem, + srcBuilder: SrcMapBuilder, + opts: EmitOptions, +) { + emitModule(module, { srcBuilder, opts }); +} + +function emitModule(module: ModuleElem, ctx: EmitContext): void { + if (ctx.opts.isRoot) { + for (const directive of module.directives) { + emitDirective(directive, ctx); + } + } + for (const decl of module.declarations) { + emitDecl(decl, ctx); + } +} + +function emitDecl(e: GlobalDeclarationElem, ctx: EmitContext): void { + if (!evaluateIfAttribute(ctx.opts.conditions, e.attributes)) { + return; + } + + emitAttributes(e.attributes, ctx); + if (e.kind === "alias") { + emitAlias(e, ctx); + } else if (e.kind === "assert") { + emitAssert(e, ctx); + } else if (e.kind === "declaration") { + emitDeclaration(e, ctx); + } else if (e.kind === "function") { + emitFunction(e, ctx); + } else if (e.kind === "struct") { + emitStruct(e, ctx); + } else { + assertUnreachable(e); + } +} +function emitAlias(e: AliasElem, ctx: EmitContext) { + ctx.srcBuilder.addRange("alias", e.span[0]); + ctx.srcBuilder.addSynthetic(" "); + emitDeclIdent(e.name, ctx); + ctx.srcBuilder.addSynthetic(" = "); + emitTemplatedIdentElem(e.type, ctx); + ctx.srcBuilder.addRange(";", e.span[1] - 1); + ctx.srcBuilder.addSynthetic("\n"); +} + +function emitAssert(e: ConstAssertElem, ctx: EmitContext) { + throw new Error("Function not implemented."); +} + +function emitDeclaration(e: DeclarationElem, ctx: EmitContext) { + throw new Error("Function not implemented."); +} + +function emitFunction(e: FunctionDeclarationElem, ctx: EmitContext) { + throw new Error("Function not implemented."); +} + +function emitStruct(e: StructElem, ctx: EmitContext) { + throw new Error("Function not implemented."); +} + +function lowerAndEmitElem(e: AbstractElem, ctx: EmitContext): void { + switch (e.kind) { + // terminal elements copy strings to the output + case "text": + return emitText(e, ctx); + case "name": + return emitName(e, ctx); + case "synthetic": + return emitSynthetic(e, ctx); + + // identifiers are copied to the output, but with potentially mangled names + case "ref": + return emitRefIdent(e, ctx); + case "decl": + return emitDeclIdent(e, ctx); + + // container elements just emit their child elements + case "param": + case "var": + case "typeDecl": + case "let": + case "member": + case "type": + case "stuff": + return emitContents(e, ctx); + + case "module": + return emitModule(e, ctx); + + // root level container elements get some extra newlines to make the output prettier + case "fn": + case "struct": + case "override": + case "const": + case "assert": + case "alias": + case "gvar": + if (ctx.extracting) { + ctx.srcBuilder.addNl(); + ctx.srcBuilder.addNl(); + } + return emitContents(e, ctx); + + case "attribute": + return emitAttribute(e, ctx); + + default: + assertUnreachable(e); + } +} + +// TODO: Remove this (once we've got our entire AST parsing) +export function emitText(e: TextElem, ctx: EmitContext): void { + ctx.srcBuilder.addCopy(e.span); +} + +export function emitName(e: NameElem, ctx: EmitContext): void { + ctx.srcBuilder.add(e.name, e.span); +} + +export function emitSynthetic(e: SyntheticElem, ctx: EmitContext): void { + const { text } = e; + ctx.srcBuilder.addSynthetic(text, text, [0, text.length]); +} + +export function emitContents(elem: ContainerElem, ctx: EmitContext): void { + elem.contents.forEach(e => lowerAndEmitElem(e, ctx)); +} + +export function emitModule(elem: ModuleElem, ctx: EmitContext): void { + elem.directives.forEach(e => emitDirective(e, ctx)); + elem.declarations.forEach(e => lowerAndEmitElem(e, ctx)); + + // TODO: Remove + elem.contents.forEach(e => lowerAndEmitElem(e, ctx)); +} + +export function emitRefIdent(e: RefIdentElem, ctx: EmitContext): void { + if (e.ident.std) { + ctx.srcBuilder.add(e.ident.originalName, e.span); + } else { + const declIdent = findDecl(e.ident); + const mangledName = displayName(declIdent); + ctx.srcBuilder.add(mangledName!, e.span); + } +} + +function emitDeclIdent(e: DeclIdent, ctx: EmitContext): void { + if (e.symbolRef === null) { + ctx.srcBuilder.add(e.name, e.span, true); + } else { + const symbol = getSymbol(ctx.opts.tables, ctx.opts.tableId, e.symbolRef); + assertThat( + symbol.kind === "name", + "Compilation step should have resolved this import", + ); + ctx.srcBuilder.add(symbol.value, e.span, true); + } +} + +function emitAttributes( + e: AttributeElem[] | undefined, + ctx: EmitContext, +): void { + e?.forEach(v => emitAttribute(v, ctx)); +} + +function emitAttribute(e: AttributeElem, ctx: EmitContext): void { + const { kind } = e.attribute; + if (kind === "attribute") { + const { params } = e.attribute; + if (params.length === 0) { + ctx.srcBuilder.add(str`@${e.attribute.name}`, e.span); + } else { + ctx.srcBuilder.add(str`@${e.attribute.name}(`, [e.span[0], params[0]); + ctx.srcBuilder.add( + "@" + + e.attribute.name + + "(" + + params.map(expressionToString).join(", ") + + ")", + e.span, + ); + ctx.srcBuilder.addRange(")", e.span[1] - 1); + } + } else if (kind === "@builtin") { + ctx.srcBuilder.add("@builtin(" + e.attribute.param.name + ")", e.span); + } else if (kind === "@diagnostic") { + ctx.srcBuilder.add( + "@diagnostic" + + diagnosticControlToString(e.attribute.severity, e.attribute.rule), + e.span, + ); + } else if (kind === "@if") { + ctx.srcBuilder.add( + `@if(${expressionToString(e.attribute.param.expression)})`, + e.span, + ); + } else if (kind === "@interpolate") { + ctx.srcBuilder.add( + `@interpolate(${e.attribute.params.map(v => v.name).join(", ")})`, + e.span, + ); + } else { + assertUnreachable(kind); + } +} + +export function diagnosticControlToString( + severity: NameElem, + rule: [NameElem, NameElem | null], +): string { + const ruleStr = rule[0].name + (rule[1] !== null ? "." + rule[1].name : ""); + return `(${severity.name}, ${ruleStr})`; +} + +export function expressionToString(elem: ExpressionElem): string { + const { kind } = elem; + if (kind === "binary-expression") { + return `${expressionToString(elem.left)} ${elem.operator.value} ${expressionToString(elem.right)}`; + } else if (kind === "unary-expression") { + return `${elem.operator.value}${expressionToString(elem.expression)}`; + } else if (kind === "templated-ident") { + return templatedIdentToString(elem); + } else if (kind === "literal") { + return elem.value; + } else if (kind === "name") { + return elem.name; + } else if (kind === "parenthesized-expression") { + return `(${expressionToString(elem.expression)})`; + } else if (kind === "component-expression") { + return `${expressionToString(elem.base)}[${expressionToString(elem.access)}]`; + } else if (kind === "component-member-expression") { + return `${expressionToString(elem.base)}.${elem.access.name}`; + } else if (kind === "call-expression") { + return `${expressionToString(elem.function)}(${elem.arguments.map(expressionToString).join(", ")})`; + } else { + assertUnreachable(kind); + } +} + +export function templatedIdentToString(elem: TemplatedIdentElem): string { + let name = elem.ident.name; + if (elem.path !== undefined && elem.path.length > 0) { + name = elem.path.map(p => p.name).join("::") + "::" + name; + } + let params = ""; + if (elem.template !== undefined) { + const paramStrs = elem.template.map(expressionToString).join(", "); + params = "<" + paramStrs + ">"; + } + return name + params; +} + +export function lhsExpressionToString(elem: LhsExpression): string { + const { kind } = elem; + if (kind === "unary-expression") { + return `${elem.operator.value}${lhsExpressionToString(elem.expression)}`; + } else if (kind === "lhs-ident") { + return elem.name.name; + } else if (kind === "parenthesized-expression") { + return `(${lhsExpressionToString(elem.expression)})`; + } else if (kind === "component-expression") { + return `${lhsExpressionToString(elem.base)}[${expressionToString(elem.access)}]`; + } else if (kind === "component-member-expression") { + return `${lhsExpressionToString(elem.base)}.${elem.access.name}`; + } else { + assertUnreachable(kind); + } +} + +function templateToString(template: ExpressionElem[] | undefined): string { + if (template === undefined) return ""; + if (template.length === 0) return ""; + + return "<" + template.map(expressionToString).join(", ") + ">"; +} + +function emitDirective(e: DirectiveElem, ctx: EmitContext): void { + const { directive } = e; + const { kind } = directive; + if (kind === "diagnostic") { + ctx.srcBuilder.add( + `diagnostic${diagnosticControlToString(directive.severity, directive.rule)};`, + e.span, + ); + } else if (kind === "enable") { + ctx.srcBuilder.add( + `enable ${directive.extensions.map(v => v.name).join(", ")};`, + e.span, + ); + } else if (kind === "requires") { + ctx.srcBuilder.add( + `requires ${directive.extensions.map(v => v.name).join(", ")};`, + e.span, + ); + } else { + assertUnreachable(kind); + } +} + +function displayName(declIdent: DeclIdent): string { + if (isGlobal(declIdent)) { + // mangled name was set in binding step + const mangledName = declIdent.mangledName; + if (tracing && !mangledName) { + console.log( + "ERR: mangled name not found for decl ident", + identToString(declIdent), + ); + } + return mangledName!; + } + + return declIdent.mangledName || declIdent.originalName; +} + +function findDecl(ident: any) { + throw new Error("Function not implemented."); +} + +function isGlobal(declIdent: DeclIdent) { + throw new Error("Function not implemented."); +} + +function identToString(declIdent: DeclIdent): any { + throw new Error("Function not implemented."); +} diff --git a/tools/packages/wesl/src/lower/TranslationUnit.ts b/tools/packages/wesl/src/lower/TranslationUnit.ts new file mode 100644 index 000000000..a97653b37 --- /dev/null +++ b/tools/packages/wesl/src/lower/TranslationUnit.ts @@ -0,0 +1,125 @@ +import { ParserInit } from "mini-parse"; +import { LinkedWesl } from "../LinkedWesl.ts"; +import { CompilationOptions, compileToWgsl } from "./CompileToWgsl.ts"; +import { ModulePath, ModulePathString, SrcModule, WeslAST } from "../Module.ts"; +import { WeslStream } from "../parse/WeslStream.ts"; +import { normalize, noSuffix } from "../PathUtil.ts"; +import { str } from "../Util.ts"; +import { weslRoot } from "../parse/WeslGrammar.ts"; +import { VirtualFilesystem } from "../VirtualFilesystem.ts"; + +export class TranslationUnit { + /** Key is module path, turned into a string */ + private modules = new Map(); + + constructor(filesystem: VirtualFilesystem) { + // We ignore @if blocks during the "fetch => parse => find all imports (including inline usages) => kick off more fetches" operation + // Inline usages can be dependent on conditions, but pre-bundling gets to ignore that. + // I had a cool partial condition evaluator on a branch, but long term, + // I really don't want to maintain a complex conditional compilation infrastructure + // ... + // Also, libraries do not need to be added up front + // It's entirely possible to switch out a library on the fly, just like one would switch out a single module on the fly. + } + + addModule(module: SrcModule): WeslAST { + const result = parseSrcModule(module); + this.modules.set(module.modulePath.toString(), result); + return result; + } + + /* LATER + async prebundleModules(rootModulePath: string[]) { + const seenModules = new Set(this.modules.keys()); + + // the logic to dispatch a free worker with "parse(moduleSrc)" would go here + // for now we're parsing it on the main thread + const addModule = async (module: SrcModule) => this.addModule(module); + + // Maybe it shouldn't be "modulePath" but instead be a file path? + async function fetchAndParseModuleInner(modulePath: string[]) { + // to do: Add the virtual filesystem to the parsed registry + const moduleSource = await this.virtualFilesystem.fetch(modulePath); + const ast = await addModule(moduleSource); + + const childPromises: Promise[] = []; + // to do: imports are dependent on conditional compilation, and also include inline usages + // of course we're just skipping the conditional compilation here + for(const importElem of ast.moduleElem.imports) { + if(seenModules.has(importElem)) { + // skip + } else { + seenModules.add(importElem); + childPromises.push(fetchAndParseModule(import)); + } + } + await Promise.all(childPromises); + } + + await(fetchAndParseModuleInner(rootModulePath)); + }*/ + + compile(rootModulePath: ModulePath, opts: CompilationOptions): LinkedWesl { + return compileToWgsl(rootModulePath, this.modules, opts); + } + + /** Gets a module via its absolute path */ + getModule(modulePath: ModulePath): WeslAST | null { + return this.modules.get(modulePath.toString()) ?? null; + } + + getModules(): WeslAST[] { + return [...this.modules.values()]; + } + + toDebugString(): string { + return `modules: ${[...Object.keys(this.modules)]}`; + } +} + +export function parsedRegistry(): TranslationUnit { + return new TranslationUnit(); +} + +/** + * @param srcFiles map of source strings by file path + * key is '/' separated relative path (relative to srcRoot, not absolute file path ) + * value is wesl source string + * @param registry add parsed modules to this registry + * @param packageName name of package + */ +export function parseIntoRegistry( + srcFiles: Record, + registry: TranslationUnit, + packageName: string = "package", + debugWeslRoot?: string, +): void { + if (debugWeslRoot === undefined) { + debugWeslRoot = ""; + } else if (!debugWeslRoot.endsWith("/")) { + debugWeslRoot += "/"; + } + Object.entries(srcFiles).forEach(([filePath, src]) => { + const modulePath = fileToModulePath(filePath, packageName); + const debugFilePath = debugWeslRoot + filePath; + if (registry.getModule(modulePath)) { + throw new Error(str`duplicate module path: '${modulePath}'`); + } + registry.addModule({ + modulePath, + debugFilePath, + src, + }); + }); +} + +export function parseSrcModule(srcModule: SrcModule): WeslAST { + const stream = new WeslStream(srcModule.src); + const init: ParserInit = { stream }; + const parseResult = weslRoot.parse(init); + if (parseResult === null) { + throw new Error("parseWESL failed"); + } + + return { srcModule, moduleElem: parseResult.value }; +} diff --git a/tools/packages/wesl/src/parse/BaseGrammar.ts b/tools/packages/wesl/src/parse/BaseGrammar.ts new file mode 100644 index 000000000..7bf576224 --- /dev/null +++ b/tools/packages/wesl/src/parse/BaseGrammar.ts @@ -0,0 +1,50 @@ +import { + or, + Parser, + repeatPlus, + seq, + Stream, + terminated, + token, + tokenKind, + withSep, +} from "mini-parse"; +import { WeslToken } from "./WeslStream.ts"; +import { FullIdent, NameElem } from "./WeslElems.ts"; + +export const name = tokenKind("word").map(makeName); +export const symbol = (symbol: string) => token("symbol", symbol); + +const full_ident_continue = withSep(symbol("::"), tokenKind("word"), { + requireOne: true, + trailing: false, +}); +export const full_ident: WeslParser = or( + seq(terminated(token("keyword", "package"), "::"), full_ident_continue).map( + ([a, b]) => [a, ...b], + ), + seq( + repeatPlus(terminated(token("keyword", "super"), "::")), + full_ident_continue, + ).map(([a, b]) => [...a, ...b]), + full_ident_continue, +).map(makeFullIdent); + +export type WeslParser = Parser, T>; + +function makeName(token: WeslToken<"word">): NameElem { + return { + kind: "name", + name: token.text, + span: token.span, + }; +} + +function makeFullIdent( + tokens: (WeslToken<"keyword"> | WeslToken<"word">)[], +): FullIdent { + return { + segments: tokens.map(v => v.text), + span: [tokens[0].span[0], tokens[tokens.length - 1].span[1]], + }; +} diff --git a/tools/packages/wesl/src/parse/DirectiveElem.ts b/tools/packages/wesl/src/parse/DirectiveElem.ts new file mode 100644 index 000000000..2948c05c1 --- /dev/null +++ b/tools/packages/wesl/src/parse/DirectiveElem.ts @@ -0,0 +1,23 @@ +import { Span } from "mini-parse"; +import { AttributeElem, NameElem } from "./WeslElems.ts"; + +export interface DirectiveElem { + kind: "directive"; + attributes: AttributeElem[]; + directive: DiagnosticDirective | EnableDirective | RequiresDirective; + span: Span; +} + +export interface DiagnosticDirective { + kind: "diagnostic"; + severity: NameElem; + rule: [NameElem, NameElem | null]; +} +export interface EnableDirective { + kind: "enable"; + extensions: NameElem[]; +} +export interface RequiresDirective { + kind: "requires"; + extensions: NameElem[]; +} diff --git a/tools/packages/wesl/src/parse/ExpressionElem.ts b/tools/packages/wesl/src/parse/ExpressionElem.ts new file mode 100644 index 000000000..19fd1ff22 --- /dev/null +++ b/tools/packages/wesl/src/parse/ExpressionElem.ts @@ -0,0 +1,82 @@ +import type { Span } from "mini-parse"; +import type { NameElem, FullIdent } from "./WeslElems.ts"; + +/** Inspired by https://github.com/wgsl-tooling-wg/wesl-rs/blob/3b2434eac1b2ebda9eb8bfb25f43d8600d819872/crates/wgsl-parse/src/syntax.rs#L364 */ +export type ExpressionElem = + | Literal + | TemplatedIdentElem + | ParenthesizedExpression + | ComponentExpression + | ComponentMemberExpression + | UnaryExpression + | BinaryExpression + | FunctionCallExpression; + +/** A literal value in WESL source. A boolean or a number. */ +export interface Literal { + kind: "literal"; + value: string; + span: Span; +} + +/** an identifier with template arguments */ +export interface TemplatedIdentElem { + kind: "templated-ident"; + /** + * A symbol can either point at an entry in the symbol table, or + * - not be set yet (after parsing) + * - be a predeclared identifier + */ + symbolRef: number | null; + ident: FullIdent; + template?: ExpressionElem[]; + span: Span; +} + +/** (expr) */ +export interface ParenthesizedExpression { + kind: "parenthesized-expression"; + expression: ExpressionElem; +} +/** `foo[expr]` */ +export interface ComponentExpression { + kind: "component-expression"; + base: ExpressionElem; + access: ExpressionElem; +} +/** `foo.member` */ +export interface ComponentMemberExpression { + kind: "component-member-expression"; + base: ExpressionElem; + access: NameElem; +} +/** `+foo` */ +export interface UnaryExpression { + kind: "unary-expression"; + operator: UnaryOperator; + expression: ExpressionElem; +} +/** `foo + bar` */ +export interface BinaryExpression { + kind: "binary-expression"; + operator: BinaryOperator; + left: ExpressionElem; + right: ExpressionElem; +} +/** `foo(arg, arg)` */ +export interface FunctionCallExpression { + kind: "call-expression"; + function: TemplatedIdentElem; + arguments: ExpressionElem[]; +} +export interface UnaryOperator { + value: "!" | "&" | "*" | "-" | "~"; + span: Span; +} +export interface BinaryOperator { + value: + | ("||" | "&&" | "+" | "-" | "*" | "/" | "%" | "==") + | ("!=" | "<" | "<=" | ">" | ">=" | "|" | "&" | "^") + | ("<<" | ">>"); + span: Span; +} diff --git a/tools/packages/wesl/src/parse/ExpressionGrammar.ts b/tools/packages/wesl/src/parse/ExpressionGrammar.ts new file mode 100644 index 000000000..55922b44e --- /dev/null +++ b/tools/packages/wesl/src/parse/ExpressionGrammar.ts @@ -0,0 +1,465 @@ +import { + delimited, + fn, + opt, + or, + preceded, + repeat, + repeatPlus, + req, + seq, + span, + Span, + token, + tokenKind, + tokenOf, + tracing, + withSep, + withSepPlus, + yes, +} from "mini-parse"; +import { full_ident, name, WeslParser } from "./BaseGrammar.ts"; +import { + templateClose, + templateOpen, + weslExtension, + WeslToken, +} from "./WeslStream.ts"; +import { + BinaryExpression, + BinaryOperator, + ExpressionElem, + FunctionCallExpression, + Literal, + ParenthesizedExpression, + TemplatedIdentElem, + UnaryExpression, + UnaryOperator, +} from "./ExpressionElem.ts"; +import { + FullIdent, + LhsExpression, + LhsIdentElem, + LhsParenthesizedExpression, + LhsUnaryExpression, + LhsUnaryOperator, + NameElem, +} from "./WeslElems.ts"; + +const literal = or( + tokenOf("keyword", ["true", "false"]), + tokenKind("number"), +).map(makeLiteral); + +const paren_expression = delimited( + "(", + req(fn(() => expression)), + req(")"), +).map(makeParenthesizedExpression); + +const primary_expression: WeslParser = or( + literal, + paren_expression, + seq( + fn(() => templated_ident), + opt(fn(() => argument_expression_list)), + ).map(tryMakeFunctionCall), +); +export const component_or_swizzle: WeslParser<(NameElem | ExpressionElem)[]> = + repeatPlus( + or( + preceded(".", name), + delimited("[", () => expression, req("]")), + ), + ); +const unary_expression: WeslParser = or( + seq(unaryOperator(["!", "&", "*", "-", "~"]), () => unary_expression).map( + makeUnaryExpression, + ), + seq(primary_expression, opt(component_or_swizzle)).map( + tryMakeComponentOrSwizzle, + ), +); +const bitwise_post_unary: WeslParser = or( + // LATER I can skip template list discovery in these cases, because a&b => { + const shift_left = seq(binaryOperator("<<"), unary_expression).map(v => [v]); + const shift_right = seq(binaryOperator(">>"), unary_expression).map(v => [v]); + const mul_add: WeslParser = seq( + repeat(seq(multiplicative_operator, unary_expression)), + repeat( + seq( + additive_operator, + seq( + unary_expression, + repeat(seq(multiplicative_operator, unary_expression)), + ).map(makeRepeatingBinaryExpression), + ), + ), + ).map(([a, b]) => [...a, ...b]); + + return inTemplate ? + or(shift_left, mul_add) + : or(shift_left, shift_right, mul_add); +}; +const relational_post_unary = ( + inTemplate: boolean, +): WeslParser => { + return seq( + shift_post_unary(inTemplate), + opt( + seq( + // '<' is unambiguous, since templates were already caught by the primary expression inside of the previous unary_expression! + inTemplate ? + binaryOperator(["<", "<=", "!=", "=="]) + : binaryOperator([">", ">=", "<", "<=", "!=", "=="]), + // LATER I can skip template list discovery in this cases, because a>=b (b !== null ? [...a, b] : a)); +}; + +/** The expression parser exists in two variants + * `true` is template-expression: Refuses to parse parse symbols like `&&` and `||`. + * `false` is maybe-template-expression: Does the template disambiguation. + */ +const expressionParser = (inTemplate: boolean): WeslParser => { + return seq( + unary_expression, + or( + bitwise_post_unary, + seq( + relational_post_unary(inTemplate), + inTemplate ? + // Don't accept || or && in template mode + yes() + : or( + repeatPlus( + seq( + binaryOperator("||"), + seq(unary_expression, relational_post_unary(false)).map( + makeRepeatingBinaryExpression, + ), + ), + ), + repeatPlus( + seq( + binaryOperator("&&"), + seq(unary_expression, relational_post_unary(false)).map( + makeRepeatingBinaryExpression, + ), + ), + ), + yes(), + ), + ).map(([a, b]) => (b !== null ? [...a, ...b] : a)), + ), + ).map(makeRepeatingBinaryExpression); +}; + +let maybe_template = false; +export const expression = expressionParser(maybe_template); +let is_template = true; +const template_arg_expression = expressionParser(is_template); + +export const opt_template_list: WeslParser = opt( + delimited( + templateOpen, + withSepPlus(",", template_arg_expression), + req(templateClose), + ), +); + +export const templated_ident: WeslParser = span( + seq(full_ident, opt_template_list), +).map(makeTemplatedIdent); + +export const argument_expression_list = delimited( + "(", + withSep(",", expression), + req(")"), +); + +//--------- Specialized parser for @if(expr) -----------// +const attribute_if_primary_expression: WeslParser< + Literal | ParenthesizedExpression | TemplatedIdentElem +> = or( + tokenOf("keyword", ["true", "false"]).map(makeLiteral), + delimited( + token("symbol", "("), + fn(() => attribute_if_expression), + token("symbol", ")"), + ).map(makeParenthesizedExpression), + full_ident.map( + (v): TemplatedIdentElem => ({ + kind: "templated-ident", + symbolRef: null, + ident: v, + span: v.span, + }), + ), +); + +const attribute_if_unary_expression: WeslParser = or( + seq( + unaryOperator("!"), + fn(() => attribute_if_unary_expression), + ).map(makeUnaryExpression), + attribute_if_primary_expression, +); + +export const attribute_if_expression: WeslParser = + weslExtension( + seq( + attribute_if_unary_expression, + or( + repeatPlus( + seq(binaryOperator("||"), req(attribute_if_unary_expression)), + ), + repeatPlus( + seq(binaryOperator("&&"), req(attribute_if_unary_expression)), + ), + yes().map(() => []), + ), + ).map(makeRepeatingBinaryExpression), + ); + +export const lhs_expression: WeslParser = or( + seq(full_ident.map(makeLhsIdentElem), opt(component_or_swizzle)).map( + tryMakeLhsComponentOrSwizzle, + ), + seq( + delimited("(", () => lhs_expression, ")").map( + makeLhsParenthesizedExpression, + ), + opt(component_or_swizzle), + ).map(tryMakeLhsComponentOrSwizzle), + seq(lhsUnaryOperator("&"), () => lhs_expression).map(makeLhsUnaryExpression), + seq(lhsUnaryOperator("*"), () => lhs_expression).map(makeLhsUnaryExpression), +); + +function tryMakeFunctionCall([ident, args]: [ + TemplatedIdentElem, + ExpressionElem[] | null, +]): TemplatedIdentElem | FunctionCallExpression { + if (args !== null) { + return { + kind: "call-expression", + function: ident, + arguments: args, + }; + } else { + return ident; + } +} + +// LATER how do I combine the two? +function tryMakeComponentOrSwizzle([expression, componentOrSwizzle]: [ + ExpressionElem, + (NameElem | ExpressionElem)[] | null, +]): ExpressionElem { + if (componentOrSwizzle === null || componentOrSwizzle.length === 0) { + return expression; + } + let result = expression; + for (const v of componentOrSwizzle) { + if (v.kind === "name") { + result = { + kind: "component-member-expression", + access: v, + base: result, + }; + } else { + result = { + kind: "component-expression", + access: v, + base: result, + }; + } + } + return result; +} +function tryMakeLhsComponentOrSwizzle([expression, componentOrSwizzle]: [ + LhsExpression, + (NameElem | ExpressionElem)[] | null, +]): LhsExpression { + if (componentOrSwizzle === null || componentOrSwizzle.length === 0) { + return expression; + } + let result = expression; + for (const v of componentOrSwizzle) { + if (v.kind === "name") { + result = { + kind: "component-member-expression", + access: v, + base: result, + }; + } else { + result = { + kind: "component-expression", + access: v, + base: result, + }; + } + } + return result; +} + +function makeTemplatedIdent({ + value: [full_ident, template], + span, +}: { + value: [FullIdent, ExpressionElem[] | null]; + span: Span; +}): TemplatedIdentElem { + return { + kind: "templated-ident", + span, + symbolRef: null, + ident: full_ident, + template: template ?? undefined, + }; +} +function makeLhsIdentElem(full_ident: FullIdent): LhsIdentElem { + return { + kind: "lhs-ident", + symbolRef: null, + name: full_ident, + }; +} + +function makeLiteral(token: WeslToken<"keyword" | "number">): Literal { + return { + kind: "literal", + value: token.text, + span: token.span, + }; +} + +function makeParenthesizedExpression( + expression: ExpressionElem, +): ParenthesizedExpression { + return { + kind: "parenthesized-expression", + expression, + }; +} +function makeLhsParenthesizedExpression( + expression: LhsExpression, +): LhsParenthesizedExpression { + return { + kind: "parenthesized-expression", + expression, + }; +} + +function unaryOperator( + text: UnaryOperator["value"] | UnaryOperator["value"][], +): WeslParser { + return ( + Array.isArray(text) ? + tokenOf("symbol", text) + : token("symbol", text)).map(token => ({ + value: token.text as any, + span: token.span, + })); +} +function lhsUnaryOperator( + text: LhsUnaryOperator["value"] | LhsUnaryOperator["value"][], +): WeslParser { + return ( + Array.isArray(text) ? + tokenOf("symbol", text) + : token("symbol", text)).map(token => ({ + value: token.text as any, + span: token.span, + })); +} + +function binaryOperator( + text: BinaryOperator["value"] | BinaryOperator["value"][], +): WeslParser { + return ( + Array.isArray(text) ? + tokenOf("symbol", text) + : token("symbol", text)).map(token => ({ + value: token.text as any, + span: token.span, + })); +} +function makeUnaryExpression([operator, expression]: [ + UnaryOperator, + ExpressionElem, +]): UnaryExpression { + return { + kind: "unary-expression", + operator, + expression, + }; +} +function makeLhsUnaryExpression([operator, expression]: [ + LhsUnaryOperator, + LhsExpression, +]): LhsUnaryExpression { + return { + kind: "unary-expression", + operator, + expression, + }; +} + +type PartialBinaryExpression = [BinaryOperator, ExpressionElem]; +/** A list of left-to-right associative binary expressions */ +function makeRepeatingBinaryExpression([start, repeating]: [ + ExpressionElem, + PartialBinaryExpression[], +]): ExpressionElem { + let result: ExpressionElem = start; + for (const [op, left] of repeating) { + result = makeBinaryExpression([result, op, left]); + } + return result; +} +function makeBinaryExpression([left, operator, right]: [ + ExpressionElem, + BinaryOperator, + ExpressionElem, +]): BinaryExpression { + return { + kind: "binary-expression", + operator, + left, + right, + }; +} + +if (tracing) { + const names: Record> = { + argument_expression_list, + templated_ident, + opt_template_list, + literal, + paren_expression, + primary_expression, + component_or_swizzle, + unary_expression, + expression, + template_arg_expression, + }; + + Object.entries(names).forEach(([name, parser]) => { + parser.setTraceName(name); + }); +} diff --git a/tools/packages/wesl/src/parse/ImportElems.ts b/tools/packages/wesl/src/parse/ImportElems.ts new file mode 100644 index 000000000..afcdd37b7 --- /dev/null +++ b/tools/packages/wesl/src/parse/ImportElems.ts @@ -0,0 +1,48 @@ +import { Span } from "mini-parse"; +import { AttributeElem } from "./WeslElems.ts"; + +/** Holds an import statement, and has a span */ +export interface ImportElem { + kind: "import"; + attributes: AttributeElem[]; + imports: ImportStatement; + span: Span; +} + +/** + * An import statement, which is tree shaped. + * `import foo::bar::{baz, cat as neko}; + */ +export interface ImportStatement { + kind: "import-statement"; + segments: ImportSegment[]; + finalSegment: ImportCollection | ImportItem; +} + +/** + * A collection of import trees. + * `{baz, cat as neko}` + */ +export interface ImportSegment { + kind: "import-segment"; + name: string; +} + +/** + * A primitive segment in an import statement. + * `foo` + */ +export interface ImportCollection { + kind: "import-collection"; + subtrees: ImportStatement[]; +} + +/** + * A renamed item at the end of an import statement. + * `cat as neko` + */ +export interface ImportItem { + kind: "import-item"; + name: string; + as?: string; +} diff --git a/tools/packages/wesl/src/parse/ImportGrammar.ts b/tools/packages/wesl/src/parse/ImportGrammar.ts index 8cafb816d..837246244 100644 --- a/tools/packages/wesl/src/parse/ImportGrammar.ts +++ b/tools/packages/wesl/src/parse/ImportGrammar.ts @@ -1,34 +1,32 @@ import { delimited, fn, + kind, opt, or, Parser, preceded, - repeat, repeatPlus, req, seq, seqObj, span, Stream, - tagScope, terminated, tracing, withSepPlus, yes, } from "mini-parse"; -import type { +import { assertUnreachable } from "../Assertions.ts"; +import { WeslToken } from "./WeslStream.ts"; +import { + ImportSegment, ImportCollection, - ImportElem, ImportItem, - ImportSegment, ImportStatement, -} from "../AbstractElems.js"; -import { assertUnreachable } from "../Assertions.js"; -import { importElem } from "../WESLCollect.js"; -import { word } from "./WeslBaseGrammar.js"; -import { WeslToken } from "./WeslStream.js"; + ImportElem, +} from "./ImportElems.ts"; +import { WeslParser } from "./BaseGrammar.ts"; function makeStatement( segments: ImportSegment[], @@ -56,6 +54,8 @@ function prependSegments( return statement; } +const word = kind("word"); + // forward references for mutual recursion let import_collection: Parser< Stream, @@ -110,7 +110,7 @@ const import_relative = or( ), ); -const import_statement = span( +export const import_statement: WeslParser = span( delimited( "import", seqObj({ @@ -131,24 +131,18 @@ const import_statement = span( ).map( (v): ImportElem => ({ kind: "import", + attributes: [], // LATER Parse and fill in imports: v.value, - start: v.span[0], - end: v.span[1], + span: v.span, }), ); -/** parse a WESL style wgsl import statement. */ -export const weslImports: Parser, ImportElem[]> = tagScope( - repeat(import_statement).ptag("owo").collect(importElem), -); - if (tracing) { const names: Record, unknown>> = { import_collection, import_path_or_item, import_relative, import_statement, - weslImports, }; Object.entries(names).forEach(([name, parser]) => { diff --git a/tools/packages/wesl/src/parse/WeslBaseGrammar.ts b/tools/packages/wesl/src/parse/WeslBaseGrammar.ts deleted file mode 100644 index acb70b38d..000000000 --- a/tools/packages/wesl/src/parse/WeslBaseGrammar.ts +++ /dev/null @@ -1,8 +0,0 @@ -import { kind, or, withSepPlus } from "mini-parse"; -import { WeslTokenKind } from "./WeslStream"; - -export const word = kind("word"); -export const keyword = kind("keyword"); - -export const qualified_ident = withSepPlus("::", or(word, "package", "super")); // LATER consider efficiency (it's a pretty hot area of the grammar.) -export const number = kind("number"); diff --git a/tools/packages/wesl/src/parse/WeslElems.ts b/tools/packages/wesl/src/parse/WeslElems.ts new file mode 100644 index 000000000..f483cdbd4 --- /dev/null +++ b/tools/packages/wesl/src/parse/WeslElems.ts @@ -0,0 +1,385 @@ +import type { Span } from "mini-parse"; +import type { ImportElem } from "./ImportElems.ts"; +import type { DirectiveElem } from "./DirectiveElem.ts"; +import { ExpressionElem, TemplatedIdentElem } from "./ExpressionElem.ts"; + +/** A name is either a string, or refers to an entry in the symbols table. */ +// export type SymbolReference = string | number; + +/** a name that doesn't need to be an Ident + * e.g. + * - a struct member name + * - a diagnostic rule name + * - an enable-extension name + * - an interpolation sampling name + */ +export interface NameElem { + kind: "name"; + name: string; + span: Span; +} + +/** + * Either a single ident `foo`, or a qualified ident, like `package::foo::bar` + * + * Implementation detail: We only treat imports with 2 or more elements as a possible package reference. + */ +export interface FullIdent { + segments: string[]; + span: Span; +} + +/** an identifier declaration */ +export interface DeclIdent { + /** Null before the symbols table pass */ + symbolRef: number | null; + name: string; + span: Span; +} + +/** a wesl module */ +export interface ModuleElem { + kind: "module"; + + /** imports found in this module */ + imports: ImportElem[]; + + /** directives found in this module */ + directives: DirectiveElem[]; + + /** declarations found in this module */ + declarations: GlobalDeclarationElem[]; +} + +export type GlobalDeclarationElem = + | AliasElem + | ConstAssertElem + | DeclarationElem + | FunctionDeclarationElem + | StructElem; + +interface GlobalDeclarationBase { + kind: GlobalDeclarationElem["kind"]; + span: Span; + attributes?: AttributeElem[]; +} + +/** an alias statement */ +export interface AliasElem extends GlobalDeclarationBase { + kind: "alias"; + name: DeclIdent; + type: TemplatedIdentElem; +} + +/** a const_assert statement */ +export interface ConstAssertElem extends GlobalDeclarationBase { + kind: "assert"; + expression: ExpressionElem; +} + +/** a var/let/const/override declaration. Can also be used as a normal statement. */ +export interface DeclarationElem extends GlobalDeclarationBase { + kind: "declaration"; + variant: DeclarationVariant; + name: DeclIdent; + type?: TemplatedIdentElem; + initializer?: ExpressionElem; +} + +export type DeclarationVariant = + | { kind: "const" } + | { kind: "override" } + | { kind: "let" } + | { kind: "var"; template?: ExpressionElem[] }; + +/** a function declaration */ +export interface FunctionDeclarationElem extends GlobalDeclarationBase { + kind: "function"; + name: DeclIdent; + params: FunctionParam[]; + returnAttributes?: AttributeElem[]; + returnType?: TemplatedIdentElem; + body: CompoundStatement; +} +export interface FunctionParam { + attributes?: AttributeElem[]; + name: DeclIdent; + type: TemplatedIdentElem; +} + +/** a struct declaration */ +export interface StructElem extends GlobalDeclarationBase { + kind: "struct"; + name: DeclIdent; + members: StructMemberElem[]; + span: Span; + bindingStruct?: true; // used later during binding struct transformation +} + +/** a member of a struct declaration */ +export interface StructMemberElem { + name: NameElem; + type: TemplatedIdentElem; + attributes?: AttributeElem[]; + mangledVarName?: string; // root name if transformed to a var (for binding struct transformation) +} + +/** an attribute like '@compute' or '@binding(0)' */ +export interface AttributeElem { + kind: "attribute"; + attribute: Attribute; + span: Span; +} + +export type Attribute = + | StandardAttribute + | InterpolateAttribute + | BuiltinAttribute + | DiagnosticAttribute + | IfAttribute; + +export interface StandardAttribute { + kind: "attribute"; + name: string; + params: ExpressionElem[]; +} + +export interface InterpolateAttribute { + kind: "@interpolate"; + params: NameElem[]; +} + +export interface BuiltinAttribute { + kind: "@builtin"; + param: NameElem; +} + +export interface DiagnosticAttribute { + kind: "@diagnostic"; + severity: NameElem; + rule: [NameElem, NameElem | null]; +} + +export interface IfAttribute { + kind: "@if"; + param: ConditionalExpressionElem; +} + +/** For conditional compilation */ +export interface ConditionalExpressionElem { + kind: "translate-time-expression"; + expression: ExpressionElem; + span: Span; +} + +export type Statement = + | ForStatement + | IfStatement + | LoopStatement + | SwitchStatement + | WhileStatement + | CompoundStatement + | FunctionCallStatement + | DeclarationElem + | AssignmentStatement + | PostfixStatement + | BreakStatement + | ContinueStatement + | DiscardStatement + | ReturnStatement + | ConstAssertElem; + +interface StatementBase { + kind: Statement["kind"]; + attributes?: AttributeElem[]; + span: Span; +} + +/** for(let i = 0; i < 10; i++) { } */ +export interface ForStatement extends StatementBase { + kind: "for-statement"; + initializer?: ForInitStatement; + condition?: ExpressionElem; + update?: ForUpdateStatement; + body: CompoundStatement; +} + +export type ForInitStatement = + | FunctionCallStatement + | DeclarationElem + | AssignmentStatement + | PostfixStatement; + +export type ForUpdateStatement = + | FunctionCallStatement + | AssignmentStatement + | PostfixStatement; + +/** if(1 == 1) { } */ +export interface IfStatement extends StatementBase { + kind: "if-else-statement"; + main: IfClause; +} + +/** A clause in an if statement (`if`, `else if`, `else`), without attributes. */ +export interface IfClause { + kind: "if-clause"; + condition: ExpressionElem; + accept: CompoundStatement; + reject?: IfClause | CompoundStatement; +} + +export interface LoopStatement extends StatementBase { + kind: "loop-statement"; + body: CompoundStatement; + /** Last element in the body */ + continuing?: ContinuingStatement; +} + +export interface ContinuingStatement { + kind: "continuing-statement"; + attributes?: AttributeElem[]; + body: CompoundStatement; + /** Last element in the body */ + breakIf?: BreakIfStatement; + span: Span; +} + +export interface BreakIfStatement { + kind: "break-if-statement"; + attributes?: AttributeElem[]; + expression: ExpressionElem; + span: Span; +} + +export interface SwitchStatement extends StatementBase { + kind: "switch-statement"; + selector: ExpressionElem; + bodyAttributes?: AttributeElem[]; + clauses: SwitchClause[]; +} +/** + * `case foo: {}` or `default: {}`. + * A `default:` is modeled as a `case default:` + */ +export interface SwitchClause { + attributes?: AttributeElem[]; + cases: SwitchCaseSelector[]; + body: CompoundStatement; + span: Span; +} +export type SwitchCaseSelector = ExpressionCaseSelector | DefaultCaseSelector; +export interface ExpressionCaseSelector { + expression: ExpressionElem; +} +export interface DefaultCaseSelector { + expression: "default"; + span: Span; +} + +export interface WhileStatement extends StatementBase { + kind: "while-statement"; + condition: ExpressionElem; + body: CompoundStatement; +} + +export interface CompoundStatement extends StatementBase { + kind: "compound-statement"; + body: Statement[]; +} + +/** `foo(arg, arg);` */ +export interface FunctionCallStatement extends StatementBase { + kind: "call-statement"; + function: TemplatedIdentElem; + arguments: ExpressionElem[]; +} + +export interface AssignmentStatement extends StatementBase { + kind: "assignment-statement"; + left: LhsExpression | LhsDiscard; + operator: AssignmentOperator; + right: ExpressionElem; +} + +export interface AssignmentOperator { + value: + | ("=" | "<<=" | ">>=" | "%=" | "&=") + | ("*=" | "+=" | "-=" | "/=" | "^=" | "|="); + span: Span; +} + +export interface PostfixStatement extends StatementBase { + kind: "postfix-statement"; + operator: PostfixOperator; + expression: LhsExpression; +} + +export interface PostfixOperator { + value: "++" | "--"; + span: Span; +} + +export interface BreakStatement extends StatementBase { + kind: "break-statement"; +} + +export interface ContinueStatement extends StatementBase { + kind: "continue-statement"; +} + +export interface DiscardStatement extends StatementBase { + kind: "discard-statement"; +} + +export interface ReturnStatement extends StatementBase { + kind: "return-statement"; + expression?: ExpressionElem; +} + +export interface LhsDiscard { + kind: "discard-expression"; + span: Span; +} + +export type LhsExpression = + | LhsUnaryExpression + | LhsComponentExpression + | LhsComponentMemberExpression + | LhsParenthesizedExpression + | LhsIdentElem; + +/** Analogous to the {@link TemplatedIdentElem} */ +export interface LhsIdentElem { + kind: "lhs-ident"; + symbolRef: null | number; + name: FullIdent; +} + +/** (expr) */ +export interface LhsParenthesizedExpression { + kind: "parenthesized-expression"; + expression: LhsExpression; +} +/** `foo[expr]` */ +export interface LhsComponentExpression { + kind: "component-expression"; + base: LhsExpression; + access: ExpressionElem; +} +/** `foo.member` */ +export interface LhsComponentMemberExpression { + kind: "component-member-expression"; + base: LhsExpression; + access: NameElem; +} +/** `+foo` */ +export interface LhsUnaryExpression { + kind: "unary-expression"; + operator: LhsUnaryOperator; + expression: LhsExpression; +} +export interface LhsUnaryOperator { + value: "&" | "*"; + span: Span; +} diff --git a/tools/packages/wesl/src/parse/WeslExpression.ts b/tools/packages/wesl/src/parse/WeslExpression.ts deleted file mode 100644 index d29a806b0..000000000 --- a/tools/packages/wesl/src/parse/WeslExpression.ts +++ /dev/null @@ -1,207 +0,0 @@ -import { - collectArray, - delimited, - fn, - opt, - or, - Parser, - preceded, - repeat, - repeatPlus, - req, - seq, - Stream, - tagScope, - tokenOf, - tracing, - withSep, - withSepPlus, - yes, -} from "mini-parse"; -import { - expressionCollect, - memberRefCollect, - nameCollect, - refIdent, - stuffCollect, - typeRefCollect, -} from "../WESLCollect"; -import { number, qualified_ident, word } from "./WeslBaseGrammar"; -import { templateClose, templateOpen, WeslToken } from "./WeslStream"; - -export const opt_template_list = opt( - seq( - templateOpen, - withSepPlus(",", () => template_parameter), - req(templateClose, "invalid template, expected '>'"), - ), -); - -// prettier-ignore -const template_elaborated_ident = seq( - qualified_ident.collect(refIdent), - opt_template_list -); -const literal = or("true", "false", number); -const paren_expression = seq( - "(", - () => expression, - req(")", "invalid expression, expected ')'"), -); - -const primary_expression = or( - literal, - paren_expression, - seq(template_elaborated_ident, opt(fn(() => argument_expression_list))), -); -export const component_or_swizzle = repeatPlus( - or( - preceded(".", word), - collectArray( - delimited( - "[", - () => expression, - req("]", "invalid expression, expected ']'"), - ), - ), - ), -); -// LATER Remove -// prettier-ignore -/** parse simple struct.member style references specially, for binding struct lowering */ -export const simple_component_reference = tagScope( - seq( - qualified_ident.collect(refIdent, "structRef"), - seq(".", word.collect(nameCollect, "component")), - opt(component_or_swizzle.collect(stuffCollect, "extra_components")) - ).collect(memberRefCollect) -); -const unary_expression: Parser, any> = or( - seq(tokenOf("symbol", ["!", "&", "*", "-", "~"]), () => unary_expression), - or( - simple_component_reference, - seq(primary_expression, opt(component_or_swizzle)), - ), -); -const bitwise_post_unary = or( - // LATER I can skip template list discovery in these cases, because a&b { - const shift_left = seq("<<", unary_expression); - const shift_right = seq(">>", unary_expression); - const mul_add = seq( - repeat(seq(multiplicative_operator, unary_expression)), - repeat( - seq( - additive_operator, - unary_expression, - repeat(seq(multiplicative_operator, unary_expression)), - ), - ), - ); - return inTemplate ? - or(shift_left, mul_add) - : or(shift_left, shift_right, mul_add); -}; -const relational_post_unary = (inTemplate: boolean) => { - return seq( - shift_post_unary(inTemplate), - opt( - seq( - // '<' is unambiguous, since templates were already caught by the primary expression inside of the previous unary_expression! - inTemplate ? - tokenOf("symbol", ["<", "<=", "!=", "=="]) - : tokenOf("symbol", [">", ">=", "<", "<=", "!=", "=="]), - // LATER I can skip template list discovery in this cases, because a>=b, any> => { - return seq( - unary_expression, - or( - bitwise_post_unary, - seq( - relational_post_unary(inTemplate), - inTemplate ? - // Don't accept || or && in template mode - yes() - : or( - repeatPlus( - seq("||", seq(unary_expression, relational_post_unary(false))), - ), - repeatPlus( - seq("&&", seq(unary_expression, relational_post_unary(false))), - ), - yes().map(() => []), - ), - ), - ), - ); -}; - -let maybe_template = false; -export const expression = expressionParser(maybe_template); -let is_template = true; -const template_arg_expression = expressionParser(is_template); - -// prettier-ignore -const std_type_specifier = seq( - qualified_ident .collect(refIdent, "typeRefName"), - () => opt_template_list, -) .collect(typeRefCollect); - -// prettier-ignore -export const type_specifier: Parser,any> = tagScope( - std_type_specifier, -) .ctag("typeRefElem"); - -/** a template_arg_expression with additional collection for parameters - * that are types like array vs. expressions like 1+2 */ -// prettier-ignore -const template_parameter = or( - // LATER Remove this, it's wrong. This should instead be done by inspecting the syntax tree. - type_specifier.ctag("templateParam"), - template_arg_expression.collect(expressionCollect, "templateParam") -); - -export const argument_expression_list = seq( - "(", - withSep(",", expression), - req(")", "invalid fn arguments, expected ')'"), -); - -if (tracing) { - const names: Record, unknown>> = { - argument_expression_list, - type_specifier, - opt_template_list, - template_elaborated_ident, - literal, - paren_expression, - primary_expression, - component_or_swizzle, - unary_expression, - expression, - template_arg_expression, - }; - - Object.entries(names).forEach(([name, parser]) => { - parser.setTraceName(name); - }); -} diff --git a/tools/packages/wesl/src/parse/WeslGrammar.ts b/tools/packages/wesl/src/parse/WeslGrammar.ts index e95bcf20c..56240a83b 100644 --- a/tools/packages/wesl/src/parse/WeslGrammar.ts +++ b/tools/packages/wesl/src/parse/WeslGrammar.ts @@ -4,6 +4,7 @@ import { fn, opt, or, + ParseError, Parser, preceded, repeat, @@ -11,10 +12,10 @@ import { req, separated_pair, seq, + seqObj, Span, span, Stream, - tagScope, terminated, text, token, @@ -23,71 +24,77 @@ import { tracing, withSep, withSepPlus, - yes, } from "mini-parse"; import { - BinaryExpression, - BinaryOperator, + AliasElem, + AssignmentOperator, + AssignmentStatement, + Attribute, + AttributeElem, + BreakIfStatement, + BreakStatement, BuiltinAttribute, + CompoundStatement, + ConstAssertElem, + ContinueStatement, + ContinuingStatement, + DeclarationElem, + DeclarationVariant, + DefaultCaseSelector, DiagnosticAttribute, - DiagnosticDirective, - EnableDirective, - ExpressionElem, + DiscardStatement, + ExpressionCaseSelector, + ForStatement, + FunctionCallStatement, + FunctionDeclarationElem, + FunctionParam, + GlobalDeclarationElem, + DeclIdent, IfAttribute, + IfClause, + IfStatement, InterpolateAttribute, - Literal, + LhsDiscard, + LhsExpression, + LoopStatement, + ModuleElem, NameElem, - ParenthesizedExpression, - RequiresDirective, + PostfixOperator, + PostfixStatement, + ReturnStatement, StandardAttribute, - TranslateTimeExpressionElem, - TranslateTimeFeature, - UnaryExpression, - UnaryOperator, - UnknownExpressionElem, -} from "../AbstractElems.ts"; -import { - aliasCollect, - assertCollect, - collectAttribute, - collectFnParam, - collectModule, - collectStruct, - collectStructMember, - collectVarLike, - declCollect, - directiveCollect, - expressionCollect, - fnCollect, - globalAssertCollect, - globalDeclCollect, - nameCollect, - partialScopeCollect, - refIdent, - scopeCollect, - specialAttribute, - statementCollect, - switchClauseCollect, - typedDecl, -} from "../WESLCollect.ts"; -import { weslImports } from "./ImportGrammar.ts"; -import { qualified_ident, word } from "./WeslBaseGrammar.ts"; + Statement, + StructElem, + StructMemberElem, + SwitchCaseSelector, + SwitchClause, + SwitchStatement, + ConditionalExpressionElem, + WhileStatement, + ForInitStatement, + ForUpdateStatement, +} from "./WeslElems.ts"; +import { import_statement } from "./ImportGrammar.ts"; +import { name, WeslParser, symbol } from "./BaseGrammar.ts"; import { argument_expression_list, - component_or_swizzle, + attribute_if_expression, expression, + lhs_expression, opt_template_list, - simple_component_reference, - type_specifier, -} from "./WeslExpression.ts"; + templated_ident, +} from "./ExpressionGrammar.ts"; import { weslExtension, WeslToken } from "./WeslStream.ts"; +import { + DiagnosticDirective, + DirectiveElem, + EnableDirective, + RequiresDirective, +} from "./DirectiveElem.ts"; +import { ExpressionElem, TemplatedIdentElem } from "./ExpressionElem.ts"; +import { ImportElem } from "./ImportElems.ts"; -const name = tokenKind("word").map(makeName); - -const diagnostic_rule_name = seq( - name, - opt(preceded(".", req(name, "invalid diagnostic rule name, expected name"))), -); +const diagnostic_rule_name = seq(name, opt(preceded(".", req(name)))); const diagnostic_control = delimited( "(", req( @@ -97,562 +104,455 @@ const diagnostic_control = delimited( seq(opt(","), req(")", "invalid diagnostic control, expected ')'")), ); +const decl_ident = tokenKind("word").map( + (v): DeclIdent => ({ + symbolRef: null, + name: v.text, + span: v.span, + }), +); + /** list of words that aren't identifiers (e.g. for @interpolate) */ const name_list = withSep(",", name, { requireOne: true }); -// LATER Add proper error reporting here. e.g. @3 should throw an error pointing at the 3 -// Currently it's not possible, since we neither accumulate the necessary context, -// nor can we add a `req` parser, since this here relies on backtracking -// prettier-ignore -const special_attribute = tagScope( - preceded("@", - or( - // These attributes have no arguments - or("compute", "const", "fragment", "invariant", "must_use", "vertex") - .map(name => makeStandardAttribute([name, []])), - - // These attributes have arguments, but the argument doesn't have any identifiers - preceded("interpolate", req(delimited("(", name_list, ")"), "invalid @interpolate, expected '('")) - .map(makeInterpolateAttribute), - preceded("builtin", req(delimited("(", name, ")"), "invalid @builtin, expected '('")) - .map(makeBuiltinAttribute), - preceded("diagnostic", req(diagnostic_control, "invalid @diagnostic, expected '('")) - .map(makeDiagnosticAttribute), - ) .ptag("attr_variant") - ) .collect(specialAttribute) -); - -// prettier-ignore -const if_attribute = tagScope( - preceded(seq("@", weslExtension("if")), - span( +const attribute: WeslParser = preceded( + "@", + or( + // These attributes have no arguments + or("compute", "const", "fragment", "invariant", "must_use", "vertex").map( + name => makeStandardAttribute([name, []]), + ), + // These attributes have arguments, but the argument doesn't have any identifiers + preceded("interpolate", req(delimited("(", name_list, ")"))).map( + makeInterpolateAttribute, + ), + preceded("builtin", req(delimited("(", name, ")"))).map( + makeBuiltinAttribute, + ), + preceded("diagnostic", req(diagnostic_control)).map( + makeDiagnosticAttribute, + ), + preceded( + weslExtension("if"), delimited( "(", fn(() => attribute_if_expression), seq(opt(","), ")"), + ).mapSpanned(makeTranslateTimeExpressionElem), + ).map(makeIfAttribute), + // These are normal attributes + seq( + or( + "workgroup_size", + "align", + "binding", + "blend_src", + "group", + "id", + "location", + "size", ), - ) .map(makeTranslateTimeExpressionElem), - ) .map(makeIfAttribute) - .ptag("attr_variant") - .collect(specialAttribute) -); - -// prettier-ignore -const normal_attribute = tagScope( - preceded("@", - or( - // These are normal attributes, with required arguments - seq( - or( - "workgroup_size", - "align", - "binding", - "blend_src", - "group", - "id", - "location", - "size", - ) .ptag("name"), - req(() => attribute_argument_list, "invalid attribute, expected '('"), - ), - - // Everything else is also a normal attribute, optional expression list - seq( - // we don't want this to interfere with if_attribute, - // but not("if") isn't necessary for now, since 'if' is a keyword, not a word - word .ptag("name"), - opt(() => attribute_argument_list), - ), - ), - ) .collect(collectAttribute), -); + req(() => attribute_argument_list), + ).map(makeStandardAttribute), + // Everything else is also a normal attribute, it might have an expression list + seq( + req(tokenKind("word").map(v => v.text)), + opt(() => attribute_argument_list).map(v => v ?? []), + ).map(makeStandardAttribute), + ), +).mapSpanned(makeAttributeElem); -// prettier-ignore const attribute_argument_list = delimited( "(", - withSep( - ",", - span(fn(() => expression)) .collect(expressionCollect, "attrParam"), // LATER These unknown expressions have decls inside of them, that's why they're tough to replace! - ), - req(")", "invalid attribute arguments, expected ')'"), + withSep(",", expression), + req(")"), ); -// separate statements with if from statements - -// prettier-ignore -const attribute_no_if = or( - special_attribute, - normal_attribute -) .ctag("attribute"); - -// prettier-ignore -const attribute_incl_if = or( - if_attribute, - special_attribute, - normal_attribute, -) .ctag("attribute"); - -const opt_attributes = repeat(attribute_incl_if); +const opt_attributes = repeat(attribute); -const opt_attributes_no_if = repeat(attribute_no_if); +const lhs_discard: WeslParser = symbol("_").map(v => ({ + kind: "discard-expression", + span: v.span, +})); -// prettier-ignore -const globalTypeNameDecl = - req( - word .collect(globalDeclCollect, "type_name"), - "invalid type name, expected a name" - ); - -// prettier-ignore -const fnNameDecl = - req( - word .collect(globalDeclCollect, "fn_name"), - "missing fn name", - ); - -// prettier-ignore -const optionally_typed_ident = tagScope( +const variable_updating_statement: WeslParser< + AssignmentStatement | PostfixStatement +> = or( seq( - word .collect(declCollect, "decl_elem"), - opt(seq(":", type_specifier)), - ) .collect(typedDecl) -) .ctag("var_name"); + lhs_expression, + or( + seq( + assignmentOperator([ + ...(["=", "<<=", ">>=", "%=", "&="] as const), + ...(["*=", "+=", "-=", "/=", "^=", "|="] as const), + ]), + expression, + ), + postfixOperator(["++", "--"]), + ), + ).mapSpanned(makeVariableUpdatingStatement), + seq(lhs_discard, assignmentOperator("="), expression).mapSpanned( + makeVariableDiscardStatement, + ), +); -const req_optionally_typed_ident = req(optionally_typed_ident, "invalid ident"); +const struct_member = seqObj({ + attributes: opt_attributes, + name: name, + _1: ":", + typeRef: req(templated_ident), +}).map(makeStructMember); -// prettier-ignore -const global_ident = tagScope( - req( - seq( - word .collect(globalDeclCollect, "decl_elem"), - opt(seq(":", type_specifier)), - ) .collect(typedDecl) - ) -) .ctag("var_name"); - -// prettier-ignore -const struct_member = tagScope( - seq( - opt_attributes, - word .collect(nameCollect, "nameElem"), - req(":", "invalid struct member, expected ':'"), - req(type_specifier, "invalid struct member, expected type specifier"), - ) .collect(collectStructMember) -) .ctag("members"); - -// prettier-ignore -const struct_decl = seq( - weslExtension(opt_attributes) .collect((cc) => cc.tags.attribute, "attributes"), +const struct_decl = preceded( "struct", - req(globalTypeNameDecl, "invalid struct, expected name"), seq( - req("{", "invalid struct, expected '{'"), - withSepPlus(",", struct_member), - req("}", "invalid struct, expected '}'"), - ) .collect(scopeCollect, "struct_scope"), -) .collect(collectStruct); + req(decl_ident), + delimited(req("{"), withSepPlus(",", struct_member), req("}")), + ), +).mapSpanned(makeStruct); /** Also covers func_call_statement.post.ident */ -// prettier-ignore -const fn_call = seq( - qualified_ident .collect(refIdent), - () => opt_template_list, +const fn_call: WeslParser = seq( + templated_ident, argument_expression_list, +).mapSpanned(makeFunctionCall); + +/** + * Covers variable_or_value_statement, variable_decl, global_variable_decl, global_value_decl. + * Does not include a semicolon. + */ +const declaration: WeslParser = seqObj({ + variant: tokenOf("keyword", ["const", "var", "override", "let"]), + varTemplate: () => opt_template_list, + typedIdent: req(seq(decl_ident, opt(preceded(":", templated_ident)))), + initializer: opt(preceded("=", expression)), +}).mapSpanned(makeDeclarationElem); + +/** Does not include the semicolon */ +const const_assert = preceded( + token("keyword", "const_assert"), + req(expression), +).mapSpanned( + (expression, span): ConstAssertElem => ({ + kind: "assert", + expression, + span, + }), ); -// prettier-ignore -const fnParam = tagScope( - seq( - opt_attributes .collect((cc) => cc.tags.attribute, "attributes"), - word .collect(declCollect, "decl_elem"), - opt(seq(":", req(type_specifier, "invalid fn parameter, expected type specifier"))) - .collect(typedDecl, "param_name"), - ) .collect(collectFnParam), -) .ctag("fn_param"); - -const fnParamList = seq("(", withSep(",", fnParam), ")"); - -// prettier-ignore -const local_variable_decl = seq( - "var", - () => opt_template_list, - req_optionally_typed_ident, - opt(seq("=", () => expression)), // no decl_scope, but I think that's ok -) .collect(collectVarLike("var")); - -// prettier-ignore -const global_variable_decl = seq( - "var", - () => opt_template_list, - global_ident, - // TODO shouldn't decl_scope include the ident type? - opt(seq("=", () => expression .collect(scopeCollect, "decl_scope"))), -); - -const attribute_if_primary_expression: Parser< - Stream, - Literal | ParenthesizedExpression | TranslateTimeFeature -> = or( - tokenOf("keyword", ["true", "false"]).map(makeLiteral), - delimited( - token("symbol", "("), - fn(() => attribute_if_expression), - token("symbol", ")"), - ).map(makeParenthesizedExpression), - tokenKind("word").map(makeTranslateTimeFeature), +const compound_statement: WeslParser = delimited( + text("{"), + fn(() => statements), + req("}"), +).mapSpanned(makeCompoundStatement); + +const for_init: WeslParser = or( + fn_call, + declaration, + variable_updating_statement, ); -const attribute_if_unary_expression: Parser< - Stream, - ExpressionElem -> = or( - seq( - token("symbol", "!").map(makeUnaryOperator), - fn(() => attribute_if_unary_expression), - ).map(makeUnaryExpression), - attribute_if_primary_expression, +const for_update: WeslParser = or( + fn_call, + variable_updating_statement, ); -const attribute_if_expression: Parser< - Stream, - ExpressionElem -> = weslExtension( +const for_statement = preceded( + "for", seq( - attribute_if_unary_expression, - or( - repeatPlus( - seq( - token("symbol", "||").map(makeBinaryOperator), - req( - attribute_if_unary_expression, - "invalid expression, expected expression", - ), - ), - ), - repeatPlus( - seq( - token("symbol", "&&").map(makeBinaryOperator), - req( - attribute_if_unary_expression, - "invalid expression, expected expression", - ), - ), + delimited( + "(", + seq( + terminated(opt(for_init), req(";")), + terminated(opt(expression), req(";")), + opt(for_update), ), - yes().map(() => []), + req(")"), ), - ).map(makeRepeatingBinaryExpression), -); + compound_statement, + ), +).mapSpanned(makeForStatement); -const unscoped_compound_statement = seq( - opt_attributes, - text("{"), - repeat(() => statement), - req("}", "invalid block, expected }"), -).collect(statementCollect); +const if_statement: WeslParser = preceded( + token("keyword", "if"), + req( + seqObj({ + ifBranch: req(seq(expression, compound_statement)), + elseIfBranch: repeat( + preceded(seq("else", "if"), req(seq(expression, compound_statement))), + ), + elseBranch: opt(preceded("else", req(compound_statement))), + }), + ), +).mapSpanned(makeIfStatement); -// prettier-ignore -const compound_statement = tagScope( - seq( - opt_attributes, - seq( - text("{"), - repeat(() => statement), - req("}", "invalid block, expected '}'"), - ) .collect(scopeCollect), - ) .collect(statementCollect) -); +interface CustomCompoundStatement { + attributes: AttributeElem[]; + body: (Statement | ContinuingStatement | BreakIfStatement)[]; + span: Span; +} -const for_init = seq( +const custom_compound_statement: WeslParser = seq( opt_attributes, - or( - fn_call, - () => variable_or_value_statement, - () => variable_updating_statement, + delimited( + text("{"), + fn(() => statements), + req("}"), ), -); +).mapSpanned(([attributes, body], span) => ({ attributes, body, span })); -const for_update = seq( - opt_attributes, - or(fn_call, () => variable_updating_statement), -); +const continuing_statement = preceded( + "continuing", + custom_compound_statement, +).mapSpanned(makeContinuingStatement); -// prettier-ignore -const for_statement = seq( // LATER consider allowing @if on for_init, expression and for_update - "for", - seq( - req("(", "invalid for loop, expected '('"), - opt(for_init), - req(";", "invalid for loop, expected ';'"), - opt(expression), - req(";", "invalid for loop, expected ';'"), - opt(for_update), - req(")", "invalid for loop, expected ')'"), - unscoped_compound_statement, - ) .collect(scopeCollect), +const break_if_statement = preceded(seq("break", "if"), expression).mapSpanned( + makeBreakIfStatement, ); -const if_statement = seq( - "if", - req(seq(expression, compound_statement), "invalid if statement"), - repeat( - seq( - "else", - "if", - req(seq(expression, compound_statement), "invalid else if branch"), - ), - ), - opt( - seq("else", req(compound_statement, "invalid else branch, expected '{'")), - ), +const loop_statement = preceded("loop", custom_compound_statement).mapSpanned( + makeLoopStatement, ); -// prettier-ignore -const loop_statement = seq( - "loop", - opt_attributes_no_if, - req( - seq( - "{", - repeat(() => statement), - opt( - tagScope( - seq( - opt_attributes, - "continuing", - opt_attributes_no_if, - "{", - repeat(() => statement), - tagScope( - opt( - seq( - opt_attributes, - seq("break", "if", expression, ";") - ) .collect(statementCollect) - ) - ), - "}", - ) .collect(statementCollect) - .collect(scopeCollect) - ), - ), - "}", - ), - "invalid loop statement" - ), -) .collect(scopeCollect); - -const case_selector = or("default", expression); - -// prettier-ignore -const switch_clause = tagScope( - seq( - opt_attributes, - or( - seq( - "case", +const case_selector: WeslParser = or( + token("keyword", "default").map(makeDefaultCaseSelector), + expression.map(makeExpressionCaseSelector), +); +const switch_clause = seq( + opt_attributes, + or( + preceded( + "case", + separated_pair( withSep(",", case_selector, { requireOne: true }), opt(":"), compound_statement, ), - seq("default", opt(":"), compound_statement), - ). collect(switchClauseCollect), - ) -); -const switch_body = seq(opt_attributes, "{", repeatPlus(switch_clause), "}"); -const switch_statement = seq("switch", expression, switch_body); - -const while_statement = seq("while", expression, compound_statement); - -const regular_statement = or( - for_statement, - if_statement, - loop_statement, - switch_statement, - while_statement, - seq("break", ";"), // ambiguous with break if - seq("continue", req(";", "invalid statement, expected ';'")), - seq(";"), // LATER this one cannot have attributes in front of it - () => const_assert, - seq("discard", req(";", "invalid statement, expected ';'")), - seq("return", opt(expression), req(";", "invalid statement, expected ';'")), - seq(fn_call, req(";", "invalid statement, expected ';'")), - seq( - () => variable_or_value_statement, - req(";", "invalid statement, expected ';'"), + ).mapSpanned(makeSwitchClause), + separated_pair( + token("keyword", "default").map(makeDefaultCaseSelector), + opt(":"), + compound_statement, + ).mapSpanned(makeDefaultClause), ), +).map(attachAttributes); + +const switch_statement: WeslParser = preceded( + "switch", seq( - () => variable_updating_statement, - req(";", "invalid statement, expected ';'"), + expression, + opt_attributes, + delimited("{", repeatPlus(switch_clause), "}"), ), -); +).mapSpanned(makeSwitchStatement); -// prettier-ignore -const conditional_statement = tagScope( - seq( - opt_attributes, - regular_statement - ) .collect(statementCollect) - .collect(partialScopeCollect)); +const while_statement: WeslParser = preceded( + "while", + seq(expression, compound_statement), +).mapSpanned(makeWhileStatement); -// prettier-ignore -const unconditional_statement = tagScope( - seq( - opt_attributes_no_if, - regular_statement, - ) +const break_statement = token("keyword", "break").map( + (v): BreakStatement => ({ kind: "break-statement", span: v.span }), ); - -// prettier-ignore -const statement: Parser, any> = or( - compound_statement, - unconditional_statement, - conditional_statement +const continue_statement = token("keyword", "continue").map( + (v): ContinueStatement => ({ kind: "continue-statement", span: v.span }), ); - -// prettier-ignore -const lhs_expression: Parser,any> = or( - simple_component_reference, - seq( - qualified_ident .collect(refIdent), - opt(component_or_swizzle) - ), - seq( - "(", - () => lhs_expression, - ")", - opt(component_or_swizzle) // LATER this doesn't find member references. - ), - seq("&", () => lhs_expression), - seq("*", () => lhs_expression), +const discard_statement = token("keyword", "discard").map( + (v): DiscardStatement => ({ kind: "discard-statement", span: v.span }), ); - -// prettier-ignore -const variable_or_value_statement = tagScope( // LATER consider collecting these as var elems and scopes - or( - // Also covers the = expression case - local_variable_decl, - seq("const", req_optionally_typed_ident, req("=", "invalid const declaration, expected '='"), expression), - seq( - "let", - req_optionally_typed_ident, - req("=", "invalid let declaration, expected '='"), - expression - ) - ) +const return_statement = preceded( + token("keyword", "return"), + opt(expression), +).mapSpanned( + (expression, span): ReturnStatement => ({ + kind: "return-statement", + expression: expression ?? undefined, + span, + }), ); -const variable_updating_statement = or( - seq( - lhs_expression, - or("=", "<<=", ">>=", "%=", "&=", "*=", "+=", "-=", "/=", "^=", "|="), - expression, +const statement: WeslParser< + Statement | ContinuingStatement | BreakIfStatement +> = seq( + opt_attributes, + or( + for_statement, + if_statement, + loop_statement, + switch_statement, + while_statement, + compound_statement, + terminated(break_statement, ";"), + terminated(continue_statement, ";"), + // seq(";") is excluded, since it is parsed by the statements parser below + terminated(const_assert, ";"), + terminated(discard_statement, ";"), + terminated(return_statement, ";"), + terminated(fn_call, ";"), + terminated( + declaration.verifyMap(v => + v.variant.kind !== "override" ? { value: v } : null, + ), + ";", + ), + terminated(() => variable_updating_statement, ";"), + // Those extra statement types are parsed here to avoid backtracking (when parsing attributes above a statement) + continuing_statement, + break_if_statement, ), - seq(lhs_expression, or("++", "--")), - seq("_", "=", expression), +).map(attachAttributes); + +const statements: WeslParser< + (Statement | ContinuingStatement | BreakIfStatement)[] +> = preceded(repeat(";"), repeat(terminated(statement, repeat(";")))); + +const function_param: WeslParser = seq( + opt_attributes, + decl_ident, + preceded(":", req(templated_ident)), +).map(([attributes, name, type]) => ({ + attributes, + name, + type, +})); + +const function_param_list: WeslParser = delimited( + "(", + withSep(",", function_param), + ")", ); -// prettier-ignore -const fn_decl = seq( - tagScope( - opt_attributes .collect((cc) => cc.tags.attribute || []), - ) .ctag("fn_attributes"), +const function_decl: WeslParser = preceded( text("fn"), - req(fnNameDecl, "invalid fn, expected function name"), seq( - req(fnParamList, "invalid fn, expected function parameters") - .collect(scopeCollect, "header_scope"), - opt(seq( - "->", - opt_attributes .collect((cc) => cc.tags.attribute, "return_attributes"), - type_specifier .ctag("return_type") - .collect(scopeCollect, "return_scope") - )), - req( - unscoped_compound_statement, - "invalid fn, expected function body" - ) .ctag("body_statement") - .collect(scopeCollect, "body_scope"), - ) -) .collect(partialScopeCollect, "fn_partial_scope") - .collect(fnCollect); - -// prettier-ignore -const global_value_decl = or( - seq( - opt_attributes, - "override", - global_ident, - seq(opt(seq("=", expression .collect(scopeCollect, "decl_scope")))), // TODO partial scopes for decl_scopes? - ";", - ) .collect(collectVarLike("override")), - seq( - opt_attributes, - "const", - global_ident, - "=", - seq(expression) .collect(scopeCollect, "decl_scope"), - ";", - ) .collect(collectVarLike("const")), + req(decl_ident), + req(function_param_list), + opt(preceded(symbol("->"), seq(opt_attributes, templated_ident))), + req(compound_statement), + ), +).mapSpanned(makeFunctionDeclarationElem); + +const global_alias: WeslParser = terminated( + seqObj({ + _1: "alias", + name: req(decl_ident), + _2: req("="), + type: req(templated_ident), + }).mapSpanned(({ name, type }, span): AliasElem => { + return { kind: "alias", name, type, span }; + }), + req(";"), ); -// prettier-ignore -const global_alias = seq( - weslExtension(opt_attributes) .collect((cc) => cc.tags.attribute, "attributes"), - "alias", - req(word, "invalid alias, expected name") .collect(globalDeclCollect, "alias_name"), - req("=", "invalid alias, expected '='"), - req(type_specifier, "invalid alias, expected type") .collect(scopeCollect, "alias_scope"), - req(";", "invalid alias, expected ';'"), -) .collect(aliasCollect); - -// prettier-ignore -const const_assert = tagScope( - seq( - opt_attributes, - "const_assert", - req(expression, "invalid const_assert, expected expression"), - req(";", "invalid statement, expected ';'") - ) .collect(assertCollect) -) .ctag("const_assert"); - -// prettier-ignore -const global_directive = tagScope( +const global_directive = terminated( seq( + // LATER Hoist the attributes further up for even less backtracking opt_attributes, - terminated( + span( or( - preceded("diagnostic", diagnostic_control) .map(makeDiagnosticDirective), - preceded("enable", name_list) .map(makeEnableDirective), - preceded("requires", name_list) .map(makeRequiresDirective), - ) .ptag("directive"), - ";", + preceded("diagnostic", diagnostic_control).map(makeDiagnosticDirective), + preceded("enable", name_list).map(makeEnableDirective), + preceded("requires", name_list).map(makeRequiresDirective), + ), ), - ) .collect(directiveCollect) -); + ), + ";", +).map(([attributes, { value: directive, span }]): DirectiveElem => { + return { kind: "directive", attributes, directive: directive, span }; +}); -// prettier-ignore -const global_decl = tagScope( +const global_decl: WeslParser = seq( + opt_attributes, or( - fn_decl, - seq( - opt_attributes, - global_variable_decl, - ";") .collect(collectVarLike("gvar")), - global_value_decl, - ";", + function_decl, + terminated( + declaration.verifyMap(v => + v.variant.kind !== "let" ? { value: v } : null, + ), + ";", + ), global_alias, - const_assert .collect(globalAssertCollect), + terminated(const_assert, ";"), struct_decl, ), +).map(attachAttributes); + +/** The translation_unit rule allows for stray semicolons */ +const global_decls: WeslParser = preceded( + repeat(";"), + repeat(terminated(global_decl, repeat(";"))), ); -// prettier-ignore export const weslRoot = seq( - weslExtension(weslImports), - repeat(global_directive), - repeat(global_decl), - req(eof(), "invalid WESL, expected EOF"), - ) .collect(collectModule, "collectModule"); + weslExtension(repeat(import_statement)), + repeat(global_directive), + global_decls, + req(eof()), +).map(makeModule); + +function makeDeclarationElem( + value: { + variant: WeslToken<"keyword">; + varTemplate: ExpressionElem[] | null; + typedIdent: [DeclIdent, TemplatedIdentElem | null]; + initializer: ExpressionElem | null; + }, + span: Span, +): DeclarationElem { + let variant: DeclarationVariant; + if (value.variant.text === "const") { + // LATER: We can actually report good errors here + if (value.varTemplate !== null) { + throw new ParseError("const