Skip to content

Commit 45176a4

Browse files
committed
x86_64: rewrite scalar <<|
Closes #23035
1 parent c75df44 commit 45176a4

16 files changed

+4142
-587
lines changed

lib/std/math.zig

+21-11
Original file line numberDiff line numberDiff line change
@@ -774,18 +774,15 @@ pub fn Log2IntCeil(comptime T: type) type {
774774
/// Returns the smallest integer type that can hold both from and to.
775775
pub fn IntFittingRange(comptime from: comptime_int, comptime to: comptime_int) type {
776776
assert(from <= to);
777-
if (from == 0 and to == 0) {
778-
return u0;
779-
}
780777
const signedness: std.builtin.Signedness = if (from < 0) .signed else .unsigned;
781-
const largest_positive_integer = @max(if (from < 0) (-from) - 1 else from, to); // two's complement
782-
const base = log2(largest_positive_integer);
783-
const upper = (1 << base) - 1;
784-
var magnitude_bits = if (upper >= largest_positive_integer) base else base + 1;
785-
if (signedness == .signed) {
786-
magnitude_bits += 1;
787-
}
788-
return std.meta.Int(signedness, magnitude_bits);
778+
return @Type(.{ .int = .{
779+
.signedness = signedness,
780+
.bits = @as(u16, @intFromBool(signedness == .signed)) +
781+
switch (if (from < 0) @max(@abs(from) - 1, to) else to) {
782+
0 => 0,
783+
else => |pos_max| 1 + log2(pos_max),
784+
},
785+
} });
789786
}
790787

791788
test IntFittingRange {
@@ -1267,6 +1264,19 @@ pub fn log2_int(comptime T: type, x: T) Log2Int(T) {
12671264
return @as(Log2Int(T), @intCast(@typeInfo(T).int.bits - 1 - @clz(x)));
12681265
}
12691266

1267+
test log2_int {
1268+
try testing.expect(log2_int(u32, 1) == 0);
1269+
try testing.expect(log2_int(u32, 2) == 1);
1270+
try testing.expect(log2_int(u32, 3) == 1);
1271+
try testing.expect(log2_int(u32, 4) == 2);
1272+
try testing.expect(log2_int(u32, 5) == 2);
1273+
try testing.expect(log2_int(u32, 6) == 2);
1274+
try testing.expect(log2_int(u32, 7) == 2);
1275+
try testing.expect(log2_int(u32, 8) == 3);
1276+
try testing.expect(log2_int(u32, 9) == 3);
1277+
try testing.expect(log2_int(u32, 10) == 3);
1278+
}
1279+
12701280
/// Return the log base 2 of integer value x, rounding up to the
12711281
/// nearest integer.
12721282
pub fn log2_int_ceil(comptime T: type, x: T) Log2IntCeil(T) {

lib/std/math/big/int.zig

+4-4
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,12 @@ pub const Mutable = struct {
415415
// in the case that scalar happens to be small in magnitude within its type, but it
416416
// is well worth being able to use the stack and not needing an allocator passed in.
417417
// Note that Mutable.init still sets len to calcLimbLen(scalar) in any case.
418-
const limb_len = comptime switch (@typeInfo(@TypeOf(scalar))) {
418+
const limbs_len = comptime switch (@typeInfo(@TypeOf(scalar))) {
419419
.comptime_int => calcLimbLen(scalar),
420420
.int => |info| calcTwosCompLimbCount(info.bits),
421421
else => @compileError("expected scalar to be an int"),
422422
};
423-
var limbs: [limb_len]Limb = undefined;
423+
var limbs: [limbs_len]Limb = undefined;
424424
const operand = init(&limbs, scalar).toConst();
425425
return add(r, a, operand);
426426
}
@@ -2454,12 +2454,12 @@ pub const Const = struct {
24542454
// in the case that scalar happens to be small in magnitude within its type, but it
24552455
// is well worth being able to use the stack and not needing an allocator passed in.
24562456
// Note that Mutable.init still sets len to calcLimbLen(scalar) in any case.
2457-
const limb_len = comptime switch (@typeInfo(@TypeOf(scalar))) {
2457+
const limbs_len = comptime switch (@typeInfo(@TypeOf(scalar))) {
24582458
.comptime_int => calcLimbLen(scalar),
24592459
.int => |info| calcTwosCompLimbCount(info.bits),
24602460
else => @compileError("expected scalar to be an int"),
24612461
};
2462-
var limbs: [limb_len]Limb = undefined;
2462+
var limbs: [limbs_len]Limb = undefined;
24632463
const rhs = Mutable.init(&limbs, scalar);
24642464
return order(lhs, rhs.toConst());
24652465
}

lib/std/math/big/int_test.zig

-4
Original file line numberDiff line numberDiff line change
@@ -2295,8 +2295,6 @@ test "sat shift-left signed simple positive" {
22952295
}
22962296

22972297
test "sat shift-left signed multi positive" {
2298-
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
2299-
23002298
var x: SignedDoubleLimb = 1;
23012299
_ = &x;
23022300

@@ -2310,8 +2308,6 @@ test "sat shift-left signed multi positive" {
23102308
}
23112309

23122310
test "sat shift-left signed multi negative" {
2313-
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
2314-
23152311
var x: SignedDoubleLimb = -1;
23162312
_ = &x;
23172313

lib/std/math/log2.zig

+11-10
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@ const expect = std.testing.expect;
1212
/// - log2(nan) = nan
1313
pub fn log2(x: anytype) @TypeOf(x) {
1414
const T = @TypeOf(x);
15-
switch (@typeInfo(T)) {
16-
.comptime_float => {
17-
return @as(comptime_float, @log2(x));
18-
},
19-
.float => return @log2(x),
15+
return switch (@typeInfo(T)) {
16+
.comptime_float, .float => @log2(x),
2017
.comptime_int => comptime {
18+
std.debug.assert(x > 0);
2119
var x_shifted = x;
2220
// First, calculate floorPowerOfTwo(x)
2321
var shift_amt = 1;
@@ -34,12 +32,15 @@ pub fn log2(x: anytype) @TypeOf(x) {
3432
}
3533
return result;
3634
},
37-
.int => |IntType| switch (IntType.signedness) {
38-
.signed => @compileError("log2 not implemented for signed integers"),
39-
.unsigned => return math.log2_int(T, x),
40-
},
35+
.int => |int_info| math.log2_int(switch (int_info.signedness) {
36+
.signed => @Type(.{ .int = .{
37+
.signedness = .unsigned,
38+
.bits = int_info.bits -| 1,
39+
} }),
40+
.unsigned => T,
41+
}, @intCast(x)),
4142
else => @compileError("log2 not implemented for " ++ @typeName(T)),
42-
}
43+
};
4344
}
4445

4546
test log2 {

lib/zig.h

+20-11
Original file line numberDiff line numberDiff line change
@@ -1115,14 +1115,15 @@ static inline bool zig_mulo_i16(int16_t *res, int16_t lhs, int16_t rhs, uint8_t
11151115
\
11161116
static inline uint##w##_t zig_shls_u##w(uint##w##_t lhs, uint##w##_t rhs, uint8_t bits) { \
11171117
uint##w##_t res; \
1118-
if (rhs >= bits) return lhs != UINT##w##_C(0) ? zig_maxInt_u(w, bits) : lhs; \
1119-
return zig_shlo_u##w(&res, lhs, (uint8_t)rhs, bits) ? zig_maxInt_u(w, bits) : res; \
1118+
if (rhs < bits && !zig_shlo_u##w(&res, lhs, rhs, bits)) return res; \
1119+
return lhs == INT##w##_C(0) ? INT##w##_C(0) : zig_maxInt_u(w, bits); \
11201120
} \
11211121
\
1122-
static inline int##w##_t zig_shls_i##w(int##w##_t lhs, int##w##_t rhs, uint8_t bits) { \
1122+
static inline int##w##_t zig_shls_i##w(int##w##_t lhs, uint##w##_t rhs, uint8_t bits) { \
11231123
int##w##_t res; \
1124-
if ((uint##w##_t)rhs < (uint##w##_t)bits && !zig_shlo_i##w(&res, lhs, (uint8_t)rhs, bits)) return res; \
1125-
return lhs < INT##w##_C(0) ? zig_minInt_i(w, bits) : zig_maxInt_i(w, bits); \
1124+
if (rhs < bits && !zig_shlo_i##w(&res, lhs, rhs, bits)) return res; \
1125+
return lhs == INT##w##_C(0) ? INT##w##_C(0) : \
1126+
lhs < INT##w##_C(0) ? zig_minInt_i(w, bits) : zig_maxInt_i(w, bits); \
11261127
} \
11271128
\
11281129
static inline uint##w##_t zig_adds_u##w(uint##w##_t lhs, uint##w##_t rhs, uint8_t bits) { \
@@ -1851,15 +1852,23 @@ static inline bool zig_shlo_i128(zig_i128 *res, zig_i128 lhs, uint8_t rhs, uint8
18511852

18521853
static inline zig_u128 zig_shls_u128(zig_u128 lhs, zig_u128 rhs, uint8_t bits) {
18531854
zig_u128 res;
1854-
if (zig_cmp_u128(rhs, zig_make_u128(0, bits)) >= INT32_C(0))
1855-
return zig_cmp_u128(lhs, zig_make_u128(0, 0)) != INT32_C(0) ? zig_maxInt_u(128, bits) : lhs;
1856-
return zig_shlo_u128(&res, lhs, (uint8_t)zig_lo_u128(rhs), bits) ? zig_maxInt_u(128, bits) : res;
1855+
if (zig_cmp_u128(rhs, zig_make_u128(0, bits)) < INT32_C(0) && !zig_shlo_u128(&res, lhs, (uint8_t)zig_lo_u128(rhs), bits)) return res;
1856+
switch (zig_cmp_u128(lhs, zig_make_u128(0, 0))) {
1857+
case 0: return zig_make_i128(0, 0);
1858+
case 1: return zig_maxInt_u(128, bits);
1859+
default: zig_unreachable();
1860+
}
18571861
}
18581862

1859-
static inline zig_i128 zig_shls_i128(zig_i128 lhs, zig_i128 rhs, uint8_t bits) {
1863+
static inline zig_i128 zig_shls_i128(zig_i128 lhs, zig_u128 rhs, uint8_t bits) {
18601864
zig_i128 res;
1861-
if (zig_cmp_u128(zig_bitCast_u128(rhs), zig_make_u128(0, bits)) < INT32_C(0) && !zig_shlo_i128(&res, lhs, (uint8_t)zig_lo_i128(rhs), bits)) return res;
1862-
return zig_cmp_i128(lhs, zig_make_i128(0, 0)) < INT32_C(0) ? zig_minInt_i(128, bits) : zig_maxInt_i(128, bits);
1865+
if (zig_cmp_u128(rhs, zig_make_u128(0, bits)) < INT32_C(0) && !zig_shlo_i128(&res, lhs, (uint8_t)zig_lo_u128(rhs), bits)) return res;
1866+
switch (zig_cmp_i128(lhs, zig_make_i128(0, 0))) {
1867+
case -1: return zig_minInt_i(128, bits);
1868+
case 0: return zig_make_i128(0, 0);
1869+
case 1: return zig_maxInt_i(128, bits);
1870+
default: zig_unreachable();
1871+
}
18631872
}
18641873

18651874
static inline zig_u128 zig_adds_u128(zig_u128 lhs, zig_u128 rhs, uint8_t bits) {

src/Air.zig

+3-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ pub const Inst = struct {
257257
/// it shifts out any bits that disagree with the resultant sign bit.
258258
/// Uses the `bin_op` field.
259259
shl_exact,
260-
/// Saturating integer shift left. `<<|`
260+
/// Saturating integer shift left. `<<|`. The result is the same type as the `lhs`.
261+
/// The `rhs` must have the same vector shape as the `lhs`, but with any unsigned
262+
/// integer as the scalar type.
261263
/// Uses the `bin_op` field.
262264
shl_sat,
263265
/// Bitwise XOR. `^`

src/Sema.zig

+43-28
Original file line numberDiff line numberDiff line change
@@ -14215,14 +14215,15 @@ fn zirShl(
1421514215
const rhs_ty = sema.typeOf(rhs);
1421614216

1421714217
const src = block.nodeOffset(inst_data.src_node);
14218-
const lhs_src = switch (air_tag) {
14219-
.shl, .shl_sat => block.src(.{ .node_offset_bin_lhs = inst_data.src_node }),
14220-
.shl_exact => block.builtinCallArgSrc(inst_data.src_node, 0),
14221-
else => unreachable,
14222-
};
14223-
const rhs_src = switch (air_tag) {
14224-
.shl, .shl_sat => block.src(.{ .node_offset_bin_rhs = inst_data.src_node }),
14225-
.shl_exact => block.builtinCallArgSrc(inst_data.src_node, 1),
14218+
const lhs_src, const rhs_src = switch (air_tag) {
14219+
.shl, .shl_sat => .{
14220+
block.src(.{ .node_offset_bin_lhs = inst_data.src_node }),
14221+
block.src(.{ .node_offset_bin_rhs = inst_data.src_node }),
14222+
},
14223+
.shl_exact => .{
14224+
block.builtinCallArgSrc(inst_data.src_node, 0),
14225+
block.builtinCallArgSrc(inst_data.src_node, 1),
14226+
},
1422614227
else => unreachable,
1422714228
};
1422814229

@@ -14231,8 +14232,7 @@ fn zirShl(
1423114232
const scalar_ty = lhs_ty.scalarType(zcu);
1423214233
const scalar_rhs_ty = rhs_ty.scalarType(zcu);
1423314234

14234-
// TODO coerce rhs if air_tag is not shl_sat
14235-
const rhs_is_comptime_int = try sema.checkIntType(block, rhs_src, scalar_rhs_ty);
14235+
_ = try sema.checkIntType(block, rhs_src, scalar_rhs_ty);
1423614236

1423714237
const maybe_lhs_val = try sema.resolveValueResolveLazy(lhs);
1423814238
const maybe_rhs_val = try sema.resolveValueResolveLazy(rhs);
@@ -14245,7 +14245,7 @@ fn zirShl(
1424514245
if (try rhs_val.compareAllWithZeroSema(.eq, pt)) {
1424614246
return lhs;
1424714247
}
14248-
if (scalar_ty.zigTypeTag(zcu) != .comptime_int and air_tag != .shl_sat) {
14248+
if (air_tag != .shl_sat and scalar_ty.zigTypeTag(zcu) != .comptime_int) {
1424914249
const bit_value = try pt.intValue(Type.comptime_int, scalar_ty.intInfo(zcu).bits);
1425014250
if (rhs_ty.zigTypeTag(zcu) == .vector) {
1425114251
var i: usize = 0;
@@ -14282,6 +14282,8 @@ fn zirShl(
1428214282
rhs_val.fmtValueSema(pt, sema),
1428314283
});
1428414284
}
14285+
} else if (scalar_rhs_ty.isSignedInt(zcu)) {
14286+
return sema.fail(block, rhs_src, "shift by signed type '{}'", .{rhs_ty.fmt(pt)});
1428514287
}
1428614288

1428714289
const runtime_src = if (maybe_lhs_val) |lhs_val| rs: {
@@ -14309,18 +14311,34 @@ fn zirShl(
1430914311
return Air.internedToRef(val.toIntern());
1431014312
} else lhs_src;
1431114313

14312-
const new_rhs = if (air_tag == .shl_sat) rhs: {
14313-
// Limit the RHS type for saturating shl to be an integer as small as the LHS.
14314-
if (rhs_is_comptime_int or
14315-
scalar_rhs_ty.intInfo(zcu).bits > scalar_ty.intInfo(zcu).bits)
14316-
{
14317-
const max_int = Air.internedToRef((try lhs_ty.maxInt(pt, lhs_ty)).toIntern());
14318-
const rhs_limited = try sema.analyzeMinMax(block, rhs_src, .min, &.{ rhs, max_int }, &.{ rhs_src, rhs_src });
14319-
break :rhs try sema.intCast(block, src, lhs_ty, rhs_src, rhs_limited, rhs_src, false, false);
14320-
} else {
14321-
break :rhs rhs;
14322-
}
14323-
} else rhs;
14314+
const rt_rhs = switch (air_tag) {
14315+
else => unreachable,
14316+
.shl, .shl_exact => rhs,
14317+
// The backend can handle a large runtime rhs better than we can, but
14318+
// we can limit a large comptime rhs better here. This also has the
14319+
// necessary side effect of preventing rhs from being a `comptime_int`.
14320+
.shl_sat => if (maybe_rhs_val) |rhs_val| Air.internedToRef(rt_rhs: {
14321+
const bit_count = scalar_ty.intInfo(zcu).bits;
14322+
const rt_rhs_scalar_ty = try pt.smallestUnsignedInt(bit_count);
14323+
if (!rhs_ty.isVector(zcu)) break :rt_rhs (try pt.intValue(
14324+
rt_rhs_scalar_ty,
14325+
@min(try rhs_val.getUnsignedIntSema(pt) orelse bit_count, bit_count),
14326+
)).toIntern();
14327+
const rhs_len = rhs_ty.vectorLen(zcu);
14328+
const rhs_elems = try sema.arena.alloc(InternPool.Index, rhs_len);
14329+
for (rhs_elems, 0..) |*rhs_elem, i| rhs_elem.* = (try pt.intValue(
14330+
rt_rhs_scalar_ty,
14331+
@min(try (try rhs_val.elemValue(pt, i)).getUnsignedIntSema(pt) orelse bit_count, bit_count),
14332+
)).toIntern();
14333+
break :rt_rhs try pt.intern(.{ .aggregate = .{
14334+
.ty = (try pt.vectorType(.{
14335+
.len = rhs_len,
14336+
.child = rt_rhs_scalar_ty.toIntern(),
14337+
})).toIntern(),
14338+
.storage = .{ .elems = rhs_elems },
14339+
} });
14340+
}) else rhs,
14341+
};
1432414342

1432514343
try sema.requireRuntimeBlock(block, src, runtime_src);
1432614344
if (block.wantSafety()) {
@@ -14374,7 +14392,7 @@ fn zirShl(
1437414392
return sema.tupleFieldValByIndex(block, op_ov, 0, op_ov_tuple_ty);
1437514393
}
1437614394
}
14377-
return block.addBinOp(air_tag, lhs, new_rhs);
14395+
return block.addBinOp(air_tag, lhs, rt_rhs);
1437814396
}
1437914397

1438014398
fn zirShr(
@@ -36432,10 +36450,7 @@ fn generateUnionTagTypeSimple(
3643236450
const enum_ty = try ip.getGeneratedTagEnumType(gpa, pt.tid, .{
3643336451
.name = name,
3643436452
.owner_union_ty = union_type,
36435-
.tag_ty = if (enum_field_names.len == 0)
36436-
(try pt.intType(.unsigned, 0)).toIntern()
36437-
else
36438-
(try pt.smallestUnsignedInt(enum_field_names.len - 1)).toIntern(),
36453+
.tag_ty = (try pt.smallestUnsignedInt(enum_field_names.len -| 1)).toIntern(),
3643936454
.names = enum_field_names,
3644036455
.values = &.{},
3644136456
.tag_mode = .auto,

src/Type.zig

+4-4
Original file line numberDiff line numberDiff line change
@@ -4132,10 +4132,10 @@ pub const empty_tuple: Type = .{ .ip_index = .empty_tuple_type };
41324132
pub const generic_poison: Type = .{ .ip_index = .generic_poison_type };
41334133

41344134
pub fn smallestUnsignedBits(max: u64) u16 {
4135-
if (max == 0) return 0;
4136-
const base = std.math.log2(max);
4137-
const upper = (@as(u64, 1) << @as(u6, @intCast(base))) - 1;
4138-
return @as(u16, @intCast(base + @intFromBool(upper < max)));
4135+
return switch (max) {
4136+
0 => 0,
4137+
else => 1 + std.math.log2_int(u64, max),
4138+
};
41394139
}
41404140

41414141
/// This is only used for comptime asserts. Bump this number when you make a change

0 commit comments

Comments
 (0)