Skip to content

Copy @use and @consume annotations to parameter types #23324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ extension (tp: Type)
RefinedType(tp, name,
AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan)))

def dropUseAndConsumeAnnots(using Context): Type =
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)

extension (tp: MethodType)
/** A method marks an existential scope unless it is the prefix of a curried method */
def marksExistentialScope(using Context): Boolean =
Expand Down Expand Up @@ -490,18 +493,24 @@ extension (sym: Symbol)
def hasTrackedParts(using Context): Boolean =
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty

/** `sym` is annotated @use or it is a type parameter with a matching
/** `sym` itself or its info is annotated @use or it is a type parameter with a matching
* @use-annotated term parameter that contains `sym` in its deep capture set.
*/
def isUseParam(using Context): Boolean =
sym.hasAnnotation(defn.UseAnnot)
|| sym.info.hasAnnotation(defn.UseAnnot)
|| sym.is(TypeParam)
&& sym.owner.rawParamss.nestedExists: param =>
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
&& param.info.deepCaptureSet.elems.exists:
case c: TypeRef => c.symbol == sym
case _ => false

/** `sym` or its info is annotated with `@consume`. */
def isConsumeParam(using Context): Boolean =
sym.hasAnnotation(defn.ConsumeAnnot)
|| sym.info.hasAnnotation(defn.ConsumeAnnot)

def isUpdateMethod(using Context): Boolean =
sym.isAllOf(Mutable | Method, butNot = Accessor)

Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class CheckCaptures extends Recheck, SymTransformer:
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
val param = meth.paramNamed(pname)
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
case Some(ann) => AnnotatedType(tp, ann)
case Some(ann) if !tp.hasAnnotation(cls) => AnnotatedType(tp, ann)
case _ => tp
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
Expand Down Expand Up @@ -1616,7 +1616,10 @@ class CheckCaptures extends Recheck, SymTransformer:
if noWiden(actual, expected) then
actual
else
val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual)
// Compute the widened type. Drop `@use` and `@consume` annotations from the type,
// since they obscures the capturing type.
val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots
val improvedVAR = improveCaptures(widened, actual)
val improved = improveReadOnly(improvedVAR, expected)
val adapted = adaptBoxed(
improved.withReachCaptures(actual), expected, tree,
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then
report.error(em"""Separation failure: $descr non-local $refSym""", pos)
else if refSym.is(TermParam)
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
&& !refSym.isConsumeParam
&& currentOwner.isContainedIn(refSym.owner)
then
badParams += refSym
Expand Down Expand Up @@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"):
checkUse(tree)
tree match
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) =>
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam =>
traverseChildren(tree)
checkConsumedRefs(
captures(qual).footprint(), qual.nuType,
Expand Down Expand Up @@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
consumeInLoopError(ref, pos)
case _ =>
traverseChildren(tree)
end SepCheck
end SepCheck
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ class Definitions {

// Set of annotations that are not printed in types except under -Yprint-debug
@tu lazy val SilentAnnots: Set[Symbol] =
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot)
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot, SilentIntoAnnot, UseAnnot, ConsumeAnnot)

// A list of annotations that are commonly used to indicate that a field/method argument or return
// type is not null. These annotations are used by the nullification logic in JavaNullInterop to
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4234,6 +4234,11 @@ object Types extends TypeUtils {
paramType = addAnnotation(paramType, defn.InlineParamAnnot, param)
if param.is(Erased) then
paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param)
// Copy `@use` and `@consume` annotations from parameter symbols to the type.
if param.hasAnnotation(defn.UseAnnot) then
paramType = addAnnotation(paramType, defn.UseAnnot, param)
if param.hasAnnotation(defn.ConsumeAnnot) then
paramType = addAnnotation(paramType, defn.ConsumeAnnot, param)
paramType

def adaptParamInfo(param: Symbol)(using Context): Type =
Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class PlainPrinter(_ctx: Context) extends Printer {

protected def argText(arg: Type, isErased: Boolean = false): Text =
keywordText("erased ").provided(isErased)
~ specialAnnotText(defn.UseAnnot, arg)
~ specialAnnotText(defn.ConsumeAnnot, arg)
~ homogenizeArg(arg).match
case arg: TypeBounds => "?" ~ toText(arg)
case arg => toText(arg)
Expand Down Expand Up @@ -376,10 +378,18 @@ class PlainPrinter(_ctx: Context) extends Printer {
try "(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")"
finally elideCapabilityCaps = saved

/** Print the annotation that are meant to be on the parameter symbol but was moved
* to parameter types. Examples are `@use` and `@consume`. */
protected def specialAnnotText(sym: ClassSymbol, tp: Type): Text =
Str(s"@${sym.name} ").provided(tp.hasAnnotation(sym))

protected def paramsText(lam: LambdaType): Text = {
def paramText(ref: ParamRef) =
val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot)
keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
keywordText("erased ").provided(erased)
~ specialAnnotText(defn.UseAnnot, ref.underlying)
~ specialAnnotText(defn.ConsumeAnnot, ref.underlying)
~ ParamRefNameString(ref) ~ hashStr(lam) ~ toTextRHS(ref.underlying, isParameter = true)
Text(lam.paramRefs.map(paramText), ", ")
}

