|
| 1 | +/* |
| 2 | + * Copyright 2010-2020 JetBrains s.r.o. and Kotlin Programming Language contributors. |
| 3 | + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. |
| 4 | + */ |
| 5 | + |
| 6 | +package kotlin |
| 7 | + |
| 8 | +import kotlin.coroutines.* |
| 9 | +import kotlin.coroutines.intrinsics.* |
| 10 | + |
| 11 | +/** |
| 12 | + * Defines deep recursive function that keeps its stack on the heap, |
| 13 | + * which allows very deep recursive computations that do not use the actual call stack. |
| 14 | + * To initiate a call to this deep recursive function use its [invoke] function. |
| 15 | + * As a rule of thumb, it should be used if recursion goes deeper than a thousand calls. |
| 16 | + * |
| 17 | + * The [DeepRecursiveFunction] takes one parameter of type [T] and returns a result of type [R]. |
| 18 | + * The [block] of code defines the body of a recursive function. In this block |
| 19 | + * [callRecursive][DeepRecursiveScope.callRecursive] function can be used to make a recursive call |
| 20 | + * to the declared function. Other instances of [DeepRecursiveFunction] can be called |
| 21 | + * in this scope with `callRecursive` extension, too. |
| 22 | + * |
| 23 | + * For example, take a look at the following recursive tree class and a deeply |
| 24 | + * recursive instance of this tree with 100K nodes: |
| 25 | + * |
| 26 | + * ``` |
| 27 | + * class Tree(val left: Tree? = null, val right: Tree? = null) |
| 28 | + * val deepTree = generateSequence(Tree()) { Tree(it) }.take(100_000).last() |
| 29 | + * ``` |
| 30 | + * |
| 31 | + * A regular recursive function can be defined to compute a depth of a tree: |
| 32 | + * |
| 33 | + * ``` |
| 34 | + * fun depth(t: Tree?): Int = |
| 35 | + * if (t == null) 0 else max(depth(t.left), depth(t.right)) + 1 |
| 36 | + * println(depth(deepTree)) // StackOverflowError |
| 37 | + * ``` |
| 38 | + * |
| 39 | + * If this `depth` function is called for a `deepTree` it produces [StackOverflowError] because of deep recursion. |
| 40 | + * However, the `depth` function can be rewritten using `DeepRecursiveFunction` in the following way, and then |
| 41 | + * it successfully computes [`depth(deepTree)`][DeepRecursiveFunction.invoke] expression: |
| 42 | + * |
| 43 | + * ``` |
| 44 | + * val depth = DeepRecursiveFunction<Tree?, Int> { t -> |
| 45 | + * if (t == null) 0 else max(callRecursive(t.left), callRecursive(t.right)) + 1 |
| 46 | + * } |
| 47 | + * println(depth(deepTree)) // Ok |
| 48 | + * ``` |
| 49 | + * |
| 50 | + * Deep recursive functions can also mutually call each other using a heap for the stack via |
| 51 | + * [callRecursive][DeepRecursiveScope.callRecursive] extension. For example, the |
| 52 | + * following pair of mutually recursive functions computes the number of tree nodes at even depth in the tree. |
| 53 | + * |
| 54 | + * ``` |
| 55 | + * val mutualRecursion = object { |
| 56 | + * val even: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t -> |
| 57 | + * if (t == null) 0 else odd.callRecursive(t.left) + odd.callRecursive(t.right) + 1 |
| 58 | + * } |
| 59 | + * val odd: DeepRecursiveFunction<Tree?, Int> = DeepRecursiveFunction { t -> |
| 60 | + * if (t == null) 0 else even.callRecursive(t.left) + even.callRecursive(t.right) |
| 61 | + * } |
| 62 | + * } |
| 63 | + * ``` |
| 64 | + * |
| 65 | + * @param [T] the function parameter type. |
| 66 | + * @param [R] the function result type. |
| 67 | + * @param block the function body. |
| 68 | + */ |
| 69 | +@SinceKotlin("1.4") |
| 70 | +@ExperimentalStdlibApi |
| 71 | +public class DeepRecursiveFunction<T, R>( |
| 72 | + internal val block: suspend DeepRecursiveScope<T, R>.(T) -> R |
| 73 | +) |
| 74 | + |
| 75 | +/** |
| 76 | + * Initiates a call to this deep recursive function, forming a root of the call tree. |
| 77 | + * |
| 78 | + * This operator should not be used from inside of [DeepRecursiveScope] as it uses the call stack slot for |
| 79 | + * initial recursive invocation. From inside of [DeepRecursiveScope] use |
| 80 | + * [callRecursive][DeepRecursiveScope.callRecursive]. |
| 81 | + */ |
| 82 | +@SinceKotlin("1.4") |
| 83 | +@ExperimentalStdlibApi |
| 84 | +public operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(value: T): R = |
| 85 | + DeepRecursiveScopeImpl<T, R>(block, value).runCallLoop() |
| 86 | + |
| 87 | +/** |
| 88 | + * A scope class for [DeepRecursiveFunction] function declaration that defines [callRecursive] methods to |
| 89 | + * recursively call this function or another [DeepRecursiveFunction] putting the call activation frame on the heap. |
| 90 | + * |
| 91 | + * @param [T] function parameter type. |
| 92 | + * @param [R] function result type. |
| 93 | + */ |
| 94 | +@RestrictsSuspension |
| 95 | +@SinceKotlin("1.4") |
| 96 | +@ExperimentalStdlibApi |
| 97 | +public sealed class DeepRecursiveScope<T, R> { |
| 98 | + /** |
| 99 | + * Makes recursive call to this [DeepRecursiveFunction] function putting the call activation frame on the heap, |
| 100 | + * as opposed to the actual call stack that is used by a regular recursive call. |
| 101 | + */ |
| 102 | + public abstract suspend fun callRecursive(value: T): R |
| 103 | + |
| 104 | + /** |
| 105 | + * Makes call to the specified [DeepRecursiveFunction] function putting the call activation frame on the heap, |
| 106 | + * as opposed to the actual call stack that is used by a regular call. |
| 107 | + */ |
| 108 | + public abstract suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S |
| 109 | + |
| 110 | + @Deprecated( |
| 111 | + level = DeprecationLevel.ERROR, |
| 112 | + message = |
| 113 | + "'invoke' should not be called from DeepRecursiveScope. " + |
| 114 | + "Use 'callRecursive' to do recursion in the heap instead of the call stack.", |
| 115 | + replaceWith = ReplaceWith("this.callRecursive(value)") |
| 116 | + ) |
| 117 | + @Suppress("UNUSED_PARAMETER") |
| 118 | + public operator fun DeepRecursiveFunction<*, *>.invoke(value: Any?): Nothing = |
| 119 | + throw UnsupportedOperationException("Should not be called from DeepRecursiveScope") |
| 120 | +} |
| 121 | + |
| 122 | +// ================== Implementation ================== |
| 123 | + |
| 124 | +@ExperimentalStdlibApi |
| 125 | +private typealias DeepRecursiveFunctionBlock = suspend DeepRecursiveScope<*, *>.(Any?) -> Any? |
| 126 | + |
| 127 | +private val UNDEFINED_RESULT = Result.success(COROUTINE_SUSPENDED) |
| 128 | + |
| 129 | +@Suppress("UNCHECKED_CAST") |
| 130 | +@ExperimentalStdlibApi |
| 131 | +private class DeepRecursiveScopeImpl<T, R>( |
| 132 | + block: suspend DeepRecursiveScope<T, R>.(T) -> R, |
| 133 | + value: T |
| 134 | +) : DeepRecursiveScope<T, R>(), Continuation<R> { |
| 135 | + // Active function block |
| 136 | + private var function: DeepRecursiveFunctionBlock = block as DeepRecursiveFunctionBlock |
| 137 | + |
| 138 | + // Value to call function with |
| 139 | + private var value: Any? = value |
| 140 | + |
| 141 | + // Continuation of the current call |
| 142 | + private var cont: Continuation<Any?>? = this as Continuation<Any?> |
| 143 | + |
| 144 | + // Completion result (completion of the whole call stack) |
| 145 | + private var result: Result<Any?> = UNDEFINED_RESULT |
| 146 | + |
| 147 | + override val context: CoroutineContext |
| 148 | + get() = EmptyCoroutineContext |
| 149 | + |
| 150 | + override fun resumeWith(result: Result<R>) { |
| 151 | + this.cont = null |
| 152 | + this.result = result |
| 153 | + } |
| 154 | + |
| 155 | + override suspend fun callRecursive(value: T): R = suspendCoroutineUninterceptedOrReturn { cont -> |
| 156 | + // calling the same function that is currently active |
| 157 | + this.cont = cont as Continuation<Any?> |
| 158 | + this.value = value |
| 159 | + COROUTINE_SUSPENDED |
| 160 | + } |
| 161 | + |
| 162 | + override suspend fun <U, S> DeepRecursiveFunction<U, S>.callRecursive(value: U): S = suspendCoroutineUninterceptedOrReturn { cont -> |
| 163 | + // calling another recursive function |
| 164 | + val function = block as DeepRecursiveFunctionBlock |
| 165 | + with(this@DeepRecursiveScopeImpl) { |
| 166 | + val currentFunction = this.function |
| 167 | + if (function !== currentFunction) { |
| 168 | + // calling a different function -- create a trampoline to restore function ref |
| 169 | + this.function = function |
| 170 | + this.cont = crossFunctionCompletion(currentFunction, cont as Continuation<Any?>) |
| 171 | + } else { |
| 172 | + // calling the same function -- direct |
| 173 | + this.cont = cont as Continuation<Any?> |
| 174 | + } |
| 175 | + this.value = value |
| 176 | + } |
| 177 | + COROUTINE_SUSPENDED |
| 178 | + } |
| 179 | + |
| 180 | + private fun crossFunctionCompletion( |
| 181 | + currentFunction: DeepRecursiveFunctionBlock, |
| 182 | + cont: Continuation<Any?> |
| 183 | + ): Continuation<Any?> = Continuation(EmptyCoroutineContext) { |
| 184 | + this.function = currentFunction |
| 185 | + // When going back from a trampoline we cannot just call cont.resume (stack usage!) |
| 186 | + // We delegate the cont.resumeWith(it) call to runCallLoop |
| 187 | + this.cont = cont |
| 188 | + this.result = it |
| 189 | + } |
| 190 | + |
| 191 | + @Suppress("UNCHECKED_CAST") |
| 192 | + fun runCallLoop(): R { |
| 193 | + while (true) { |
| 194 | + // Note: cont is set to null in DeepRecursiveScopeImpl.resumeWith when the whole computation completes |
| 195 | + val result = this.result |
| 196 | + val cont = this.cont |
| 197 | + ?: return (result as Result<R>).getOrThrow() // done -- final result |
| 198 | + // The order of comparison is important here for that case of rogue class with broken equals |
| 199 | + if (UNDEFINED_RESULT == result) { |
| 200 | + // call "function" with "value" using "cont" as completion |
| 201 | + val r = try { |
| 202 | + // This is block.startCoroutine(this, value, cont) |
| 203 | + function.startCoroutineUninterceptedOrReturn(this, value, cont) |
| 204 | + } catch (e: Throwable) { |
| 205 | + cont.resumeWithException(e) |
| 206 | + continue |
| 207 | + } |
| 208 | + // If the function returns without suspension -- calls its continuation immediately |
| 209 | + if (r !== COROUTINE_SUSPENDED) |
| 210 | + cont.resume(r as R) |
| 211 | + } else { |
| 212 | + // we returned from a crossFunctionCompletion trampoline -- call resume here |
| 213 | + this.result = UNDEFINED_RESULT // reset result back |
| 214 | + cont.resumeWith(result) |
| 215 | + } |
| 216 | + } |
| 217 | + } |
| 218 | +} |
0 commit comments