diff --git a/src/parser/cxx/type_checker.cc b/src/parser/cxx/type_checker.cc index b9f9a658..dfc1d8eb 100644 --- a/src/parser/cxx/type_checker.cc +++ b/src/parser/cxx/type_checker.cc @@ -100,6 +100,8 @@ struct TypeChecker::Visitor { [[nodiscard]] auto ensure_prvalue(ExpressionAST*& expr) -> bool; + void adjust_cv(ExpressionAST* expr); + [[nodiscard]] auto implicit_conversion(ExpressionAST*& expr, const Type* destinationType) -> bool; @@ -247,9 +249,9 @@ void TypeChecker::Visitor::operator()(IdExpressionAST* ast) { } else { ast->type = control()->remove_reference(ast->symbol->type()); - if (ast->symbol->isEnumOrScopedEnum() || - ast->symbol->isNonTypeParameter()) { + if (ast->symbol->isEnumerator() || ast->symbol->isNonTypeParameter()) { ast->valueCategory = ValueCategory::kPrValue; + adjust_cv(ast); } else { ast->valueCategory = ValueCategory::kLValue; } @@ -334,6 +336,10 @@ void TypeChecker::Visitor::operator()(CallExpressionAST* ast) { } else { ast->valueCategory = ValueCategory::kPrValue; } + + if (ast->valueCategory == ValueCategory::kPrValue) { + adjust_cv(ast); + } } void TypeChecker::Visitor::operator()(TypeConstructionAST* ast) {} @@ -364,6 +370,10 @@ void TypeChecker::Visitor::operator()(CppCastExpressionAST* ast) { default: break; } // switch + + if (ast->valueCategory == ValueCategory::kPrValue) { + adjust_cv(ast); + } } void TypeChecker::Visitor::check_cpp_cast_expression( @@ -409,6 +419,7 @@ auto TypeChecker::Visitor::check_static_cast(CppCastExpressionAST* ast) auto source = ast->expression; (void)ensure_prvalue(source); + adjust_cv(source); auto sourcePtr = type_cast(source->type); if (!sourcePtr) return false; @@ -493,6 +504,7 @@ void TypeChecker::Visitor::operator()(UnaryExpressionAST* ast) { auto pointerType = type_cast(ast->expression->type); if (pointerType) { (void)ensure_prvalue(ast->expression); + adjust_cv(ast->expression); ast->type = pointerType->elementType(); ast->valueCategory = ValueCategory::kLValue; } @@ -552,10 +564,10 @@ void TypeChecker::Visitor::operator()(UnaryExpressionAST* ast) { case TokenKind::T_PLUS: { ExpressionAST* expr = ast->expression; (void)ensure_prvalue(expr); - auto ty = control()->remove_cvref(expr->type); - if (control()->is_arithmetic_or_unscoped_enum(ty) || - control()->is_pointer(ty)) { - if (control()->is_integral_or_unscoped_enum(ty)) { + adjust_cv(expr); + if (control()->is_arithmetic_or_unscoped_enum(expr->type) || + control()->is_pointer(expr->type)) { + if (control()->is_integral_or_unscoped_enum(expr->type)) { (void)integral_promotion(expr); } ast->expression = expr; @@ -568,9 +580,9 @@ void TypeChecker::Visitor::operator()(UnaryExpressionAST* ast) { case TokenKind::T_MINUS: { ExpressionAST* expr = ast->expression; (void)ensure_prvalue(expr); - auto ty = control()->remove_cvref(expr->type); - if (control()->is_arithmetic_or_unscoped_enum(ty)) { - if (control()->is_integral_or_unscoped_enum(ty)) { + adjust_cv(expr); + if (control()->is_arithmetic_or_unscoped_enum(expr->type)) { + if (control()->is_integral_or_unscoped_enum(expr->type)) { (void)integral_promotion(expr); } ast->expression = expr; @@ -590,8 +602,8 @@ void TypeChecker::Visitor::operator()(UnaryExpressionAST* ast) { case TokenKind::T_TILDE: { ExpressionAST* expr = ast->expression; (void)ensure_prvalue(expr); - auto ty = control()->remove_cvref(expr->type); - if (control()->is_integral_or_unscoped_enum(ty)) { + adjust_cv(expr); + if (control()->is_integral_or_unscoped_enum(expr->type)) { (void)integral_promotion(expr); ast->expression = expr; ast->type = expr->type; @@ -713,8 +725,10 @@ void TypeChecker::Visitor::operator()(CastExpressionAST* ast) { ast->valueCategory = ValueCategory::kLValue; else if (control()->is_rvalue_reference(ast->typeId->type)) ast->valueCategory = ValueCategory::kXValue; - else + else { ast->valueCategory = ValueCategory::kPrValue; + adjust_cv(ast); + } } } @@ -750,12 +764,6 @@ void TypeChecker::Visitor::operator()(BinaryExpressionAST* ast) { break; case TokenKind::T_PERCENT: - if (!control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) - break; - - if (!control()->is_integral_or_unscoped_enum(ast->rightExpression->type)) - break; - ast->type = usual_arithmetic_conversion(ast->leftExpression, ast->rightExpression); @@ -763,15 +771,8 @@ void TypeChecker::Visitor::operator()(BinaryExpressionAST* ast) { case TokenKind::T_LESS_LESS: case TokenKind::T_GREATER_GREATER: - if (!control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) - break; - - if (!control()->is_integral_or_unscoped_enum(ast->rightExpression->type)) - break; - (void)usual_arithmetic_conversion(ast->leftExpression, ast->rightExpression); - ast->type = ast->leftExpression->type; break; @@ -795,12 +796,6 @@ void TypeChecker::Visitor::operator()(BinaryExpressionAST* ast) { case TokenKind::T_AMP: case TokenKind::T_CARET: case TokenKind::T_BAR: - if (!control()->is_integral_or_unscoped_enum(ast->leftExpression->type)) - break; - - if (!control()->is_integral_or_unscoped_enum(ast->rightExpression->type)) - break; - ast->type = usual_arithmetic_conversion(ast->leftExpression, ast->rightExpression); @@ -906,10 +901,9 @@ auto TypeChecker::Visitor::lvalue_to_rvalue_conversion(ExpressionAST*& expr) -> bool { if (!is_glvalue(expr)) return false; - auto unref = control()->remove_cvref(expr->type); - if (control()->is_function(unref)) return false; - if (control()->is_array(unref)) return false; - if (!control()->is_complete(unref)) return false; + if (control()->is_function(expr->type)) return false; + if (control()->is_array(expr->type)) return false; + if (!control()->is_complete(expr->type)) return false; auto cast = make_node(arena()); cast->castKind = ImplicitCastKind::kLValueToRValueConversion; cast->expression = expr; @@ -948,9 +942,8 @@ auto TypeChecker::Visitor::function_to_pointer_conversion(ExpressionAST*& expr) auto TypeChecker::Visitor::integral_promotion(ExpressionAST*& expr) -> bool { if (!is_prvalue(expr)) return false; - auto ty = control()->remove_cv(expr->type); - - if (!control()->is_integral(ty) && !control()->is_enum(ty)) return false; + if (!control()->is_integral(expr->type) && !control()->is_enum(expr->type)) + return false; auto make_implicit_cast = [&](const Type* type) { auto cast = make_node(arena()); @@ -963,7 +956,7 @@ auto TypeChecker::Visitor::integral_promotion(ExpressionAST*& expr) -> bool { // TODO: bit-fields - switch (ty->kind()) { + switch (expr->type->kind()) { case TypeKind::kChar: case TypeKind::kSignedChar: case TypeKind::kUnsignedChar: @@ -990,7 +983,7 @@ auto TypeChecker::Visitor::integral_promotion(ExpressionAST*& expr) -> bool { break; } // switch - if (auto enumType = type_cast(ty)) { + if (auto enumType = type_cast(expr->type)) { auto type = enumType->underlyingType(); if (!type) { @@ -1010,11 +1003,9 @@ auto TypeChecker::Visitor::floating_point_promotion(ExpressionAST*& expr) -> bool { if (!is_prvalue(expr)) return false; - auto ty = control()->remove_cv(expr->type); - - if (!control()->is_floating_point(ty)) return false; + if (!control()->is_floating_point(expr->type)) return false; - if (ty->kind() != TypeKind::kFloat) return false; + if (expr->type->kind() != TypeKind::kFloat) return false; auto cast = make_node(arena()); cast->castKind = ImplicitCastKind::kFloatingPointPromotion; @@ -1206,8 +1197,7 @@ auto TypeChecker::Visitor::pointer_conversion(ExpressionAST*& expr, control()->get_cv_qualifiers(destinationPointerType->elementType())) return false; - if (!control()->is_void( - control()->remove_cv(destinationPointerType->elementType()))) + if (!control()->is_void(destinationPointerType->elementType())) return false; make_implicit_cast(); @@ -1320,7 +1310,7 @@ auto TypeChecker::Visitor::function_pointer_conversion( auto TypeChecker::Visitor::boolean_conversion(ExpressionAST*& expr, const Type* destinationType) -> bool { - if (!type_cast(control()->remove_cv(destinationType))) return false; + if (!type_cast(destinationType)) return false; if (!is_prvalue(expr)) return false; @@ -1363,9 +1353,22 @@ auto TypeChecker::Visitor::ensure_prvalue(ExpressionAST*& expr) -> bool { if (lvalue_to_rvalue_conversion(expr)) return true; if (array_to_pointer_conversion(expr)) return true; if (function_to_pointer_conversion(expr)) return true; + return false; } +void TypeChecker::Visitor::adjust_cv(ExpressionAST* expr) { + if (!is_prvalue(expr)) return; + + auto qualType = type_cast(expr->type); + if (!qualType) return; + + if (control()->is_class(expr->type) || control()->is_array(expr->type)) + return; + + expr->type = qualType->elementType(); +} + auto TypeChecker::Visitor::implicit_conversion(ExpressionAST*& expr, const Type* destinationType) -> bool { @@ -1378,9 +1381,7 @@ auto TypeChecker::Visitor::implicit_conversion(ExpressionAST*& expr, auto savedExpr = expr; auto didConvert = ensure_prvalue(expr); - if (control()->is_scalar(expr->type)) { - expr->type = control()->remove_cv(expr->type); - } + adjust_cv(expr); if (integral_promotion(expr)) return true; if (floating_point_promotion(expr)) return true; @@ -1404,8 +1405,20 @@ auto TypeChecker::Visitor::implicit_conversion(ExpressionAST*& expr, auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, ExpressionAST*& other) -> const Type* { - if (!expr || !expr->type) return nullptr; - if (!other || !other->type) return nullptr; + if (!control()->is_arithmetic(expr->type) && !control()->is_enum(expr->type)) + return nullptr; + + if (!control()->is_arithmetic(other->type) && + !control()->is_enum(other->type)) + return nullptr; + + (void)lvalue_to_rvalue_conversion(expr); + adjust_cv(expr); + + (void)lvalue_to_rvalue_conversion(other); + adjust_cv(other); + + if (control()->is_same(expr->type, other->type)) return expr->type; ExpressionAST* savedExpr = expr; ExpressionAST* savedOther = other; @@ -1416,42 +1429,32 @@ auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, return nullptr; }; - if (!control()->is_arithmetic_or_unscoped_enum(expr->type) || - !control()->is_arithmetic_or_unscoped_enum(other->type)) - return unmodifiedExpressions(); - - (void)lvalue_to_rvalue_conversion(expr); - (void)lvalue_to_rvalue_conversion(other); - if (control()->is_scoped_enum(expr->type) || control()->is_scoped_enum(other->type)) return unmodifiedExpressions(); if (control()->is_floating_point(expr->type) || control()->is_floating_point(other->type)) { - auto leftType = control()->remove_cv(expr->type); - auto rightType = control()->remove_cv(other->type); + if (control()->is_same(expr->type, other->type)) return expr->type; - if (control()->is_same(leftType, rightType)) return leftType; - - if (!control()->is_floating_point(leftType)) { - if (floating_integral_conversion(expr, rightType)) return rightType; + if (!control()->is_floating_point(expr->type)) { + if (floating_integral_conversion(expr, other->type)) return other->type; return unmodifiedExpressions(); } - if (!control()->is_floating_point(rightType)) { - if (floating_integral_conversion(other, leftType)) return leftType; + if (!control()->is_floating_point(other->type)) { + if (floating_integral_conversion(other, expr->type)) return expr->type; return unmodifiedExpressions(); } - if (leftType->kind() == TypeKind::kLongDouble || - rightType->kind() == TypeKind::kLongDouble) { + if (expr->type->kind() == TypeKind::kLongDouble || + other->type->kind() == TypeKind::kLongDouble) { (void)floating_point_conversion(expr, control()->getLongDoubleType()); return control()->getLongDoubleType(); } - if (leftType->kind() == TypeKind::kDouble || - rightType->kind() == TypeKind::kDouble) { + if (expr->type->kind() == TypeKind::kDouble || + other->type->kind() == TypeKind::kDouble) { (void)floating_point_conversion(expr, control()->getDoubleType()); return control()->getDoubleType(); } @@ -1462,13 +1465,11 @@ auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, (void)integral_promotion(expr); (void)integral_promotion(other); - const auto leftType = control()->remove_cv(expr->type); - const auto rightType = control()->remove_cv(other->type); - - if (control()->is_same(leftType, rightType)) return leftType; + if (control()->is_same(expr->type, other->type)) return expr->type; auto match_integral_type = [&](const Type* type) -> bool { - if (leftType->kind() == type->kind() || rightType->kind() == type->kind()) { + if (expr->type->kind() == type->kind() || + other->type->kind() == type->kind()) { (void)integral_conversion(expr, type); (void)integral_conversion(other, type); return true; @@ -1476,7 +1477,7 @@ auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, return false; }; - if (control()->is_signed(leftType) && control()->is_signed(rightType)) { + if (control()->is_signed(expr->type) && control()->is_signed(other->type)) { if (match_integral_type(control()->getLongLongIntType())) { return control()->getLongLongIntType(); } @@ -1490,7 +1491,8 @@ auto TypeChecker::Visitor::usual_arithmetic_conversion(ExpressionAST*& expr, return control()->getIntType(); } - if (control()->is_unsigned(leftType) && control()->is_unsigned(rightType)) { + if (control()->is_unsigned(expr->type) && + control()->is_unsigned(other->type)) { if (match_integral_type(control()->getUnsignedLongLongIntType())) { return control()->getUnsignedLongLongIntType(); } @@ -1567,8 +1569,9 @@ void TypeChecker::check(ExpressionAST* ast) { } void TypeChecker::Visitor::check_addition(BinaryExpressionAST* ast) { - (void)ensure_prvalue(ast->leftExpression); - (void)ensure_prvalue(ast->rightExpression); + // ### TODO: check for user-defined conversion operators + if (control()->is_class(ast->leftExpression->type)) return; + if (control()->is_class(ast->rightExpression->type)) return; if (auto ty = usual_arithmetic_conversion(ast->leftExpression, ast->rightExpression)) { @@ -1576,6 +1579,12 @@ void TypeChecker::Visitor::check_addition(BinaryExpressionAST* ast) { return; } + (void)ensure_prvalue(ast->leftExpression); + adjust_cv(ast->leftExpression); + + (void)ensure_prvalue(ast->rightExpression); + adjust_cv(ast->rightExpression); + const auto left_is_pointer = control()->is_pointer(ast->leftExpression->type); const auto right_is_pointer = @@ -1607,8 +1616,9 @@ void TypeChecker::Visitor::check_addition(BinaryExpressionAST* ast) { } void TypeChecker::Visitor::check_subtraction(BinaryExpressionAST* ast) { - (void)ensure_prvalue(ast->leftExpression); - (void)ensure_prvalue(ast->rightExpression); + // ### TODO: check for user-defined conversion operators + if (control()->is_class(ast->leftExpression->type)) return; + if (control()->is_class(ast->rightExpression->type)) return; if (auto ty = usual_arithmetic_conversion(ast->leftExpression, ast->rightExpression)) { @@ -1616,32 +1626,47 @@ void TypeChecker::Visitor::check_subtraction(BinaryExpressionAST* ast) { return; } - const auto left_is_pointer = control()->is_pointer(ast->leftExpression->type); + (void)ensure_prvalue(ast->leftExpression); + adjust_cv(ast->leftExpression); - const auto right_is_pointer = - control()->is_pointer(ast->rightExpression->type); + (void)ensure_prvalue(ast->rightExpression); + adjust_cv(ast->rightExpression); - if (left_is_pointer && right_is_pointer) { - auto lhs = control()->remove_cv(ast->leftExpression->type); - auto rhs = control()->remove_cv(ast->rightExpression->type); - if (control()->is_same(lhs, rhs)) { + auto check_operand_types = [&]() { + if (!control()->is_pointer(ast->leftExpression->type)) return false; + + if (!control()->is_arithmetic_or_unscoped_enum( + ast->rightExpression->type) && + !control()->is_pointer(ast->rightExpression->type)) + return false; + + return true; + }; + + if (!check_operand_types()) { + error(ast->opLoc, + std::format("invalid operands to binary expression '{}' and '{}'", + to_string(ast->leftExpression->type), + to_string(ast->rightExpression->type))); + return; + } + + if (control()->is_pointer(ast->rightExpression->type)) { + if (control()->is_same(ast->leftExpression->type, + ast->rightExpression->type)) { ast->type = control()->getLongIntType(); // TODO: ptrdiff_t - return; + } else { + error(ast->opLoc, + std::format("'{}' and '{}' are not pointers to compatible types", + to_string(ast->leftExpression->type), + to_string(ast->rightExpression->type))); } - } - if (left_is_pointer && - control()->is_integral_or_unscoped_enum(ast->rightExpression->type)) { - (void)integral_promotion(ast->rightExpression); - ast->type = ast->leftExpression->type; return; } - error(ast->opLoc, - std::format( - "invalid operands of types '{}' and '{}' to binary operator '-'", - to_string(ast->leftExpression->type), - to_string(ast->rightExpression->type))); + (void)integral_promotion(ast->rightExpression); + ast->type = ast->leftExpression->type; } auto TypeChecker::Visitor::check_member_access(MemberExpressionAST* ast)