Expand Down
18 changes: 18 additions & 0 deletions tests/neg-custom-args/captures/cc-annot-value-classes.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import language.experimental.captureChecking
import caps.*

class Runner(val x: Int) extends AnyVal:
def runOps(@use ops: List[() => Unit]): Unit =
ops.foreach(_()) // ok

class RunnerAlt(val x: Int):
def runOps(@use ops: List[() => Unit]): Unit =
ops.foreach(_()) // ok, of course

class RunnerAltAlt(val x: Int) extends AnyVal:
def runOps(ops: List[() => Unit]): Unit =
ops.foreach(_()) // error, as expected

class RunnerAltAltAlt(val x: Int):
def runOps(ops: List[() => Unit]): Unit =
ops.foreach(_()) // error, as expected
16 changes: 16 additions & 0 deletions tests/neg-custom-args/captures/cc-annot-value-classes2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import language.experimental.captureChecking
import caps.*
trait Ref extends Mutable
def kill(@consume x: Ref^): Unit = ()

class C1:
def myKill(@consume x: Ref^): Unit = kill(x) // ok

class C2(val dummy: Int) extends AnyVal:
def myKill(@consume x: Ref^): Unit = kill(x) // ok, too

class C3:
def myKill(x: Ref^): Unit = kill(x) // error

class C4(val dummy: Int) extends AnyVal:
def myKill(x: Ref^): Unit = kill(x) // error, too
10 changes: 5 additions & 5 deletions tests/neg-custom-args/captures/leak-problem-unboxed.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def useBoxedAsync1(@use x: Box[Async^]): Unit = x.get.read() // ok
def test(): Unit =

val f: Box[Async^] => Unit = (x: Box[Async^]) => useBoxedAsync(x) // error
val _: Box[Async^] => Unit = useBoxedAsync(_) // error
val _: Box[Async^] => Unit = useBoxedAsync // error
val _ = useBoxedAsync(_) // error
val _ = useBoxedAsync // error
val t1: Box[Async^] => Unit = useBoxedAsync(_) // error
val t2: Box[Async^] => Unit = useBoxedAsync // error
val t3 = useBoxedAsync(_) // was error, now ok
val t4 = useBoxedAsync // was error, now ok

def boom(x: Async^): () ->{f} Unit =
() => f(Box(x))

val leaked = usingAsync[() ->{f} Unit](boom)

leaked() // scope violation
leaked() // scope violation
6 changes: 3 additions & 3 deletions tests/neg-custom-args/captures/unbox-overrides.check
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:8:6 ----------------------------------
8 | def foo(x: C): C // error
| ^
|error overriding method foo in trait A of type (x: C): C;
|error overriding method foo in trait A of type (@use x: C): C;
| method foo of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
|
| longer explanation available when compiling with `-explain`
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:9:6 ----------------------------------
9 | def bar(@use x: C): C // error
| ^
|error overriding method bar in trait A of type (x: C): C;
| method bar of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
| method bar of type (@use x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
|
| longer explanation available when compiling with `-explain`
-- [E164] Declaration Error: tests/neg-custom-args/captures/unbox-overrides.scala:15:15 --------------------------------
15 |abstract class C extends A[C], B2 // error
| ^
|error overriding method foo in trait A of type (x: C): C;
|error overriding method foo in trait A of type (@use x: C): C;
| method foo in trait B2 of type (x: C): C has a parameter x with different @use status than the corresponding parameter in the overridden definition
|
| longer explanation available when compiling with `-explain`
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/unsound-reach-4.check
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
17 | def use(@consume x: F): File^ = x // error @consume override
| ^
| error overriding method use in trait Foo of type (x: File^): box File^;
| method use of type (x: File^): File^² has incompatible type
| method use of type (@consume x: File^): File^² has incompatible type
|
| where: ^ refers to the universal root capability
| ^² refers to a root capability associated with the result type of (x: File^): File^²
| ^² refers to a root capability associated with the result type of (@consume x: File^): File^²
|
| longer explanation available when compiling with `-explain`
10 changes: 10 additions & 0 deletions tests/pos-custom-args/captures/cc-use-iterable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import language.experimental.captureChecking
trait IterableOnce[+T]
trait Iterable[+T] extends IterableOnce[T]:
def flatMap[U](@caps.use f: T => IterableOnce[U]^): Iterable[U]^{this, f*}


class IterableOnceExtensionMethods[T](val it: IterableOnce[T]) extends AnyVal:
def flatMap[U](@caps.use f: T => IterableOnce[U]^): IterableOnce[U]^{f*} = it match
case it: Iterable[T] => it.flatMap(f)

Loading