From 7384097a0df0c38d8910c52792ee3deec8b0e648 Mon Sep 17 00:00:00 2001 From: Parker Liu Date: Wed, 16 Apr 2025 09:55:07 +0800 Subject: [PATCH] sema: The builtin function @max/@min support incompatible arbitrary integer types In normal arithmatic operations on two or more integer operands, all intergers must be compatible with each other. But in the function @max/@min, this constraint need not be satisfied. For example, the expression `@min(@as(i32, -30), @as(u32, 42))` is a legal expression. --- src/Sema.zig | 122 +++++++++++++++++++++++++++--- test/behavior/maximum_minimum.zig | 110 +++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 11 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index b30f42c2d7b5..ce1e89951b6f 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -24626,6 +24626,7 @@ fn checkSimdBinOp( sema: *Sema, block: *Block, src: LazySrcLoc, + comptime air_tag: Air.Inst.Tag, uncasted_lhs: Air.Inst.Ref, uncasted_rhs: Air.Inst.Ref, lhs_src: LazySrcLoc, @@ -24638,11 +24639,11 @@ fn checkSimdBinOp( try sema.checkVectorizableBinaryOperands(block, src, lhs_ty, rhs_ty, lhs_src, rhs_src); const vec_len: ?usize = if (lhs_ty.zigTypeTag(zcu) == .vector) lhs_ty.vectorLen(zcu) else null; - const result_ty = try sema.resolvePeerTypes(block, src, &.{ uncasted_lhs, uncasted_rhs }, .{ + const result_ty = try sema.resolvePeerTypesWithOp(block, src, air_tag, &.{ uncasted_lhs, uncasted_rhs }, .{ .override = &[_]?LazySrcLoc{ lhs_src, rhs_src }, }); - const lhs = try sema.coerce(block, result_ty, uncasted_lhs, lhs_src); - const rhs = try sema.coerce(block, result_ty, uncasted_rhs, rhs_src); + const lhs = try sema.coerceWithOp(block, result_ty, uncasted_lhs, lhs_src, air_tag); + const rhs = try sema.coerceWithOp(block, result_ty, uncasted_rhs, rhs_src, air_tag); return SimdBinOp{ .len = vec_len, @@ -25993,7 +25994,7 @@ fn analyzeMinMax( continue; }; - const simd_op = try sema.checkSimdBinOp(block, src, cur, operand, cur_minmax_src, operand_src); + const simd_op = try sema.checkSimdBinOp(block, src, air_tag, cur, operand, cur_minmax_src, operand_src); const cur_val = try sema.resolveLazyValue(simd_op.lhs_val.?); // cur_minmax is comptime-known const operand_val = try sema.resolveLazyValue(simd_op.rhs_val.?); // we checked the operand was resolvable above @@ -26085,7 +26086,7 @@ fn analyzeMinMax( const lhs_src = cur_minmax_src; const rhs = operands[idx]; const rhs_src = operand_srcs[idx]; - const simd_op = try sema.checkSimdBinOp(block, src, lhs, rhs, lhs_src, rhs_src); + const simd_op = try sema.checkSimdBinOp(block, src, air_tag, lhs, rhs, lhs_src, rhs_src); if (known_undef) { cur_minmax = try pt.undefRef(simd_op.result_ty); } else { @@ -29455,6 +29456,20 @@ pub fn coerce( }; } +pub fn coerceWithOp( + sema: *Sema, + block: *Block, + dest_ty_unresolved: Type, + inst: Air.Inst.Ref, + inst_src: LazySrcLoc, + op_tag: Air.Inst.Tag, +) CompileError!Air.Inst.Ref { + return sema.coerceExtra(block, dest_ty_unresolved, inst, inst_src, .{ .opt_op_tag = op_tag }) catch |err| switch (err) { + error.NotCoercible => unreachable, + else => |e| return e, + }; +} + const CoersionError = CompileError || error{ /// When coerce is called recursively, this error should be returned instead of using `fail` /// to ensure correct types in compile errors. @@ -29468,6 +29483,8 @@ const CoerceOpts = struct { is_ret: bool = false, /// Should coercion to comptime_int emit an error message. no_cast_to_comptime_int: bool = false, + /// The tag of operator in which the coerce is called + opt_op_tag: ?Air.Inst.Tag = null, param_src: struct { func_inst: Air.Inst.Ref = .none, @@ -29854,6 +29871,13 @@ fn coerceExtra( if (maybe_inst_val) |val| { // comptime-known integer to other number if (!(try sema.intFitsInType(val, dest_ty, null))) { + if (opts.opt_op_tag) |op_tag| { + switch (op_tag) { + .min => return Air.internedToRef((try dest_ty.maxInt(pt, dest_ty)).toIntern()), + .max => return pt.intRef(dest_ty, 0), + else => {}, + } + } if (!opts.report_err) return error.NotCoercible; return sema.fail(block, inst_src, "type '{}' cannot represent integer value '{}'", .{ dest_ty.fmt(pt), val.fmtValueSema(pt, sema) }); } @@ -29881,6 +29905,30 @@ fn coerceExtra( try sema.requireRuntimeBlock(block, inst_src, null); return block.addTyOp(.intcast, dest_ty, inst); } + + if (opts.opt_op_tag) |op_tag| { + switch (op_tag) { + .min => { + if (src_info.signedness != dst_info.signedness and dst_info.signedness == .signed) { + std.debug.assert(dst_info.bits <= src_info.bits); + try sema.requireRuntimeBlock(block, inst_src, null); + const max_int_inst = Air.internedToRef((try dest_ty.maxInt(pt, inst_ty)).toIntern()); + const min_inst = try block.addBinOp(.min, inst, max_int_inst); + return block.addTyOp(.intcast, dest_ty, min_inst); + } + }, + .max => { + if (src_info.signedness != dst_info.signedness and dst_info.signedness == .unsigned) { + std.debug.assert(dst_info.bits >= src_info.bits); + try sema.requireRuntimeBlock(block, inst_src, null); + const zero_inst = try pt.intRef(inst_ty, 0); + const max_inst = try block.addBinOp(.max, inst, zero_inst); + return block.addTyOp(.intcast, dest_ty, max_inst); + } + }, + else => {}, + } + } }, else => {}, }, @@ -30033,9 +30081,9 @@ fn coerceExtra( } } - return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src); + return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag); }, - .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src), + .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag), .@"struct" => { if (inst_ty.isTuple(zcu)) { return sema.coerceTupleToArray(block, dest_ty, dest_ty_src, inst, inst_src); @@ -30044,7 +30092,7 @@ fn coerceExtra( else => {}, }, .vector => switch (inst_ty.zigTypeTag(zcu)) { - .array, .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src), + .array, .vector => return sema.coerceArrayLike(block, dest_ty, dest_ty_src, inst, inst_src, opts.opt_op_tag), .@"struct" => { if (inst_ty.isTuple(zcu)) { return sema.coerceTupleToArray(block, dest_ty, dest_ty_src, inst, inst_src); @@ -31927,6 +31975,7 @@ fn coerceArrayLike( dest_ty_src: LazySrcLoc, inst: Air.Inst.Ref, inst_src: LazySrcLoc, + opt_op_tag: ?Air.Inst.Tag, ) !Air.Inst.Ref { const pt = sema.pt; const zcu = pt.zcu; @@ -31975,6 +32024,31 @@ fn coerceArrayLike( try sema.requireRuntimeBlock(block, inst_src, null); return block.addTyOp(.intcast, dest_ty, inst); } + + if (opt_op_tag) |op_tag| { + switch (op_tag) { + .min => { + if (src_info.signedness != dst_info.signedness and dst_info.signedness == .signed) { + std.debug.assert(dst_info.bits <= src_info.bits); + try sema.requireRuntimeBlock(block, inst_src, null); + const max_int_inst = Air.internedToRef((try dest_ty.maxInt(pt, inst_ty)).toIntern()); + const min_inst = try block.addBinOp(.min, inst, max_int_inst); + return block.addTyOp(.intcast, dest_ty, min_inst); + } + }, + .max => { + if (src_info.signedness != dst_info.signedness and dst_info.signedness == .unsigned) { + std.debug.assert(dst_info.bits >= src_info.bits); + try sema.requireRuntimeBlock(block, inst_src, null); + const zeros = try sema.splat(inst_ty, try pt.intValue(inst_elem_ty, 0)); + const zero_inst = Air.internedToRef(zeros.toIntern()); + const max_inst = try block.addBinOp(.max, inst, zero_inst); + return block.addTyOp(.intcast, dest_ty, max_inst); + } + }, + else => {}, + } + } }, .float => if (inst_elem_ty.isRuntimeFloat()) { // float widening @@ -31998,7 +32072,10 @@ fn coerceArrayLike( const src = inst_src; // TODO better source location const elem_src = inst_src; // TODO better source location const elem_ref = try sema.elemValArray(block, src, inst_src, inst, elem_src, index_ref, true); - const coerced = try sema.coerce(block, dest_elem_ty, elem_ref, elem_src); + const coerced = if (opt_op_tag) |op_tag| + try sema.coerceWithOp(block, dest_elem_ty, elem_ref, elem_src, op_tag) + else + try sema.coerce(block, dest_elem_ty, elem_ref, elem_src); ref.* = coerced; if (runtime_src == null) { if (try sema.resolveValue(coerced)) |elem_val| { @@ -34073,6 +34150,17 @@ fn resolvePeerTypes( src: LazySrcLoc, instructions: []const Air.Inst.Ref, candidate_srcs: PeerTypeCandidateSrc, +) !Type { + return resolvePeerTypesWithOp(sema, block, src, null, instructions, candidate_srcs); +} + +fn resolvePeerTypesWithOp( + sema: *Sema, + block: *Block, + src: LazySrcLoc, + comptime opt_op_tag: ?Air.Inst.Tag, + instructions: []const Air.Inst.Ref, + candidate_srcs: PeerTypeCandidateSrc, ) !Type { switch (instructions.len) { 0 => return Type.noreturn, @@ -34099,7 +34187,7 @@ fn resolvePeerTypes( val.* = try sema.resolveValue(inst); } - switch (try sema.resolvePeerTypesInner(block, src, peer_tys, peer_vals)) { + switch (try sema.resolvePeerTypesInner(block, src, opt_op_tag, peer_tys, peer_vals)) { .success => |ty| return ty, else => |result| { const msg = try result.report(sema, block, src, instructions, candidate_srcs); @@ -34112,6 +34200,7 @@ fn resolvePeerTypesInner( sema: *Sema, block: *Block, src: LazySrcLoc, + comptime opt_op_tag: ?Air.Inst.Tag, peer_tys: []?Type, peer_vals: []?Value, ) !PeerResolveResult { @@ -34197,6 +34286,7 @@ fn resolvePeerTypesInner( const final_payload = switch (try sema.resolvePeerTypesInner( block, src, + opt_op_tag, peer_tys, peer_vals, )) { @@ -34235,6 +34325,7 @@ fn resolvePeerTypesInner( const child_ty = switch (try sema.resolvePeerTypesInner( block, src, + opt_op_tag, peer_tys, peer_vals, )) { @@ -34384,6 +34475,7 @@ fn resolvePeerTypesInner( const child_ty = switch (try sema.resolvePeerTypesInner( block, src, + opt_op_tag, peer_tys, peer_vals, )) { @@ -34989,6 +35081,14 @@ fn resolvePeerTypesInner( return .{ .success = peer_tys[idx_signed.?].? }; } + if (opt_op_tag) |op_tag| { + switch (op_tag) { + .min => return .{ .success = peer_tys[idx_signed.?].? }, + .max => return .{ .success = peer_tys[idx_unsigned.?].? }, + else => {}, + } + } + // TODO: this is for compatibility with legacy behavior. Before this version of PTR was // implemented, the algorithm very often returned false positives, with the expectation // that you'd just hit a coercion error later. One of these was that for integers, the @@ -35107,7 +35207,7 @@ fn resolvePeerTypesInner( } // Resolve field type recursively - field_ty.* = switch (try sema.resolvePeerTypesInner(block, src, sub_peer_tys, sub_peer_vals)) { + field_ty.* = switch (try sema.resolvePeerTypesInner(block, src, opt_op_tag, sub_peer_tys, sub_peer_vals)) { .success => |ty| ty.toIntern(), else => |result| { const result_buf = try sema.arena.create(PeerResolveResult); diff --git a/test/behavior/maximum_minimum.zig b/test/behavior/maximum_minimum.zig index 54e5db9b85ca..3044412ae09f 100644 --- a/test/behavior/maximum_minimum.zig +++ b/test/behavior/maximum_minimum.zig @@ -183,6 +183,116 @@ test "@min/@max more than two vector arguments" { try expectEqual(@Vector(2, u32){ 5, 2 }, @max(x, y, z)); } +test "@min/@max with incompatible arbitrary integer types" { + const x: i32 = -30; + const y: u32 = 0x8000_0010; + const z: u36 = 0x2_0000_0010; + + const M = struct { + fn min(lhs: i32, rhs: u32) i32 { + return @min(lhs, rhs); + } + + fn min3(fst: i32, snd: u32, thrd: u36) i32 { + return @min(fst, snd, thrd); + } + + fn max(lhs: i32, rhs: u32) u32 { + return @max(lhs, rhs); + } + + fn max3(fst: i32, snd: u32, thrd: u36) u36 { + return @max(fst, snd, thrd); + } + }; + + // test min for comptime value + const min = @min(x, y); + try expectEqual(i6, @TypeOf(min)); + try expectEqual(-30, min); + const min3 = @min(x, y, z); + try expectEqual(i6, @TypeOf(min3)); + try expectEqual(-30, min3); + + // test min for runtime value + const m_min = M.min(x, y); + try expectEqual(-30, m_min); + const m_min3 = M.min3(x, y, z); + try expectEqual(-30, m_min3); + + // test max for comptime value + const max = @max(x, y); + try expectEqual(u32, @TypeOf(max)); + try expectEqual(0x8000_0010, max); + const max3 = @max(x, y, z); + try expectEqual(u34, @TypeOf(max3)); + try expectEqual(0x2_0000_0010, max3); + + // test max for runtime value + const m_max = M.max(x, y); + try expectEqual(0x8000_0010, m_max); + const m_max3 = M.max3(x, y, z); + try expectEqual(0x2_0000_0010, m_max3); +} + +test "@min/@max vector with incompatible arbitrary integer types" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + + const x: @Vector(2, i16) = @splat(-30); + const y: @Vector(2, u16) = @splat(0x8010); + const z: @Vector(2, u32) = @splat(0x2_0010); + + const M = struct { + fn min(lhs: @Vector(2, i16), rhs: @Vector(2, u16)) @Vector(2, i16) { + return @min(lhs, rhs); + } + + fn min3(fst: @Vector(2, i16), snd: @Vector(2, u16), thrd: @Vector(2, u32)) @Vector(2, i16) { + return @min(fst, snd, thrd); + } + + fn max(lhs: @Vector(2, i16), rhs: @Vector(2, u16)) @Vector(2, u16) { + return @max(lhs, rhs); + } + + fn max3(fst: @Vector(2, i16), snd: @Vector(2, u16), thrd: @Vector(2, u32)) @Vector(2, u32) { + return @max(fst, snd, thrd); + } + }; + + // test min for comptime value + const min = @min(x, y); + try expectEqual(@Vector(2, i6), @TypeOf(min)); + try expectEqual(@as(@Vector(2, i16), @splat(-30)), min); + const min3 = @min(x, y, z); + try expectEqual(@Vector(2, i6), @TypeOf(min3)); + try expectEqual(@as(@Vector(2, i16), @splat(-30)), min3); + + // test min for runtime value + const m_min = M.min(x, y); + try expectEqual(@as(@Vector(2, i16), @splat(-30)), m_min); + const m_min3 = M.min3(x, y, z); + try expectEqual(@as(@Vector(2, i16), @splat(-30)), m_min3); + + // test max for comptime value + const max = @max(x, y); + try expectEqual(@Vector(2, u16), @TypeOf(max)); + try expectEqual(@as(@Vector(2, u16), @splat(0x8010)), max); + const max3 = @max(x, y, z); + try expectEqual(@Vector(2, u18), @TypeOf(max3)); + try expectEqual(@as(@Vector(2, u18), @splat(0x2_0010)), max3); + + // test max for runtime value + const m_max = M.max(x, y); + try expectEqual(@as(@Vector(2, u16), @splat(0x8010)), m_max); + const m_max3 = M.max3(x, y, z); + try expectEqual(@as(@Vector(2, u32), @splat(0x2_0010)), m_max3); +} + test "@min/@max notices bounds" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO