Skip to content

Experiment with Capless-like Scheme for Capture Checking #23291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
48 changes: 29 additions & 19 deletions compiler/src/dotty/tools/dotc/cc/Capability.scala
Original file line number Diff line number Diff line change
Expand Up @@ -473,27 +473,28 @@ object Capabilities:
case info: OrType => viaInfo(info.tp1)(test) && viaInfo(info.tp2)(test)
case _ => false

def trySubpath(y: TermRef): Boolean =
y.prefix.match
case ypre: Capability =>
this.subsumes(ypre)
|| this.match
case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol =>
// To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y`
// are equvalent, which means `x =:= y` in terms of subtyping,
// not just `{x} =:= {y}` in terms of subcapturing.
// It is possible to construct two singleton types `x` and `y`,
// which subsume each other, but are not equal references.
// See `tests/neg-custom-args/captures/path-prefix.scala` for example.
withMode(Mode.IgnoreCaptures):
TypeComparer.isSameRef(xpre, ypre)
case _ =>
false
case _ => false

try (this eq y)
|| maxSubsumes(y, canAddHidden = !vs.isOpen)
|| y.match
case y: TermRef =>
y.prefix.match
case ypre: Capability =>
this.subsumes(ypre)
|| this.match
case x @ TermRef(xpre: Capability, _) if x.symbol == y.symbol =>
// To show `{x.f} <:< {y.f}`, it is important to prove `x` and `y`
// are equvalent, which means `x =:= y` in terms of subtyping,
// not just `{x} =:= {y}` in terms of subcapturing.
// It is possible to construct two singleton types `x` and `y`,
// which subsume each other, but are not equal references.
// See `tests/neg-custom-args/captures/path-prefix.scala` for example.
withMode(Mode.IgnoreCaptures):
TypeComparer.isSameRef(xpre, ypre)
case _ =>
false
case _ => false
|| viaInfo(y.info)(subsumingRefs(this, _))
case y: TermRef => trySubpath(y) || viaInfo(y.info)(subsumingRefs(this, _))
case Maybe(y1) => this.stripMaybe.subsumes(y1)
case ReadOnly(y1) => this.stripReadOnly.subsumes(y1)
case y: TypeRef if y.derivesFrom(defn.Caps_CapSet) =>
Expand All @@ -507,6 +508,15 @@ object Capabilities:
this.subsumes(hi)
case _ =>
y.captureSetOfInfo.elems.forall(this.subsumes)
case Reach(y1: TermRef) =>
val sym = y1.symbol
def isUseClassParam: Boolean =
sym.owner match
case classSym: ClassSymbol =>
val paramSym = classSym.primaryConstructor.paramNamed(sym.name)
paramSym.isUseParam
case _ => false
isUseClassParam && trySubpath(y1)
case _ => false
|| this.match
case Reach(x1) => x1.subsumes(y.stripReach)
Expand Down Expand Up @@ -858,4 +868,4 @@ object Capabilities:
case tp1 => tp1
end toResultInResults

end Capabilities
end Capabilities
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ extension (tp: Type)
val tp1 = narrowCaps(tp)
if narrowCaps.change then
capt.println(i"narrow $tp of $ref to $tp1")
//println(i"reach refinement $tp at $ref to $tp1 (${ctx.compilationUnit})")
tp1
else
tp
Expand All @@ -395,6 +396,9 @@ extension (tp: Type)
RefinedType(tp, name,
AnnotatedType(rinfo, Annotation(defn.RefineOverrideAnnot, util.Spans.NoSpan)))

def dropUseAndConsumeAnnots(using Context): Type =
tp.dropAnnot(defn.UseAnnot).dropAnnot(defn.ConsumeAnnot)

extension (tp: MethodType)
/** A method marks an existential scope unless it is the prefix of a curried method */
def marksExistentialScope(using Context): Boolean =
Expand Down Expand Up @@ -490,18 +494,24 @@ extension (sym: Symbol)
def hasTrackedParts(using Context): Boolean =
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty

