Skip to content

Commit a567f38

Browse files
committed
spirv: improve shuffle codegen
1 parent a3b1ba8 commit a567f38

File tree

2 files changed

+138
-8
lines changed

2 files changed

+138
-8
lines changed

src/codegen/spirv.zig

+55-8
Original file line numberDiff line numberDiff line change
@@ -4082,25 +4082,72 @@ const DeclGen = struct {
40824082
const b = try self.resolve(extra.b);
40834083
const mask = Value.fromInterned(extra.mask);
40844084

4085-
const ty = self.typeOfIndex(inst);
4085+
// Note: number of components in the result, a, and b may differ.
4086+
const result_ty = self.typeOfIndex(inst);
4087+
const a_ty = self.typeOf(extra.a);
4088+
const b_ty = self.typeOf(extra.b);
4089+
4090+
const scalar_ty = result_ty.scalarType(mod);
4091+
const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
4092+
4093+
// If all of the types are SPIR-V vectors, we can use OpVectorShuffle.
4094+
if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) {
4095+
// The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are
4096+
// numbered consecutively instead of using negatives.
4097+
4098+
const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod));
4099+
defer self.gpa.free(components);
4100+
4101+
const a_len = a_ty.vectorLen(mod);
4102+
4103+
for (components, 0..) |*component, i| {
4104+
const elem = try mask.elemValue(mod, i);
4105+
if (elem.isUndef(mod)) {
4106+
// This is explicitly valid for OpVectorShuffle, it indicates undefined.
4107+
component.* = 0xFFFF_FFFF;
4108+
continue;
4109+
}
4110+
4111+
const index = elem.toSignedInt(mod);
4112+
if (index >= 0) {
4113+
component.* = @intCast(index);
4114+
} else {
4115+
component.* = @intCast(~index + a_len);
4116+
}
4117+
}
40864118

4087-
var wip = try self.elementWise(ty, true);
4088-
defer wip.deinit();
4089-
for (wip.results, 0..) |*result_id, i| {
4119+
const result_id = self.spv.allocId();
4120+
try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{
4121+
.id_result_type = try self.resolveType(result_ty, .direct),
4122+
.id_result = result_id,
4123+
.vector_1 = a,
4124+
.vector_2 = b,
4125+
.components = components,
4126+
});
4127+
return result_id;
4128+
}
4129+
4130+
// Fall back to manually extracting and inserting components.
4131+
4132+
const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod));
4133+
defer self.gpa.free(components);
4134+
4135+
for (components, 0..) |*id, i| {
40904136
const elem = try mask.elemValue(mod, i);
40914137
if (elem.isUndef(mod)) {
4092-
result_id.* = try self.spv.constUndef(wip.ty_id);
4138+
id.* = try self.spv.constUndef(scalar_ty_id);
40934139
continue;
40944140
}
40954141

40964142
const index = elem.toSignedInt(mod);
40974143
if (index >= 0) {
4098-
result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index));
4144+
id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index));
40994145
} else {
4100-
result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index));
4146+
id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index));
41014147
}
41024148
}
4103-
return try wip.finalize();
4149+
4150+
return try self.constructVector(result_ty, components);
41044151
}
41054152

41064153
fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {

test/behavior/shuffle.zig

+83
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ const std = @import("std");
22
const builtin = @import("builtin");
33
const mem = std.mem;
44
const expect = std.testing.expect;
5+
const expectEqual = std.testing.expectEqual;
56

67
test "@shuffle int" {
78
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
@@ -49,6 +50,88 @@ test "@shuffle int" {
4950
try comptime S.doTheTest();
5051
}
5152

53+
test "@shuffle int strange sizes" {
54+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
55+
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
56+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
57+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
58+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
59+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
60+
61+
try comptime testShuffle(2, 2, 2);
62+
try testShuffle(2, 2, 2);
63+
try comptime testShuffle(4, 4, 4);
64+
try testShuffle(4, 4, 4);
65+
try comptime testShuffle(7, 4, 4);
66+
try testShuffle(7, 4, 4);
67+
try comptime testShuffle(8, 6, 4);
68+
try testShuffle(8, 6, 4);
69+
try comptime testShuffle(2, 7, 5);
70+
try testShuffle(2, 7, 5);
71+
try comptime testShuffle(13, 16, 12);
72+
try testShuffle(13, 16, 12);
73+
try comptime testShuffle(19, 3, 17);
74+
try testShuffle(19, 3, 17);
75+
try comptime testShuffle(1, 10, 1);
76+
try testShuffle(1, 10, 1);
77+
}
78+
79+
fn testShuffle(
80+
comptime x_len: comptime_int,
81+
comptime a_len: comptime_int,
82+
comptime b_len: comptime_int,
83+
) !void {
84+
const T = i32;
85+
const XT = @Vector(x_len, T);
86+
const AT = @Vector(a_len, T);
87+
const BT = @Vector(b_len, T);
88+
89+
const a_elems = comptime blk: {
90+
var elems: [a_len]T = undefined;
91+
for (&elems, 0..) |*elem, i| elem.* = @intCast(100 + i);
92+
break :blk elems;
93+
};
94+
var a: AT = a_elems;
95+
_ = &a;
96+
97+
const b_elems = comptime blk: {
98+
var elems: [b_len]T = undefined;
99+
for (&elems, 0..) |*elem, i| elem.* = @intCast(1000 + i);
100+
break :blk elems;
101+
};
102+
var b: BT = b_elems;
103+
_ = &b;
104+
105+
const mask_seed: []const i32 = &.{ -14, -31, 23, 1, 21, 13, 17, -21, -10, -27, -16, -5, 15, 14, -2, 26, 2, -31, -24, -16 };
106+
107+
const mask = comptime blk: {
108+
var elems: [x_len]i32 = undefined;
109+
for (&elems, 0..) |*elem, i| {
110+
const mask_val = mask_seed[i];
111+
if (mask_val >= 0) {
112+
elem.* = @mod(mask_val, a_len);
113+
} else {
114+
elem.* = @mod(mask_val, -b_len);
115+
}
116+
}
117+
118+
break :blk elems;
119+
};
120+
121+
const x: XT = @shuffle(T, a, b, mask);
122+
123+
const x_elems: [x_len]T = x;
124+
for (mask, x_elems) |m, x_elem| {
125+
if (m >= 0) {
126+
// Element from A
127+
try expectEqual(x_elem, a_elems[@intCast(m)]);
128+
} else {
129+
// Element from B
130+
try expectEqual(x_elem, b_elems[@intCast(~m)]);
131+
}
132+
}
133+
}
134+
52135
test "@shuffle bool 1" {
53136
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
54137
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO

0 commit comments

Comments
 (0)