Skip to content

Commit 798a96e

Browse files
committed
refactor sub-cases projection
1 parent 4810547 commit 798a96e

File tree

6 files changed

+187
-24
lines changed

6 files changed

+187
-24
lines changed

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -936,46 +936,84 @@ object SpaceEngine {
936936
* ^ pat ^ selector
937937
*
938938
*/
939-
private def selectorIsBoundVar(selector: Tree, pat: Tree)(using Context): Boolean =
940-
pat match
941-
case b: Bind => selector.symbol == b.symbol
942-
case _ => false
939+
private object SelectorBoundVar:
940+
def unapply(args: (Tree, Tree))(using Context): Boolean =
941+
val (selector, pat) = args
942+
pat match
943+
case b: Bind => selector.symbol == b.symbol
944+
case _ => false
943945

944946
/** Find the index of the parameter in an outer UnApply pattern that directly binds the selector symbol.
945947
*
946948
* case Wrapper(c) if c match
947949
* ^ returns Some(0)
948950
*
949951
*/
950-
private def selectorParamIndex(selector: Tree, pat: Tree)(using Context): Option[Int] =
951-
unbind(pat) match
952-
case UnApply(_, _, pats) =>
953-
val idx = pats.indexWhere {
954-
case b: Bind => b.symbol == selector.symbol
955-
case _ => false
956-
}
957-
if idx >= 0 then Some(idx) else None
952+
private object SelectorParamIndex:
953+
def unapply(args: (Tree, Tree))(using Context): Option[Int] =
954+
val (selector, pat) = args
955+
unbind(pat) match
956+
case UnApply(_, _, pats) =>
957+
val idx = pats.indexWhere {
958+
case b: Bind => b.symbol == selector.symbol
959+
case _ => false
960+
}
961+
Option.when(idx >= 0)(idx)
962+
case _ => None
963+
964+
/** Find the constructor parameter index corresponding to a field access on the outer pattern's bound var.
965+
*
966+
* case x if x.version match -- returns Some(1) for Document(title, version)
967+
* ^^^^^^^^^ selector
968+
*
969+
*/
970+
private object SelectorFieldIndex:
971+
def unapply(args: (Tree, Tree))(using Context): Option[Int] =
972+
args match
973+
case (Select(qual, fieldName), b: Bind) if b.symbol == qual.symbol =>
974+
val cls = toUnderlying(qual.tpe).classSymbol
975+
if cls.is(CaseClass) && !cls.isOneOf(AbstractOrTrait) then
976+
val idx = cls.caseAccessors.indexWhere(_.name == fieldName)
977+
Option.when(idx >= 0)(idx)
978+
else None
979+
case _ => None
980+
981+
private def narrowProdParam(patSpace: Space, idx: Int, subSpace: Space)(using Context): Option[Space] =
982+
def narrow(prod: Prod): Option[Space] =
983+
val Prod(tp, unappTp, params) = prod
984+
if idx >= params.length then None
985+
else
986+
val narrowedParam = simplify(intersect(params(idx), subSpace))
987+
Some(simplify(Prod(tp, unappTp, params.updated(idx, narrowedParam))))
988+
patSpace match
989+
case prod @ Prod(tp, unappTp1, _) =>
990+
expandCaseClass(tp) match
991+
case null => None
992+
case Prod(_, unappTp2, _) if isSameUnapply(unappTp1, unappTp2) => narrow(prod)
993+
case Typ(tp, _) =>
994+
expandCaseClass(tp) match
995+
case null => None
996+
case prod => narrow(prod)
958997
case _ => None
959998

960999
private def projectSubMatch(pat: Tree, sm: SubMatch)(using Context): Option[Space] =
9611000
val Match(selector, cases) = sm
9621001

9631002
val subSpace = Or(cases.map(projectCaseDef))
1003+
if simplify(subSpace) == Empty then return None // all sub-cases are guarded or empty; treat outer case as partial
9641004
def selTyp = toUnderlying(selector.tpe)
9651005
def patSpace = project(pat)
9661006

967-
if selectorIsBoundVar(selector, pat) then
968-
Some(simplify(intersect(patSpace, subSpace)))
969-
else selectorParamIndex(selector, pat) match
970-
case Some(idx) =>
971-
patSpace match
972-
case Prod(tp, unappTp, params) =>
973-
val narrowedParam = simplify(intersect(params(idx), subSpace))
974-
Some(simplify(Prod(tp, unappTp, params.updated(idx, narrowedParam))))
975-
case _ => None
976-
case None =>
977-
if simplify(minus(project(selTyp), subSpace)) == Empty then Some(patSpace)
978-
else None
1007+
(selector, pat) match
1008+
case SelectorBoundVar() =>
1009+
Some(simplify(intersect(patSpace, subSpace)))
1010+
case SelectorParamIndex(idx) =>
1011+
narrowProdParam(patSpace, idx, subSpace)
1012+
case SelectorFieldIndex(idx) =>
1013+
narrowProdParam(patSpace, idx, subSpace)
1014+
case _ if simplify(minus(project(selTyp), subSpace)) == Empty =>
1015+
Some(patSpace)
1016+
case _ => None
9791017

9801018
/** Resolve the space covered by a case and whether it may be partial.
9811019
* @return (space, maybePartial) where maybePartial is true when the case

tests/pos/sub-cases-exhaustivity.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,40 @@ def getVersion(d: Option[Document]): String = d match
7878
case Version.Stable(m, n) => s"$m.$n"
7979
case Version.Legacy => "legacy"
8080
case None => "none"
81+
82+
sealed trait Shape
83+
case class Circle(r: Double) extends Shape
84+
case class Rectangle(w: Double, h: Double) extends Shape
85+
86+
def tupleFirstExhaustive(pair: (Color, Color)): String = pair match
87+
case (a, b) if a match
88+
case Color.Red => "red first"
89+
case Color.Green => "green first"
90+
case Color.Blue => "blue first"
91+
92+
def tupleSecondExhaustive(pair: (Color, Color)): String = pair match
93+
case (a, b) if b match
94+
case Color.Red => "red second"
95+
case Color.Green => "green second"
96+
case Color.Blue => "blue second"
97+
98+
def typedBoundExhaustive(s: Shape): String = s match
99+
case x: Circle if x match
100+
case Circle(r) => s"circle r=$r"
101+
case _: Rectangle => "rectangle"
102+
103+
def typedGuardedFallback(s: Shape): String = s match
104+
case x: Circle if x match
105+
case Circle(r) if r > 0 => "positive circle"
106+
case _: Circle => "other circle"
107+
case _: Rectangle => "rectangle"
108+
109+
def fieldAccessExhaustive(d: Document): String = d match
110+
case x if x.version match
111+
case Version.Stable(m, n) => s"$m.$n"
112+
case Version.Legacy => "legacy"
113+
114+
def fieldAccessTypedBound(d: Document): String = d match
115+
case x: Document if x.version match
116+
case Version.Stable(m, n) => s"$m.$n"
117+
case Version.Legacy => "legacy"

tests/warn/sub-cases-exhaustivity.check

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,35 @@
3838
| It would fail on pattern case: Blue
3939
|
4040
| longer explanation available when compiling with `-explain`
41+
-- [E029] Pattern Match Exhaustivity Warning: tests/warn/sub-cases-exhaustivity.scala:61:54 ----------------------------
42+
61 |def tupleFirstMissing(pair: (Color, Color)): String = pair match // warn
43+
| ^^^^
44+
| match may not be exhaustive.
45+
|
46+
| It would fail on pattern case: (Blue, _)
47+
|
48+
| longer explanation available when compiling with `-explain`
49+
-- [E029] Pattern Match Exhaustivity Warning: tests/warn/sub-cases-exhaustivity.scala:66:55 ----------------------------
50+
66 |def tupleSecondMissing(pair: (Color, Color)): String = pair match // warn
51+
| ^^^^
52+
| match may not be exhaustive.
53+
|
54+
| It would fail on pattern case: (_, Blue)
55+
|
56+
| longer explanation available when compiling with `-explain`
57+
-- [E029] Pattern Match Exhaustivity Warning: tests/warn/sub-cases-exhaustivity.scala:71:45 ----------------------------
58+
71 |def typedGuardedSubcases(s: Shape): String = s match // warn
59+
| ^
60+
| match may not be exhaustive.
61+
|
62+
| It would fail on pattern case: Circle(_)
63+
|
64+
| longer explanation available when compiling with `-explain`
65+
-- [E029] Pattern Match Exhaustivity Warning: tests/warn/sub-cases-exhaustivity.scala:82:46 ----------------------------
66+
82 |def fieldAccessMissing(d: Document): String = d match // warn
67+
| ^
68+
| match may not be exhaustive.
69+
|
70+
| It would fail on pattern case: Document(_, Legacy)
71+
|
72+
| longer explanation available when compiling with `-explain`

tests/warn/sub-cases-exhaustivity.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,32 @@ def testAlternativeMissing(c: Color): String =
5353
c match // warn: match may not be exhaustive: It would fail on pattern case: Color.Blue
5454
case c1 if c1 match
5555
case Color.Red | Color.Green => "warm"
56+
57+
sealed trait Shape
58+
case class Circle(r: Double) extends Shape
59+
case class Rectangle(w: Double, h: Double) extends Shape
60+
61+
def tupleFirstMissing(pair: (Color, Color)): String = pair match // warn
62+
case (a, b) if a match
63+
case Color.Red => "red first"
64+
case Color.Green => "green first"
65+
66+
def tupleSecondMissing(pair: (Color, Color)): String = pair match // warn
67+
case (a, b) if b match
68+
case Color.Red => "red second"
69+
case Color.Green => "green second"
70+
71+
def typedGuardedSubcases(s: Shape): String = s match // warn
72+
case x: Circle if x match
73+
case Circle(r) if r > 0 => "positive circle"
74+
case _: Rectangle => "rectangle"
75+
76+
enum Version:
77+
case Legacy
78+
case Stable(major: Int, minor: Int)
79+
80+
case class Document(title: String, version: Version)
81+
82+
def fieldAccessMissing(d: Document): String = d match // warn
83+
case x if x.version match
84+
case Version.Stable(m, n) => s"$m.$n"

tests/warn/sub-cases-reachability.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@
2222
50 | case Color.Red => "unreachable" // warn
2323
| ^^^^^^^^^
2424
| Unreachable case
25+
-- [E030] Match case Unreachable Warning: tests/warn/sub-cases-reachability.scala:63:7 ---------------------------------
26+
63 | case (Color.Red, _) => "unreachable" // warn
27+
| ^^^^^^^^^^^^^^
28+
| Unreachable case
29+
-- [E030] Match case Unreachable Warning: tests/warn/sub-cases-reachability.scala:69:15 --------------------------------
30+
69 | case Document(_, Version.Legacy) => "unreachable" // warn
31+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
32+
| Unreachable case

tests/warn/sub-cases-reachability.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,22 @@ def testAlternativeReachability(c: Color): String = c match
4848
case c1 if c1 match
4949
case Color.Red | Color.Green | Color.Blue => "all"
5050
case Color.Red => "unreachable" // warn
51+
52+
enum Version:
53+
case Legacy
54+
case Stable(major: Int, minor: Int)
55+
56+
case class Document(title: String, version: Version)
57+
58+
def tupleFirstReachability(pair: (Color, Color)): String = pair match
59+
case (a, b) if a match
60+
case Color.Red => "red first"
61+
case Color.Green => "green first"
62+
case Color.Blue => "blue first"
63+
case (Color.Red, _) => "unreachable" // warn
64+
65+
def fieldAccessReachability(d: Document): String = d match
66+
case x if x.version match
67+
case Version.Stable(m, n) => s"$m.$n"
68+
case Version.Legacy => "legacy"
69+
case Document(_, Version.Legacy) => "unreachable" // warn

0 commit comments

Comments
 (0)