/** `sym` is annotated @use or it is a type parameter with a matching
/** `sym` itself or its info is annotated @use or it is a type parameter with a matching
* @use-annotated term parameter that contains `sym` in its deep capture set.
*/
def isUseParam(using Context): Boolean =
sym.hasAnnotation(defn.UseAnnot)
|| sym.info.hasAnnotation(defn.UseAnnot)
|| sym.is(TypeParam)
&& sym.owner.rawParamss.nestedExists: param =>
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
&& param.info.deepCaptureSet.elems.exists:
case c: TypeRef => c.symbol == sym
case _ => false

/** `sym` or its info is annotated with `@consume`. */
def isConsumeParam(using Context): Boolean =
sym.hasAnnotation(defn.ConsumeAnnot)
|| sym.info.hasAnnotation(defn.ConsumeAnnot)

def isUpdateMethod(using Context): Boolean =
sym.isAllOf(Mutable | Method, butNot = Accessor)

Expand Down
14 changes: 11 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ class CheckCaptures extends Recheck, SymTransformer:
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
val param = meth.paramNamed(pname)
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
case Some(ann) => AnnotatedType(tp, ann)
case Some(ann) if !tp.hasAnnotation(cls) => AnnotatedType(tp, ann)
case _ => tp
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
Expand Down Expand Up @@ -789,6 +789,7 @@ class CheckCaptures extends Recheck, SymTransformer:
case appType @ CapturingType(appType1, refs)
if qualType.exists
&& !tree.fun.symbol.isConstructor
&& funType.paramInfos.isEmpty
&& qualCaptures.mightSubcapture(refs)
&& argCaptures.forall(_.mightSubcapture(refs)) =>
val callCaptures = argCaptures.foldLeft(qualCaptures)(_ ++ _)
Expand Down Expand Up @@ -845,10 +846,14 @@ class CheckCaptures extends Recheck, SymTransformer:
initCs ++ FreshCap(Origin.NewCapability(core)).readOnly.singletonCaptureSet
else initCs
for (getterName, argType) <- mt.paramNames.lazyZip(argTypes) do
val paramSym = cls.primaryConstructor.paramNamed(getterName)
val getter = cls.info.member(getterName).suchThat(_.isRefiningParamAccessor).symbol
if !getter.is(Private) && getter.hasTrackedParts then
refined = refined.refinedOverride(getterName, argType.unboxed) // Yichen you might want to check this
allCaptures ++= argType.captureSet
if paramSym.isUseParam then
allCaptures ++= argType.deepCaptureSet
else
allCaptures ++= argType.captureSet
(refined, allCaptures)

/** Augment result type of constructor with refinements and captures.
Expand Down Expand Up @@ -1616,7 +1621,10 @@ class CheckCaptures extends Recheck, SymTransformer:
if noWiden(actual, expected) then
actual
else
val improvedVAR = improveCaptures(actual.widen.dealiasKeepAnnots, actual)
// Compute the widened type. Drop `@use` and `@consume` annotations from the type,
// since they obscures the capturing type.
val widened = actual.widen.dealiasKeepAnnots.dropUseAndConsumeAnnots
val improvedVAR = improveCaptures(widened, actual)
val improved = improveReadOnly(improvedVAR, expected)
val adapted = adaptBoxed(
improved.withReachCaptures(actual), expected, tree,
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if currentOwner.enclosingMethodOrClass.isProperlyContainedIn(refSym.maybeOwner.enclosingMethodOrClass) then
report.error(em"""Separation failure: $descr non-local $refSym""", pos)
else if refSym.is(TermParam)
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
&& !refSym.isConsumeParam
&& currentOwner.isContainedIn(refSym.owner)
then
badParams += refSym
Expand Down Expand Up @@ -899,7 +899,7 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if !isUnsafeAssumeSeparate(tree) then trace(i"checking separate $tree"):
checkUse(tree)
tree match
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.hasAnnotation(defn.ConsumeAnnot) =>
case tree @ Select(qual, _) if tree.symbol.is(Method) && tree.symbol.isConsumeParam =>
traverseChildren(tree)
checkConsumedRefs(
captures(qual).footprint(), qual.nuType,
Expand Down Expand Up @@ -962,4 +962,4 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
consumeInLoopError(ref, pos)
case _ =>
traverseChildren(tree)
end SepCheck
end SepCheck
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4234,6 +4234,11 @@ object Types extends TypeUtils {
paramType = addAnnotation(paramType, defn.InlineParamAnnot, param)
if param.is(Erased) then
paramType = addAnnotation(paramType, defn.ErasedParamAnnot, param)
// Copy `@use` and `@consume` annotations from parameter symbols to the type.
if param.hasAnnotation(defn.UseAnnot) then
paramType = addAnnotation(paramType, defn.UseAnnot, param)
if param.hasAnnotation(defn.ConsumeAnnot) then
paramType = addAnnotation(paramType, defn.ConsumeAnnot, param)
paramType

def adaptParamInfo(param: Symbol)(using Context): Type =
Expand Down
6 changes: 3 additions & 3 deletions scala2-library-cc/src/scala/collection/Iterable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,9 @@ trait IterableOps[+A, +CC[_], +C] extends Any with IterableOnce[A] with Iterable

def map[B](f: A => B): CC[B]^{this, f} = iterableFactory.from(new View.Map(this, f))

def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} = iterableFactory.from(new View.FlatMap(this, f))
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} = iterableFactory.from(new View.FlatMap(this, f))

def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this} = flatMap(asIterable)
def flatten[B](implicit asIterable: A -> IterableOnce[B]): CC[B]^{this, asIterable*} = flatMap(asIterable)

def collect[B](pf: PartialFunction[A, B]^): CC[B]^{this, pf} =
iterableFactory.from(new View.Collect(this, pf))
Expand Down Expand Up @@ -902,7 +902,7 @@ object IterableOps {
def map[B](f: A => B): CC[B]^{this, f} =
self.iterableFactory.from(new View.Map(filtered, f))

def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f} =
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*} =
self.iterableFactory.from(new View.FlatMap(filtered, f))

def foreach[U](f: A => U): Unit = filtered.foreach(f)
Expand Down
5 changes: 2 additions & 3 deletions scala2-library-cc/src/scala/collection/IterableOnce.scala
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,9 @@ final class IterableOnceExtensionMethods[A](private val it: IterableOnce[A]) ext
}

@deprecated("Use .iterator.flatMap instead or consider requiring an Iterable", "2.13.0")
def flatMap[B](f: A => IterableOnce[B]^): IterableOnce[B]^{f} = it match {
def flatMap[B](@caps.use f: A => IterableOnce[B]^): IterableOnce[B]^{f*} = it match
case it: Iterable[A] => it.flatMap(f)
case _ => it.iterator.flatMap(f)
}

@deprecated("Use .iterator.sameElements instead", "2.13.0")
def sameElements[B >: A](that: IterableOnce[B]): Boolean = it.iterator.sameElements(that)
Expand Down Expand Up @@ -439,7 +438,7 @@ trait IterableOnceOps[+A, +CC[_], +C] extends Any { this: IterableOnce[A]^ =>
* @return a new $coll resulting from applying the given collection-valued function
* `f` to each element of this $coll and concatenating the results.
*/
def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f}
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*}

/** Converts this $coll of iterable collections into
* a $coll formed by the elements of these iterable
Expand Down
10 changes: 5 additions & 5 deletions scala2-library-cc/src/scala/collection/Iterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
def next() = f(self.next())
}

def flatMap[B](f: A => IterableOnce[B]^): Iterator[B]^{this, f} = new AbstractIterator[B] {
private[this] var cur: Iterator[B]^{f} = Iterator.empty
def flatMap[B](@caps.use f: A => IterableOnce[B]^): Iterator[B]^{this, f*} = new AbstractIterator[B] {
private[this] var cur: Iterator[B]^{f*} = Iterator.empty
/** Trillium logic boolean: -1 = unknown, 0 = false, 1 = true */
private[this] var _hasNext: Int = -1

Expand Down Expand Up @@ -623,7 +623,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
}
}

def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this} =
def flatten[B](implicit ev: A -> IterableOnce[B]): Iterator[B]^{this, ev*} =
flatMap[B](ev)

def concat[B >: A](xs: => IterableOnce[B]^): Iterator[B]^{this, xs} = new Iterator.ConcatIterator[B](self).concat(xs)
Expand Down Expand Up @@ -982,7 +982,7 @@ object Iterator extends IterableFactory[Iterator] {
/** Creates a target $coll from an existing source collection
*
* @param source Source collection
* @tparam A the type of the collections elements
* @tparam A the type of the collection's elements
* @return a new $coll with the elements of `source`
*/
override def from[A](source: IterableOnce[A]^): Iterator[A]^{source} = source.iterator
Expand All @@ -1003,7 +1003,7 @@ object Iterator extends IterableFactory[Iterator] {

/**
* @return A builder for $Coll objects.
* @tparam A the type of the ${coll}s elements
* @tparam A the type of the ${coll}'s elements
*/
def newBuilder[A]: Builder[A, Iterator[A]] =
new ImmutableBuilder[A, Iterator[A]](empty[A]) {
Expand Down
4 changes: 2 additions & 2 deletions scala2-library-cc/src/scala/collection/Map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ trait MapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C]
* @return a new $coll resulting from applying the given collection-valued function
* `f` to each element of this $coll and concatenating the results.
*/
def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f))
def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] = mapFactory.from(new View.FlatMap(this, f))

