-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbinop.mjs
101 lines (92 loc) · 2.97 KB
/
binop.mjs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/* Define binary operations for both GPU and CPU */
class BinOp {
constructor(args) {
// no defaults! if something is undefined, go with it
Object.assign(this, args);
}
}
class BinOpAdd extends BinOp {
constructor(args) {
super(args);
this.identity = 0;
switch (this.datatype) {
case "f32":
break;
case "i32":
break;
case "u32": // fall-through OK
default:
this.wgslatomic = "atomicAdd"; // u32 only
break;
}
this.wgslfn = `fn binop(a : ${this.datatype}, b : ${this.datatype}) -> ${this.datatype} {return a+b;}`;
this.subgroupOp = "subgroupAdd";
this.op = (a, b) => a + b;
}
}
export const BinOpAddU32 = new BinOpAdd({ datatype: "u32" });
export const BinOpAddF32 = new BinOpAdd({ datatype: "f32" });
export const BinOpAddI32 = new BinOpAdd({ datatype: "i32" });
class BinOpMin extends BinOp {
constructor(args) {
super(args);
/* identity depends on datatype */
switch (this.datatype) {
case "f32":
this.identity = 3.402823466385288e38; // FLT_MAX
break;
case "i32":
this.identity = 0x7fffffff;
break;
case "u32": // fall-through OK
default:
this.identity = 0xffffffff;
break;
}
this.op = (a, b) => Math.min(a, b);
this.wgslfn = `fn binop(a : ${this.datatype}, b : ${this.datatype}) -> ${this.datatype} {return min(a,b);}`;
this.wgslatomic = "atomicMin";
this.subgroupOp = "subgroupMin";
}
}
export const BinOpMinU32 = new BinOpMin({ datatype: "u32" });
export const BinOpMinF32 = new BinOpMin({ datatype: "f32" });
export const BinOpMinI32 = new BinOpMin({ datatype: "i32" });
class BinOpMax extends BinOp {
constructor(args) {
super(args);
/* identity depends on datatype */
switch (this.datatype) {
case "f32":
this.identity = -3.402823466385288e38; // -FLT_MAX
break;
case "i32":
this.identity = 0xf0000000;
break;
case "u32": // fall-through OK
default:
this.identity = 0;
break;
}
this.op = (a, b) => Math.max(a, b);
this.wgslfn = `fn binop(a : ${this.datatype}, b : ${this.datatype}) -> ${this.datatype} {return max(a,b);}`;
this.wgslatomic = "atomicMax";
this.subgroupOp = "subgroupMax";
}
}
export const BinOpMaxU32 = new BinOpMax({ datatype: "u32" });
export const BinOpMaxF32 = new BinOpMax({ datatype: "f32" });
export const BinOpMaxI32 = new BinOpMax({ datatype: "i32" });
class BinOpMultiply extends BinOp {
constructor(args) {
super(args);
this.identity = 1;
this.op = (a, b) => a * b;
this.wgslfn = `fn binop(a : ${this.datatype}, b : ${this.datatype}) -> ${this.datatype} {return a*b;}`;
this.wgslatomic = "atomicMul";
this.subgroupOp = "subgroupMul";
}
}
export const BinOpMultiplyU32 = new BinOpMultiply({ datatype: "u32" });
export const BinOpMultiplyF32 = new BinOpMultiply({ datatype: "f32" });
export const BinOpMultiplyI32 = new BinOpMultiply({ datatype: "i32" });