-
Notifications
You must be signed in to change notification settings - Fork 20
Fix AST transformations of by-name parameters #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
aae630a
647d1ea
c354504
20ce237
4136755
a9a7885
e07ca52
36ef738
23bf279
c2f47df
db52258
9ade443
1e78d94
48c1c85
c35b6b7
c476e7d
bd1c75a
1387ecb
1bc8ded
06e28eb
2b576cf
ffd0ea7
28adad9
b2983ad
570221f
fa55257
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,7 +90,7 @@ trait AstCanonicalization[C <: Context] { | |
val localvarname = TermName(c.freshName("x")) | ||
val decls = List( | ||
q"var $localvarname = null.asInstanceOf[Boolean]", | ||
q""" | ||
q""" | ||
..$conddecls | ||
if ($condident) { | ||
..$thendecls | ||
|
@@ -122,6 +122,48 @@ trait AstCanonicalization[C <: Context] { | |
(decls, q"$localvarname") | ||
case q"$selector[..$tpts](...$paramss)" if tpts.length > 0 || paramss.length > 0 => | ||
// application | ||
|
||
// Detect by-name parameters and prevent their canonicalization. | ||
|
||
// Maps to the parameter lists, saying whether or not they are by-name. | ||
val byNameParams: Seq[Seq[Boolean]] = { | ||
// Check that `selector` has a non-trivial symbol. If it doesn't, we assume | ||
// that there were no by-name parameters. | ||
if (selector.symbol != null && selector.symbol != NoSymbol) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to #27 (comment). I think lambdas were problematic. There might have been more cases, though. I should have documented this more heavily at the time. |
||
val paramLists = selector.symbol.asMethod.paramLists | ||
|
||
// This value does not contain repeated parameters. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that repeated by name parameters are not allowed in Scala. Also, the |
||
val noRepeatedParamsSeq = paramLists.map { currentParamList => | ||
currentParamList.map { param => param.asTerm.isByNameParam } | ||
} | ||
|
||
// Check that parameters were both passed in to and requested by the method. | ||
if (paramss.length > 0 && noRepeatedParamsSeq.length > 0) { | ||
// Initially, we assume that no parameters were repeated. | ||
var paramsSeq = noRepeatedParamsSeq | ||
|
||
/** If the number of passed-in parameters is greater than the number of | ||
* parameters needed by the method, then we assume that there are repeated | ||
* parameters. | ||
* | ||
* This problem is fixed by appending the last value in | ||
* `noRepeatedParamsSeq` to `paramsSeq`. This is done until the number of | ||
* parameters are equal. | ||
*/ | ||
while (paramss(0).length > paramsSeq(0).length) { | ||
val newHead = paramsSeq(0) :+ paramsSeq(0).last | ||
val newTail = paramsSeq.tail | ||
paramsSeq = newHead :: newTail | ||
} | ||
paramsSeq | ||
} else { | ||
noRepeatedParamsSeq | ||
} | ||
} else { | ||
for (params <- paramss) yield | ||
for (p <- params) yield false | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should not inspect the method symbol and match the parameter trees to the method type information to figure out which is which - you're reproducing complex typechecker logic by doing that, and not covering all the edge cases. Instead, see this: http://stackoverflow.com/questions/29757584/handling-by-name-parameters-in-scala-macro The We should have the following behavior:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can do! After I write unit tests, I will change my implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Working on this change here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would the first condition disallow calling methods with repeated parameters? |
||
val (rdecls, newselector) = selector match { | ||
case q"$r.$method" => | ||
val (rdecls, rident) = canonicalize(r) | ||
|
@@ -130,10 +172,35 @@ trait AstCanonicalization[C <: Context] { | |
(Nil, q"$method") | ||
} | ||
for (tpt <- tpts) disallowCoroutinesIn(tpt) | ||
val (pdeclss, pidents) = paramss.map(_.map(canonicalize).unzip).unzip | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you do just:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By Also, where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By |
||
// Canonicalize parameters that are not by-name. | ||
val paramsByNameUnmodified = { | ||
val modifiedParamLists = mutable.Seq.fill[ | ||
List[(List[c.universe.Tree], c.universe.Tree)]](paramss.length)(null) | ||
|
||
// Canonicalize each parameter list. | ||
for (i <- 0 until paramss.length) { | ||
val modifiedParams = mutable.Seq.fill[ | ||
(List[c.universe.Tree], c.universe.Tree)](paramss(i).length)(null) | ||
for (j <- 0 until modifiedParams.length) { | ||
if (byNameParams(i)(j)) { | ||
new NestedContextValidator().traverse(paramss(i)(j)) | ||
modifiedParams(j) = (List(q""), paramss(i)(j)) | ||
} else { | ||
modifiedParams(j) = canonicalize(paramss(i)(j)) | ||
} | ||
} | ||
modifiedParamLists(i) = modifiedParams.toList | ||
} | ||
modifiedParamLists.toList | ||
} | ||
val pdeclss = | ||
paramsByNameUnmodified.map((_.map(tuple => tuple._1))) | ||
.flatten.flatten.filter{ decl => decl != q"" } | ||
val pidents = paramsByNameUnmodified.map((_.map(tuple => tuple._2))) | ||
val localvarname = TermName(c.freshName("x")) | ||
val localvartree = q"val $localvarname = $newselector[..$tpts](...$pidents)" | ||
(rdecls ++ pdeclss.flatten.flatten ++ List(localvartree), q"$localvarname") | ||
(rdecls ++ pdeclss ++ List(localvartree), q"$localvarname") | ||
case q"$r[..$tpts]" if tpts.length > 0 => | ||
// type application | ||
for (tpt <- tpts) disallowCoroutinesIn(tpt) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ import scala.annotation.unchecked.uncheckedVariance | |
import scala.concurrent._ | ||
import scala.concurrent.duration._ | ||
import scala.concurrent.ExecutionContext.Implicits.global | ||
import scala.language.{ reflectiveCalls, postfixOps } | ||
import scala.util.Success | ||
|
||
|
||
|
@@ -237,15 +236,14 @@ class AsyncAwaitTest extends FunSuite with Matchers { | |
val fut = AsyncAwaitTest.async(coroutine { () => | ||
Option(new ParamWrapper("value!")) match { | ||
case Some(valueHolder) => | ||
AsyncAwaitTest.await(Future(doAThing(valueHolder))) | ||
AsyncAwaitTest.await(doAThing(valueHolder)) | ||
case None => | ||
None | ||
} | ||
}) | ||
|
||
val result = Await.result(fut, 5 seconds) | ||
assert(result.asInstanceOf[Future[ParamWrapper[String]]].value == | ||
Some(Success(None))) | ||
assert(result == None) | ||
} | ||
|
||
// Source: https://git.io/vr7NW | ||
|
@@ -309,25 +307,47 @@ class AsyncAwaitTest extends FunSuite with Matchers { | |
assert(result == 103) | ||
} | ||
|
||
// Source: https://git.io/vrFj3 | ||
/** Source: https://git.io/vrFj3 | ||
* Modified so that there is no coroutine passed as a by-name parameter. For more | ||
* information, see https://git.io/vKkM5. | ||
*/ | ||
test("nested await as bare expression") { | ||
val c = async(coroutine { () => | ||
await(Future(await(Future("")).isEmpty)) | ||
val emptyString = await(Future("")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I will add some unit tests to assert that this check happens. |
||
await(Future(emptyString.isEmpty)) | ||
}) | ||
val result = Await.result(c, 5 seconds) | ||
assert(result == true) | ||
} | ||
|
||
// Source: https://git.io/vrAnM | ||
/** Source: https://git.io/vrAnM | ||
* Modified so that there is no coroutine passed as a by-name parameter. For more | ||
* information, see https://git.io/vKkM5. | ||
*/ | ||
test("nested await in block") { | ||
val c = async(coroutine { () => | ||
() | ||
await(Future(await(Future("")).isEmpty)) | ||
val emptyString = await(Future("")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #24. |
||
await(Future(emptyString.isEmpty)) | ||
}) | ||
val result = Await.result(c, 5 seconds) | ||
assert(result == true) | ||
} | ||
|
||
/** Source: https://git.io/vrAlJ | ||
* Modified so that there is no coroutine passed as a by-name parameter. For more | ||
* information, see https://git.io/vKkM5. | ||
*/ | ||
test("by-name expressions aren't lifted") { | ||
def foo(ignored: => Any, b: Int) = b | ||
val c = async(coroutine { () => | ||
val innerValue = await(Future(1)) | ||
await(Future(foo(???, innerValue))) | ||
}) | ||
val result = Await.result(c, 5 seconds) | ||
assert(result == 1) | ||
} | ||
|
||
// Source: https://git.io/vrhTe | ||
test("named and default arguments respect evaluation order") { | ||
var i = 0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package org.coroutines | ||
|
||
|
||
|
||
import org.scalatest._ | ||
|
||
|
||
|
||
class ByNameTest extends FunSuite with Matchers { | ||
test("coroutine applications should not be allowed as by-name parameters 1") { | ||
import scala.concurrent.Future | ||
import scala.concurrent.ExecutionContext.Implicits.global | ||
|
||
val c1 = coroutine { (i: Int) => | ||
i | ||
} | ||
""" | ||
val c2 = coroutine { () => | ||
val f = Future(c1(4)) | ||
10 | ||
} | ||
""" shouldNot typeCheck | ||
} | ||
|
||
test("coroutine applications should not be allowed as by-name parameters 2") { | ||
def foo(a: Int, b: => Coroutine._0[Nothing, Int], c: Int): Int = { | ||
val instance = call(b()) | ||
instance.resume | ||
a + instance.result + c | ||
} | ||
|
||
val c1 = coroutine { (i: Int) => | ||
i | ||
} | ||
""" | ||
val c2 = coroutine { () => | ||
foo(1, c1(2), 3) | ||
} | ||
""" shouldNot typeCheck | ||
} | ||
|
||
test("repeated parameters") { | ||
def foo(a: => Int, otherInts: Int*): Int = { | ||
otherInts.foldLeft(a)((sum: Int, current: Int) => sum + current) | ||
} | ||
|
||
val c = coroutine { (starter: Int) => | ||
foo(starter, 1, 2, 3) | ||
} | ||
val instance = call(c(4)) | ||
assert(!instance.resume) | ||
assert(instance.result == 10) | ||
} | ||
|
||
// Inspired by https://git.io/vrAlJ | ||
test("by-name arguments aren't lifted when they surround a non by-name one") { | ||
def foo(firstIgnored: => Any, b: Int, secondIgnored: => Any) = b | ||
val c = coroutine { () => | ||
foo(???, 1, { throw new Exception }) | ||
} | ||
val instance = call(c()) | ||
assert(!instance.resume) | ||
assert(instance.result == 1) | ||
} | ||
|
||
test("mixed argument lists can end in non by-name parameters") { | ||
def foo(firstIgnored: => Any, b: Int, secondIgnored: => Any, d: Int) = { b + d } | ||
val c = coroutine { () => | ||
foo(???, 1, { throw new Exception }, 2) | ||
} | ||
val instance = call(c()) | ||
assert(!instance.resume) | ||
assert(instance.result == 3) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is a non-trivial symbol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right: non-trivial isn't the best way to describe what I meant.
Some methods threw errors when I called
asMethod
on them. These checks are my way of preventing those exceptions.I labeled the problematic functions trivial. Pretty sure they were lambdas.