/** Returns a new $coll containing the elements from the left hand operand followed by the elements from the
* right hand operand. The element type of the $coll is the most specific superclass encompassing
Expand Down Expand Up @@ -383,7 +383,7 @@ object MapOps {
def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2]^{this, f} =
self.mapFactory.from(new View.Map(filtered, f))

def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f} =
def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2]^{this, f*} =
self.mapFactory.from(new View.FlatMap(filtered, f))

override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, CC]^{this, q} =
Expand Down
2 changes: 1 addition & 1 deletion scala2-library-cc/src/scala/collection/SortedMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ object SortedMapOps {
def map[K2 : Ordering, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] =
self.sortedMapFactory.from(new View.Map(filtered, f))

def flatMap[K2 : Ordering, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
def flatMap[K2 : Ordering, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
self.sortedMapFactory.from(new View.FlatMap(filtered, f))

override def withFilter(q: ((K, V)) => Boolean): WithFilter[K, V, IterableCC, MapCC, CC]^{this, q} =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ trait StrictOptimizedIterableOps[+A, +CC[_], +C]
b.result()
}

override def flatMap[B](f: A => IterableOnce[B]^): CC[B] =
override def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B] =
strictOptimizedFlatMap(iterableFactory.newBuilder, f)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait StrictOptimizedMapOps[K, +V, +CC[_, _] <: IterableOps[_, AnyConstr, _], +C
override def map[K2, V2](f: ((K, V)) => (K2, V2)): CC[K2, V2] =
strictOptimizedMap(mapFactory.newBuilder, f)

override def flatMap[K2, V2](f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
override def flatMap[K2, V2](@caps.use f: ((K, V)) => IterableOnce[(K2, V2)]^): CC[K2, V2] =
strictOptimizedFlatMap(mapFactory.newBuilder, f)

override def concat[V2 >: V](suffix: IterableOnce[(K, V2)]^): CC[K, V2] =
Expand Down
4 changes: 2 additions & 2 deletions scala2-library-cc/src/scala/collection/View.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ object View extends IterableFactory[View] {

/** A view that flatmaps elements of the underlying collection. */
@SerialVersionUID(3L)
class FlatMap[A, B](underlying: SomeIterableOps[A]^, f: A => IterableOnce[B]^) extends AbstractView[B] {
def iterator: Iterator[B]^{underlying, f} = underlying.iterator.flatMap(f)
class FlatMap[A, B](underlying: SomeIterableOps[A]^, @caps.use f: A => IterableOnce[B]^) extends AbstractView[B] {
def iterator: Iterator[B]^{underlying, f*} = underlying.iterator.flatMap(f)
override def knownSize: Int = if (underlying.knownSize == 0) 0 else super.knownSize
override def isEmpty: Boolean = iterator.isEmpty
}
Expand Down
2 changes: 1 addition & 1 deletion scala2-library-cc/src/scala/collection/WithFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ abstract class WithFilter[+A, +CC[_]] extends Serializable {
* of the filtered outer $coll and
* concatenating the results.
*/
def flatMap[B](f: A => IterableOnce[B]^): CC[B]^{this, f}
def flatMap[B](@caps.use f: A => IterableOnce[B]^): CC[B]^{this, f*}

/** Applies a function `f` to all elements of the `filtered` outer $coll.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,15 +592,15 @@ final class LazyListIterable[+A] private(@untrackedCaptures lazyState: () => Laz
*/
// optimisations are not for speed, but for functionality
// see tickets #153, #498, #2147, and corresponding tests in run/ (as well as run/stream_flatmap_odds.scala)
override def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} =
override def flatMap[B](@caps.use f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} =
if (knownIsEmpty) LazyListIterable.empty
else LazyListIterable.flatMapImpl(this, f)

/** @inheritdoc
*
* $preservesLaziness
*/
override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this} = flatMap(asIterable)
override def flatten[B](implicit asIterable: A -> IterableOnce[B]): LazyListIterable[B]^{this, asIterable*} = flatMap(asIterable)

/** @inheritdoc
*
Expand Down Expand Up @@ -1061,11 +1061,11 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
}
}

private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f} = {
private def flatMapImpl[A, B](ll: LazyListIterable[A]^, f: A => IterableOnce[B]^): LazyListIterable[B]^{ll, f*} = {
// DO NOT REFERENCE `ll` ANYWHERE ELSE, OR IT WILL LEAK THE HEAD
var restRef: LazyListIterable[A]^{ll} = ll // restRef is captured by closure arg to newLL, so A is not recognized as parametric
newLL {
var it: Iterator[B]^{ll, f} = null
var it: Iterator[B]^{ll, f*} = null
var itHasNext = false
var rest = restRef // var rest = restRef.elem
while (!itHasNext && !rest.isEmpty) {
Expand Down Expand Up @@ -1307,7 +1307,7 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
extends collection.WithFilter[A, LazyListIterable] {
private[this] val filtered = lazyList.filter(p)
def map[B](f: A => B): LazyListIterable[B]^{this, f} = filtered.map(f)
def flatMap[B](f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f} = filtered.flatMap(f)
def flatMap[B](@caps.use f: A => IterableOnce[B]^): LazyListIterable[B]^{this, f*} = filtered.flatMap(f)
def foreach[U](f: A => U): Unit = filtered.foreach(f)
def withFilter(q: A => Boolean): collection.WithFilter[A, LazyListIterable]^{this, q} = new WithFilter(filtered, q)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ sealed abstract class List[+A]
}
}

final override def flatMap[B](f: A => IterableOnce[B]^): List[B] = {
final override def flatMap[B](@caps.use f: A => IterableOnce[B]^): List[B] = {
var rest = this
var h: ::[B] = null
var t: ::[B] = null
Expand Down
Loading
Loading