diff --git a/scala/scala-impl/src/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspection.scala b/scala/scala-impl/src/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspection.scala index bbf6fb20a9e..77b6e54a088 100644 --- a/scala/scala-impl/src/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspection.scala +++ b/scala/scala-impl/src/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspection.scala @@ -1,6 +1,6 @@ package org.jetbrains.plugins.scala.codeInspection.typeChecking -import com.intellij.codeInspection.{LocalInspectionTool, ProblemHighlightType, ProblemsHolder} +import com.intellij.codeInspection.{LocalInspectionTool, ProblemsHolder} import com.intellij.psi.PsiMethod import com.siyeh.ig.psiutils.MethodUtils import org.jetbrains.annotations.Nls @@ -8,9 +8,10 @@ import org.jetbrains.plugins.scala.codeInspection.collections.MethodRepr import org.jetbrains.plugins.scala.codeInspection.typeChecking.ComparingUnrelatedTypesInspection._ import org.jetbrains.plugins.scala.codeInspection.{PsiElementVisitorSimple, ScalaInspectionBundle} import org.jetbrains.plugins.scala.extensions._ +import org.jetbrains.plugins.scala.lang.psi.api.base.types.ScParameterizedTypeElement import org.jetbrains.plugins.scala.lang.psi.api.expr.{ScExpression, ScReferenceExpression} import org.jetbrains.plugins.scala.lang.psi.api.statements.ScFunction -import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.ScClass +import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.{ScClass, ScGiven} import org.jetbrains.plugins.scala.lang.psi.impl.toplevel.synthetic.ScSyntheticFunction import org.jetbrains.plugins.scala.lang.psi.types._ import org.jetbrains.plugins.scala.lang.psi.types.api._ @@ -127,12 +128,36 @@ object ComparingUnrelatedTypesInspection { } } } + + private def hasCanEqual(expr: ScExpression, source: ScType, target: ScType): Boolean = { + lazy val expressionTypes: Seq[ScType] = List(source, target) + lazy val canEqualExists: Boolean = expr + .contexts + .flatMap(_.children) + .filterByType[ScGiven] + .filter(_.`type`().map(_.canonicalText.matches("_root_\\.scala\\.CanEqual\\[.+?, .+?]")).getOrElse(false)) + .flatMap(_.children.filterByType[ScParameterizedTypeElement]) + .map(_.typeArgList.typeArgs.flatMap(_.`type`().map(_.tryExtractDesignatorSingleton).toSeq)) + .exists(_ + .zip(expressionTypes) + .forall { + case (givenType, compType) => + !checkComparability(givenType, compType, isBuiltinOperation = true).shouldNotBeCompared + } + ) + + val wideSource: ScType = source.widenIfLiteral + // Even though CanEqual[Primitive | String, _] can be defined and will satisfy compiler in strictEquals mode, + // it is not possible to override equals method on the primitives or Strings + !wideSource.isPrimitive && + !wideSource.canonicalText.matches("_root_\\.java\\.lang\\.String") && + (expr.isCompilerStrictEqualityMode || canEqualExists) + } } class ComparingUnrelatedTypesInspection extends LocalInspectionTool { override def buildVisitor(holder: ProblemsHolder, isOnTheFly: Boolean): PsiElementVisitorSimple = { - case e if e.isInScala3File => () // TODO Handle Scala 3 code (`CanEqual` instances, etc.), SCL-19722 case MethodRepr(expr, Some(left), Some(oper), Seq(right)) if isComparingFunctions(oper.refName) => // "blub" == 3 val needHighlighting = oper.resolve() match { @@ -145,7 +170,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool { case Seq(Right(leftType), Right(rightType)) => val isBuiltinOperation = isIdentityFunction(oper.refName) || !hasNonDefaultEquals(leftType) val comparability = checkComparability(leftType, rightType, isBuiltinOperation) - if (comparability.shouldNotBeCompared) { + if ((!expr.isInScala3File && comparability.shouldNotBeCompared) || + (expr.isInScala3File && comparability.shouldNotBeCompared && !hasCanEqual(expr, leftType, rightType))) { val message = generateComparingUnrelatedTypesMsg(leftType, rightType)(expr) holder.registerProblem(expr, message) } @@ -158,7 +184,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool { ParameterizedType(_, Seq(elemType)) <- receiverType(baseExpr, ref).map(_.tryExtractDesignatorSingleton) argType <- arg.`type`().toOption comparability = checkComparability(elemType, argType, isBuiltinOperation = !hasNonDefaultEquals(elemType)) - if comparability.shouldNotBeCompared + if (!baseExpr.isInScala3File && comparability.shouldNotBeCompared) || + (baseExpr.isInScala3File && comparability.shouldNotBeCompared && !hasCanEqual(baseExpr, elemType, argType)) } { val message = generateComparingUnrelatedTypesMsg(elemType, argType)(arg) holder.registerProblem(arg, message) diff --git a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/ScalaModuleSettings.scala b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/ScalaModuleSettings.scala index bbbdb3e9a8c..66c40f58f7d 100644 --- a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/ScalaModuleSettings.scala +++ b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/ScalaModuleSettings.scala @@ -7,7 +7,7 @@ import com.intellij.openapi.roots.{OrderEnumerator, OrderRootType, libraries} import com.intellij.openapi.util.ModificationTracker import com.intellij.openapi.util.io.JarUtil.{containsEntry, getJarAttribute} import com.intellij.openapi.vfs.VirtualFile -import com.intellij.util.CommonProcessors.{CollectProcessor, FindProcessor} +import com.intellij.util.CommonProcessors.FindProcessor import org.jetbrains.plugins.scala.ScalaVersion import org.jetbrains.plugins.scala.caches.cached import org.jetbrains.plugins.scala.project.ScalaFeatures.SerializableScalaFeatures @@ -18,7 +18,6 @@ import org.jetbrains.sbt.project.SbtVersionProvider import java.io.File import java.util.jar.Attributes -import scala.jdk.CollectionConverters.IteratorHasAsScala private class ScalaModuleSettings private( module: Module, @@ -142,6 +141,9 @@ private class ScalaModuleSettings private( val isCompilerStrictMode: Boolean = settingsForHighlighting.exists(_.strict) + val isCompilerStrictEqualityMode: Boolean = + settingsForHighlighting.exists(_.strictEquality) + val customDefaultImports: Option[Seq[String]] = additionalCompilerOptions.collectFirst { case Yimports(imports) if scalaLanguageLevel >= Scala_2_13 => imports diff --git a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/package.scala b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/package.scala index 4d256f23b35..0b831aae325 100644 --- a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/package.scala +++ b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/package.scala @@ -288,6 +288,9 @@ package object project { def isCompilerStrictMode: Boolean = scalaModuleSettings.exists(_.isCompilerStrictMode) + def isCompilerStrictEqualityMode: Boolean = + scalaModuleSettings.exists(_.isCompilerStrictEqualityMode) + def scalaCompilerClasspath: Seq[File] = module.scalaSdk .fold(throw new ScalaSdkNotConfiguredException(module)) { _.properties.compilerClasspath @@ -537,6 +540,8 @@ package object project { def isCompilerStrictMode: Boolean = module.exists(_.isCompilerStrictMode) + def isCompilerStrictEqualityMode: Boolean = isInScala3Module && module.exists(_.isCompilerStrictEqualityMode) + def scalaLanguageLevel: Option[ScalaLanguageLevel] = fromFeaturesOrModule(_.languageLevel, _.scalaLanguageLevel) diff --git a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/settings/ScalaCompilerSettings.scala b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/settings/ScalaCompilerSettings.scala index 24ff6185117..289cb49891a 100644 --- a/scala/scala-impl/src/org/jetbrains/plugins/scala/project/settings/ScalaCompilerSettings.scala +++ b/scala/scala-impl/src/org/jetbrains/plugins/scala/project/settings/ScalaCompilerSettings.scala @@ -50,6 +50,8 @@ case class ScalaCompilerSettings(compileOrder: CompileOrder, val languageWildcard: Boolean = additionalCompilerOptions.contains("-language:_") || additionalCompilerOptions.contains("--language:_") val strict: Boolean = additionalCompilerOptions.contains("-strict") + val strictEquality: Boolean = additionalCompilerOptions.contains("-language:strictEquality") || + additionalCompilerOptions.contains("--language:strictEquality") def getOptionsAsStrings(forScala3Compiler: Boolean): Seq[String] = { val state = toState diff --git a/scala/scala-impl/test/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspectionTest.scala b/scala/scala-impl/test/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspectionTest.scala index 205c4b175a7..aad27244126 100644 --- a/scala/scala-impl/test/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspectionTest.scala +++ b/scala/scala-impl/test/org/jetbrains/plugins/scala/codeInspection/typeChecking/ComparingUnrelatedTypesInspectionTest.scala @@ -9,7 +9,8 @@ import org.junit.runner.RunWith @RunWith(classOf[MultipleScalaVersionsRunner]) @RunWithScalaVersions(Array( TestScalaVersion.Scala_2_12, - TestScalaVersion.Scala_2_13 + TestScalaVersion.Scala_2_13, + TestScalaVersion.Scala_3_Latest )) abstract class ComparingUnrelatedTypesInspectionTest extends ScalaInspectionTestBase { @@ -316,6 +317,8 @@ class Test20 extends ComparingUnrelatedTypesInspectionTest { class Test21 extends ComparingUnrelatedTypesInspectionTest { + override def supportedIn(version: ScalaVersion): Boolean = version < LatestScalaVersions.Scala_3_0 + override protected val description: String = ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Some[Int]", "List[_]") @@ -324,6 +327,22 @@ class Test21 extends ComparingUnrelatedTypesInspectionTest { ) } +class Test21_Scala_3 extends ComparingUnrelatedTypesInspectionTest { + + override def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Some[Int]", "List[?]") + + def testExistential(): Unit = checkTextHasError( + s"${START}Some(1).isInstanceOf[List[?]]$END" + ) + + def testExistentialOldWildcard(): Unit = checkTextHasError( + s"${START}Some(1).isInstanceOf[List[_]]$END" + ) +} + class Test22 extends ComparingUnrelatedTypesInspectionTest { override protected val description: String = @@ -336,6 +355,8 @@ class Test22 extends ComparingUnrelatedTypesInspectionTest { class Test23 extends ComparingUnrelatedTypesInspectionTest { + override def supportedIn(version: ScalaVersion): Boolean = version < LatestScalaVersions.Scala_3_0 + override protected val description: String = ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Some[_]", "Seq[Int]") @@ -344,6 +365,22 @@ class Test23 extends ComparingUnrelatedTypesInspectionTest { ) } +class Test23_Scala_3 extends ComparingUnrelatedTypesInspectionTest { + + override def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Some[?]", "Seq[Int]") + + def testExistential(): Unit = checkTextHasError( + s"def foo(x: Some[?]) { ${START}x == Seq(1)$END }" + ) + + def testExistentialOldWildcard(): Unit = checkTextHasError( + s"def foo(x: Some[_]) { ${START}x == Seq(1)$END }" + ) +} + class Test24 extends ComparingUnrelatedTypesInspectionTest { override protected val description: String = @@ -437,7 +474,8 @@ class Test29 extends ComparingUnrelatedTypesInspectionTest { @RunWithScalaVersions(Array( TestScalaVersion.Scala_2_10, TestScalaVersion.Scala_2_12, - TestScalaVersion.Scala_2_13 + TestScalaVersion.Scala_2_13, + TestScalaVersion.Scala_3_Latest )) class Test30 extends ComparingUnrelatedTypesInspectionTest { @@ -659,9 +697,6 @@ class TestInstanceOfAutoBoxing3 extends ComparingUnrelatedTypesInspectionTest { ) } -@RunWithScalaVersions(Array( - TestScalaVersion.Scala_2_13 -)) class TestLiteralTypes extends ComparingUnrelatedTypesInspectionTest { override protected def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_2_13 @@ -1026,3 +1061,214 @@ class TestVariousCasesWithStdTypes extends ComparingUnrelatedTypesInspectionTest |}""".stripMargin ) } + +class Test33 extends ComparingUnrelatedTypesInspectionTest { + override protected def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "String", "Int") + + def testCanEqualLiteralTypes(): Unit = checkTextHasError( + s""" + |given CanEqual[String, Int] = CanEqual.derived + |$START"1" == 1$END + |""".stripMargin + ) +} + +class Test34 extends ComparingUnrelatedTypesInspectionTest { + override protected def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Boolean", "Int") + + def testCanEqualPrimitiveTypes(): Unit = checkTextHasError( + s""" + |given CanEqual[Boolean, Int] = CanEqual.derived + |val b: Boolean = true + |val i: Int = 1 + |${START}b == i$END + |""".stripMargin + ) +} + +class Test35 extends ComparingUnrelatedTypesInspectionTest { + override protected def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "Boolean", "Int") + + def testCanEqualPrimitiveTypes(): Unit = checkTextHasError( + s""" + |given CanEqual[Boolean, Int] = CanEqual.derived + |val b: Boolean = true + |val i: Int = 1 + |${START}b == i$END + |""".stripMargin + ) +} + +class Test36 extends ComparingUnrelatedTypesInspectionTest { + override protected def supportedIn(version: ScalaVersion): Boolean = version >= LatestScalaVersions.Scala_3_0 + + override protected val description: String = + ScalaInspectionBundle.message("comparing.unrelated.types.hint", "A", "B") + + def testCanEqualDefined(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + |given CanEqual[A, B] = CanEqual.derived + |A() == B() + |""".stripMargin + ) + + def testCanEqualComparisonWithUnderscore(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + |given CanEqual[A, B] = CanEqual.derived + |List(A()).find(_ == B()) + |""".stripMargin + ) + + def testCanEqualSuperType(): Unit = checkTextHasNoErrors( + s""" + |trait T + |case class A() extends T + |case class B() + |given CanEqual[T, B] = CanEqual.derived + |A() == B() + |""".stripMargin + ) + + def testCanEqualOutsideComparisonContext(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + |given CanEqual[A, B] = CanEqual.derived + |def test(): Boolean = A() == B() + |""".stripMargin + ) + + def testCanEqualInHigherContext(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + |given CanEqual[A, B] = CanEqual.derived + | + |object Test: + | def test(): Boolean = A() == B() + |""".stripMargin + ) + + def testCanEqualSeveralContextsHigher(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + | + |object Highest: + | given CanEqual[A, B] = CanEqual.derived + | + | object Higher: + | + | object High: + | def test(): Boolean = A() == B() + |""".stripMargin + ) + + def testCanEqualInUnrelatedContext(): Unit = checkTextHasError( + s""" + |case class A() + |case class B() + | + |object Unrelated: + | given CanEqual[A, B] = CanEqual.derived + | + |object Test: + | def test(): Boolean = ${START}A() == B()$END + |""".stripMargin + ) + + def testCanEqualNotDefined(): Unit = checkTextHasError( + s""" + |case class A() + |case class B() + |${START}A() == B()$END + |""".stripMargin + ) + + def testCanEqualForIncompatibleSourceType(): Unit = checkTextHasError( + s""" + |case class A() + |case class B() + |case class C() + |given CanEqual[C, B] = CanEqual.derived + |${START}A() == B()$END + |""".stripMargin + ) + + def testCanEqualIncompatibleTargetType(): Unit = checkTextHasError( + s""" + |case class A() + |case class B() + |case class C() + |given CanEqual[A, C] = CanEqual.derived + |${START}A() == B()$END + |""".stripMargin + ) + + def testCanEqualCaseObjects(): Unit = checkTextHasNoErrors( + s""" + |case object A + |case object B + |given CanEqual[A, B] = CanEqual.derived + |A == B + |""".stripMargin + ) + + def testCanEqualIncompatibleCaseObjects(): Unit = checkTextHasError( + s""" + |case object A + |case object B + |case object C + |given CanEqual[A, C] = CanEqual.derived + |${START}A == B$END + |""".stripMargin + ) + + def testCanEqualImported(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + | + |import Givens.given + |A() == B() + | + |object Givens: + | given CanEqual[A, B] = CanEqual.derived + |""".stripMargin + ) + + def testCanEqualImportedFromCompanion(): Unit = checkTextHasNoErrors( + s""" + |case class A() + |case class B() + | + |object A: + | given CanEqual[A, B] = CanEqual.derived + | + |A() == B() + |""".stripMargin + ) + + def testCanEqualDerivedOnTrait(): Unit = checkTextHasNoErrors( + s""" + |trait T derives CanEqual + |case class A() extends T + |case class B() extends T + | + |A() == B() + |""".stripMargin + ) +}