Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1348,26 +1348,29 @@ object desugar {
case _ => body
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]

/** Apply function-level parameter flags such as `given` and `erased` to term parameters. */
private def addFunctionParamFlags(params: List[ValDef], funFlags: FlagSet, erasedParams: List[Boolean])(using Context): List[ValDef] =
val commonFlags = funFlags.toTermFlags & GivenOrImplicit
params.zipWithConserve(erasedParams): (param, isErased) =>
val flags = commonFlags | (if isErased then Erased else EmptyFlags)
if flags.isEmpty then param else param.withAddedFlags(flags)

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
val vparams0 = vparamTypes.zipWithIndex.map {
case (p: ValDef, _) => p
case (p, n) => makeSyntheticParameter(n + 1, p)
}.toList
val vparams = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
addFunctionParamFlags(vparams0, fun.mods.flags, fun.erasedParams)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList
vparams0

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
Expand Down Expand Up @@ -2076,8 +2079,9 @@ object desugar {
if augmenting then paramNamesOrNil.map(ContextFunctionParamName.fresh(_))
else paramNamesOrNil
else List.fill(formals.length)(ContextFunctionParamName.fresh())
val params = for (tpt, pname) <- formals.zip(paramNames) yield
ValDef(pname, tpt, EmptyTree).withFlags(Given | Param)
val params0 = for (tpt, pname) <- formals.zip(paramNames) yield
ValDef(pname, tpt, EmptyTree).withFlags(Param)
val params = addFunctionParamFlags(params0, Given, erasedParams)
FunctionWithMods(params, body, Modifiers(Given), erasedParams)

private def derivedValDef(originalSpan: Span, named: NameTree, tpt: Tree, rhs: Tree, mods: Modifiers)(using Context) =
Expand Down
5 changes: 5 additions & 0 deletions tests/run/i25725.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
value: 99
value: 99
varA: 99
varB: 99
varC: 99 ok
33 changes: 33 additions & 0 deletions tests/run/i25725.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//> using options -language:experimental.erasedDefinitions

class Proof
erased given Proof = Proof()

class P1
class P2
erased given P1 = P1()
erased given P2 = P2()

def mixed: (erased Proof) ?=> Int ?=> String =
s"value: ${summon[Int]}"

def reversed: Int ?=> (erased Proof) ?=> String =
s"value: ${summon[Int]}"

def varA: (erased P1) ?=> (erased P2) ?=> Int ?=> String =
s"varA: ${summon[Int]}"

def varB: (erased P1) ?=> Int ?=> (erased P2) ?=> String =
s"varB: ${summon[Int]}"

def varC: Int ?=> (erased P1) ?=> String ?=> String =
s"varC: ${summon[Int]} ${summon[String]}"

@main def Test(): Unit =
given Int = 99
given String = "ok"
println(mixed)
println(reversed)
println(varA)
println(varB)
println(varC)
Loading