Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/strands/ir_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export const NodeTypeRequiredFields = {
export const StatementType = {
DISCARD: 'discard',
BREAK: 'break',
EARLY_RETURN: 'early_return',
EXPRESSION: 'expression', // Used when we want to output a single expression as a statement, e.g. a for loop condition
EMPTY: 'empty', // Used for empty statements like ; in for loops
};
Expand Down
119 changes: 82 additions & 37 deletions src/strands/strands_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import {
isStructType,
OpCode,
StatementType,
NodeType,
// isNativeType
} from './ir_types'
import { strandsBuiltinFunctions } from './strands_builtins'
import { StrandsConditional } from './strands_conditionals'
import { StrandsFor } from './strands_for'
import * as CFG from './ir_cfg'
import * as DAG from './ir_dag';
import * as FES from './strands_FES'
import { getNodeDataFromID } from './ir_dag'
import { StrandsNode, createStrandsNode } from './strands_node'
Expand Down Expand Up @@ -63,6 +65,39 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
return new StrandsFor(strandsContext, initialCb, conditionCb, updateCb, bodyCb, initialVars).build();
};
fn.strandsFor = p5.strandsFor;
p5.strandsEarlyReturn = function(value) {
const { dag, cfg } = strandsContext;

// Ensure we're inside a hook
if (!strandsContext.activeHook) {
throw new Error('strandsEarlyReturn can only be used inside a hook callback');
}

// Convert value to a StrandsNode if it isn't already
const valueNode = value instanceof StrandsNode ? value : p5.strandsNode(value);

// Create a new CFG block for the early return
const earlyReturnBlockID = CFG.createBasicBlock(cfg, BlockType.DEFAULT);
CFG.addEdge(cfg, cfg.currentBlock, earlyReturnBlockID);
CFG.pushBlock(cfg, earlyReturnBlockID);

// Create the early return statement node
const nodeData = DAG.createNodeData({
nodeType: NodeType.STATEMENT,
statementType: StatementType.EARLY_RETURN,
dependsOn: [valueNode.id]
});
const earlyReturnID = DAG.getOrCreateNode(dag, nodeData);
CFG.recordInBasicBlock(cfg, cfg.currentBlock, earlyReturnID);

// Add the value to the hook's earlyReturns array for later type checking
strandsContext.activeHook.earlyReturns.push({ earlyReturnID, valueNode });

CFG.popBlock(cfg);

return valueNode;
};
fn.strandsEarlyReturn = p5.strandsEarlyReturn;
p5.strandsNode = function(...args) {
if (args.length === 1 && args[0] instanceof StrandsNode) {
return args[0];
Expand Down Expand Up @@ -403,53 +438,62 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
CFG.addEdge(cfg, cfg.currentBlock, entryBlockID);
CFG.pushBlock(cfg, entryBlockID);
const args = createHookArguments(strandsContext, hookType.parameters);
strandsContext.activeHook = hookImplementation;
const userReturned = hookUserCallback(...args);
strandsContext.activeHook = undefined;
const expectedReturnType = hookType.returnType;
let rootNodeID = null;
if(isStructType(expectedReturnType)) {
const expectedStructType = structType(expectedReturnType);
if (userReturned instanceof StrandsNode) {
const returnedNode = getNodeDataFromID(strandsContext.dag, userReturned.id);
if (returnedNode.baseType !== expectedStructType.typeName) {
FES.userError("type error", `You have returned a ${userReturned.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
const handleRetVal = (retNode) => {
if(isStructType(expectedReturnType)) {
const expectedStructType = structType(expectedReturnType);
if (retNode instanceof StrandsNode) {
const returnedNode = getNodeDataFromID(strandsContext.dag, retNode.id);
if (returnedNode.baseType !== expectedStructType.typeName) {
FES.userError("type error", `You have returned a ${retNode.baseType} from ${hookType.name} when a ${expectedStructType.typeName} was expected.`);
}
const newDeps = returnedNode.dependsOn.slice();
for (let i = 0; i < expectedStructType.properties.length; i++) {
const expectedType = expectedStructType.properties[i].dataType;
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[retNode.id], strandsContext);
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
}
dag.dependsOn[retNode.id] = newDeps;
return retNode.id;
}
const newDeps = returnedNode.dependsOn.slice();
for (let i = 0; i < expectedStructType.properties.length; i++) {
const expectedType = expectedStructType.properties[i].dataType;
const receivedNode = createStrandsNode(returnedNode.dependsOn[i], dag.dependsOn[userReturned.id], strandsContext);
newDeps[i] = enforceReturnTypeMatch(strandsContext, expectedType, receivedNode, hookType.name);
else {
const expectedProperties = expectedStructType.properties;
const newStructDependencies = [];
for (let i = 0; i < expectedProperties.length; i++) {
const expectedProp = expectedProperties[i];
const propName = expectedProp.name;
const receivedValue = retNode[propName];
if (receivedValue === undefined) {
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
`Received: { ${Object.keys(retNode).join(', ')} }\n` +
`All of the properties are required!`);
}
const expectedTypeInfo = expectedProp.dataType;
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
newStructDependencies.push(returnedPropID);
}
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
return newStruct.id;
}
dag.dependsOn[userReturned.id] = newDeps;
rootNodeID = userReturned.id;
}
else {
const expectedProperties = expectedStructType.properties;
const newStructDependencies = [];
for (let i = 0; i < expectedProperties.length; i++) {
const expectedProp = expectedProperties[i];
const propName = expectedProp.name;
const receivedValue = userReturned[propName];
if (receivedValue === undefined) {
FES.userError('type error', `You've returned an incomplete struct from ${hookType.name}.\n` +
`Expected: { ${expectedReturnType.properties.map(p => p.name).join(', ')} }\n` +
`Received: { ${Object.keys(userReturned).join(', ')} }\n` +
`All of the properties are required!`);
}
const expectedTypeInfo = expectedProp.dataType;
const returnedPropID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, receivedValue, hookType.name);
newStructDependencies.push(returnedPropID);
else /*if(isNativeType(expectedReturnType.typeName))*/ {
if (!expectedReturnType.dataType) {
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
}
const newStruct = build.structConstructorNode(strandsContext, expectedStructType, newStructDependencies);
rootNodeID = newStruct.id;
const expectedTypeInfo = expectedReturnType.dataType;
return enforceReturnTypeMatch(strandsContext, expectedTypeInfo, retNode, hookType.name);
}
}
else /*if(isNativeType(expectedReturnType.typeName))*/ {
if (!expectedReturnType.dataType) {
throw new Error(`Missing dataType for return type ${expectedReturnType.typeName}`);
}
const expectedTypeInfo = expectedReturnType.dataType;
rootNodeID = enforceReturnTypeMatch(strandsContext, expectedTypeInfo, userReturned, hookType.name);
for (const { valueNode, earlyReturnID } of hookImplementation.earlyReturns) {
const id = handleRetVal(valueNode);
dag.dependsOn[earlyReturnID] = [id];
}
rootNodeID = userReturned ? handleRetVal(userReturned) : undefined;
const fullHookName = `${hookType.returnType.typeName} ${hookType.name}`;
const hookInfo = availableHooks[fullHookName];
strandsContext.hooks.push({
Expand All @@ -460,6 +504,7 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
});
CFG.popBlock(cfg);
}
hookImplementation.earlyReturns = [];
strandsContext.windowOverrides[hookType.name] = window[hookType.name];
strandsContext.fnOverrides[hookType.name] = fn[hookType.name];
window[hookType.name] = hookImplementation;
Expand Down
10 changes: 7 additions & 3 deletions src/strands/strands_codegen.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { sortCFG } from "./ir_cfg";
import { structType, TypeInfoFromGLSLName } from './ir_types';
import { sortCFG } from './ir_cfg';
import * as DAG from './ir_dag';
import { NodeType, StatementType, structType, TypeInfoFromGLSLName } from './ir_types';

export function generateShaderCode(strandsContext) {
const {
Expand Down Expand Up @@ -68,7 +69,10 @@ export function generateShaderCode(strandsContext) {
}
returnType = hookType.returnType.dataType;
}
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);

if (rootNodeID) {
backend.generateReturnStatement(strandsContext, generationContext, rootNodeID, returnType);
}
hooksObj[`${hookType.returnType.typeName} ${hookType.name}`] = [firstLine, ...generationContext.codeLines, '}'].join('\n');
}

Expand Down
Loading
Loading