Hug*_*ira 10 scala fold trampolines tail-call-optimization
让我们从 的简单定义开始foldRight
:
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
as match {
case Nil => base
case head +: next => f(head, foldRight(base)(f)(next))
}
}
Run Code Online (Sandbox Code Playgroud)
这种组合器的优点之一是它允许我们编写类似的内容(我使用 anif
来使 的短路行为||
更加明确):
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
foldRight(false)((el: T, acc) => if (el == e) true else acc)(as)
}
Run Code Online (Sandbox Code Playgroud)
然后它适用于无限结构:
val bs = 0 #:: 1 #:: 2 #:: 3 #:: LazyList.continually(1)
containsElement(3)(bs)
Run Code Online (Sandbox Code Playgroud)
然而,它不适用于非常长的序列,因为我们正在炸毁堆栈:
val veryLongList = List.fill(1_000_000)(0) :+ 3
containsElement(3)(veryLongList)
Run Code Online (Sandbox Code Playgroud)
...将导致java.lang.StackOverflowError
.
输入scala.util.control.TailCalls
。我们可以编写一个非常专门的实现containsElement
来利用 TCO,例如:
def containsElement[T](e: T)(as: Seq[T]) = {
import scala.util.control.TailCalls._
def _rec(as: Seq[T]): TailRec[Boolean] = {
as match {
case Nil => done(false)
case head +: next => if (head == e) done(true) else _rec(next)
}
}
_rec(as).result
}
Run Code Online (Sandbox Code Playgroud)
但现在我想将其概括为foldRight
. 以下代码是我通过增量重构得到的,但如果我继续遵循这条路径,我会遇到这样一个事实:我需要将签名更改f
为根本不是f: (T, => TailRec[U]) => U
我想要的:
def containsElement[T](e: T)(as: Seq[T]) = {
import scala.util.control.TailCalls._
val base = false
def f(head: T, next: => TailRec[Boolean]): TailRec[Boolean] = if (head == e) done(true) else next
def _rec(as: Seq[T]): TailRec[Boolean] = {
as match {
case Nil => done(base)
case head +: next => f(head, _rec(next))
}
}
_rec(as).result
}
Run Code Online (Sandbox Code Playgroud)
问题:我们如何创建一个实现foldRight
(a)保留签名,(b)适用于无限结构,(c)在很长的结构中[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U
不会崩溃?StackOverflowError
这是无法完成的(至少在堆栈有限的单个 JVM 线程上无法完成)。
要求 (a)、(b)、(c) 的组合不可避免地会导致病态的解决方案(请参阅下面的“线程间递归”以及附录)。
在签名中
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U
Run Code Online (Sandbox Code Playgroud)
该类型(T, => U) => U
意味着
f
无法返回;f
f
必须能够从头到尾在其堆栈帧中保存一些数据;f
同时共存再多的蹦床技巧也无济于事,因为你无法改变因果关系的运作方式。
上述论点导致了一个有效的实现(参见下面的“线程间递归”部分)。然而,生成数千个线程作为堆栈帧的容器是非常不自然的,并且表明应该更改签名。
我们想要解释如何实现它,而不是仅仅提供带有现成解决方案的代码片段。事实上,我们希望得出两种不同的解决方案:
TailRec
/逐渐得出更好的解决方案Eval
。该答案的其余部分结构如下:
我们首先看一些失败的方法,只是为了更好地理解问题,并收集一些需求:
“线程间递归”部分提供了正式满足问题的所有三个要求((a)、(b)、(c))的解决方案,但产生了无限数量的线程。由于生成多个线程不是一个可行的解决方案,这迫使我们放弃原始签名。
然后我们研究替代签名:
TailRec
可以替换为cats.Eval
. def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
as match {
case head +: next => f(head, foldRight(base)(f)(next))
case _ => base
}
}
Run Code Online (Sandbox Code Playgroud)
问题第一段中尝试的朴素递归(在许多博客和文章中也提到过,例如此处)具有两个看似矛盾的属性:
这里并不存在悖论。第一个属性意味着,如果可以通过仅查看序列开头的几个元素来确定解决方案,则朴素递归foldRight
可以跳过尾部的无限多个元素。然而,第二个属性意味着该解决方案不能在序列开头查看太多元素。
第一个属性是该解决方案的优点:f
一旦处理完实际需要的数据,我们绝对应该提供跳过序列其余部分的可能性。
TailRec
在对该问题的评论中提到了基于此的解决方案:TailRec
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
import scala.util.control.TailCalls._
@annotation.tailrec
def _foldRightCPS(fu: U => TailRec[U])(as: Seq[T]): TailRec[U] = {
as match {
case Nil => fu(base)
case head +: next => _foldRightCPS(u => tailcall(fu(f(head, u))))(next)
}
}
_foldRightCPS(u => done(u))(as).result
}
Run Code Online (Sandbox Code Playgroud)
这里的问题是它实际上不适用于无限流(因为它没有给f
决定是否继续的机会,所以它不能在“看到无限流的结尾”之前“提前停止” - 这永远不会发生)。
f
它通过拒绝第 -th 次调用来“避免”无限数量的同时共存堆栈帧的问题,以及决定是否需要k
第 -th 次调用的可能性。(k+1)
对于有限列表,这允许这样做
f
开始就没有“最右边的”,那么这个方案就会失败。这个解决方案的好处是TailRec
蹦床的想法。虽然它本身还不够,但稍后会派上用场(在“ TailRec Done Properly ”中)。
假设我们不想犯与上一节相同的错误:我们绝对希望让我们f
决定是否应该查看序列的其余部分。
f
在简介中,我们声称这导致了无限数量的堆栈帧必须能够同时共存于内存中的要求。
为了看到这一点,我们只需要一个反例,它清楚地表明我们必须能够在内存中保留无限数量的堆栈帧。
f: (T, => U) => U
让我们专门化forT = Boolean
和的签名U = Unit
,并考虑一个由一百万个 组成的列表true
,后跟一个false
。
假设f
实现如下:
(t, u) => {
val x = util.Random.nextInt // this value exists at the beginning
println(x)
if (t) {
u // `f` cannot exit until we're done with the rest
}
println(x) // the value must exist until `f` returns
}
Run Code Online (Sandbox Code Playgroud)
第一次调用f
x
;u
,调用f
第二次、第三次、第四次……第一百万次。f
退出,它仍然必须有权访问x
其堆栈帧中的 。因此,一百万个随机值必须存储在x
一百万个堆栈帧中的一百万个局部变量中,所有这些变量必须同时处于活动状态。
在 JVM 上,无法获取堆栈帧并将其转换为其他内容(例如,堆分配的对象)。
除非您调整 JVM 设置并允许堆栈无限增长,否则您必须限制堆栈的高度。在尊重最大堆栈高度的同时拥有无限数量的堆栈帧意味着您需要多个堆栈。为了拥有多个堆栈,需要启动多个线程。
这确实导致了一个正式满足您所有三个要求的解决方案:
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
// The number of active stack-frames remains bounded for each stack
val MaxFrames = 1000
// This recursively spawns multiple threads;
def interthreadRecursion(remainingSeq: Seq[T]): U = {
// a synchronized mailbox that we use to pass the result
// from the child thread to the caller
val resultFromChildThread = new java.util.concurrent.atomic.AtomicReference[U]
val t = new Thread {
def stackRec(remainingSeq: Seq[T], remainingFrames: Int): U = {
if (remainingFrames == 0) {
// Note that this happens in a different thread,
// the frames of `interthreadRecursion` belong to
// separate stacks
interthreadRecursion(remainingSeq)
} else {
remainingSeq match {
case Nil => base
case head +: next => f(head, stackRec(next, remainingFrames - 1))
}
}
}
override def run(): Unit = {
// start the stack-level recursion
resultFromChildThread.set(stackRec(remainingSeq, MaxFrames))
}
}
t.start()
t.join()
// return the result to the caller
resultFromChildThread.get()
}
// start the thread-level recursion
interthreadRecursion(as)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
foldRight(false)((el: T, acc) => if (el == e) true else acc)(as)
}
Run Code Online (Sandbox Code Playgroud)
它使您的foldRight
签名完全保持原样 (a),并且它适用于您的两个测试用例(没有base
-case (b) 的无限流,以及具有一百万个条目的列表 (c))。
但是创建无限数量的线程作为堆栈帧的容器显然是疯狂的。因此,如果我们想保留(b)和(c),我们就被迫放弃签名(a)。
Either[U => U, U]
我们如何修改签名以便我们可以解决(b)和(c),但不创建多个线程?
这是一个简单且说明性的解决方案,它通过单个线程完成,并且也不依赖任何预先存在的堆栈安全/蹦床框架:
import util.{Either, Left, Right}
def foldRight[T, U](base: U)(f: T => Either[U => U, U])(as: Seq[T]): U = {
@annotation.tailrec
def rec(remainingAs: Seq[T], todoSteps: List[U => U]): U =
remainingAs match
case head +: tail => f(head) match
case Left(step) => rec(tail, step :: todoSteps)
case Right(done) => todoSteps.foldLeft(done)((r, s) => s(r))
case _ => todoSteps.foldLeft(base)((r, s) => s(r))
rec(as, Nil)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
foldRight(false)((el: T) => if (el == e) Right(true) else Left(identity))(as)
}
Run Code Online (Sandbox Code Playgroud)
它适用于你的两个例子。辅助方法rec
显然是尾递归的,它不需要任何难以掌握的库。
签名的变化是f
允许查看T
然后返回一个Either[U => U, U]
,对应Right[U]
于当前调用的最终结果rec
,Left[U => U]
对应于需要查看尾部结果并进行后处理的情况某种方式。
该解决方案有效的原因是它创建了堆分配的闭包U => U
,这些闭包存储在普通的List
. 先前保存在堆栈帧中的信息被移至堆中,因此不存在堆栈溢出的可能性。
不过,该解决方案至少有一个缺点:对于像 之类的简单函数containsElement
,它会创建一个非常长的函数List[U => U]
,其中仅包含不执行任何操作的恒等函数。就连 GC 也抱怨道:
[warn] In the last 9 seconds, 5.043 (59.9%) were spent in GC. [Heap: 0.91GB free of 1.00GB, max 1.00GB] Consider increasing the JVM heap using `-Xmx` or try a different collector, e.g. `-XX:+UseG1GC`, for better performance.
Run Code Online (Sandbox Code Playgroud)
我们能以某种方式摆脱这个列表吗?(我们将此新要求称为 (d))。
List[U => U]
我们在上一节中需要 a 的原因是f
需要一种方法来“后处理”从序列尾部返回的结果。由于像这样的简单函数containsElement
实际上并不需要这个,我们可能会尝试尝试更简单的递归方案,如下所示:
/** If `f` returns `Some[U]`, then this is the result.
* If `f` returns `None`, recursively look at the tail.
*/
def collectFirst[T, U](base: U)(f: T => Option[U])(as: Seq[T]): U = {
@annotation.tailrec
def rec(remainingAs: Seq[T]): U =
remainingAs match
case head +: tail => f(head) match
case Some(done) => done
case None => rec(tail)
case _ => base
rec(as)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
collectFirst(false)((el: T) => if (el == e) Some(true) else None)(as)
}
Run Code Online (Sandbox Code Playgroud)
不幸的是,尽管它对于 来说已经足够了containsElement
,但它的表现力不如 true foldRight
(请参阅“完整代码”部分中的nestBrackets
和foldNonassocOp
,了解无法通过 表达的函数的具体示例collectFirst
)。
TailRec
:TailRec
做得对在查看了其他一些失败的尝试后,我们回到TailRec
:
def foldRight[T, U](base: U)(f: (T, TailRec[U]) => TailRec[U])(as: Seq[T]): U = {
def rec(remaining: Seq[T]): TailRec[U] =
remaining match
case head +: tail => tailcall(f(head, rec(tail)))
case _ => done(base)
rec(as).result
}
Run Code Online (Sandbox Code Playgroud)
可以这样使用:
def containsElement[T](e: T)(as: Seq[T]): Boolean = (
foldRight[T, Boolean]
(false)
((elem, rest) => if (elem == e) done(true) else rest)
(as)
)
Run Code Online (Sandbox Code Playgroud)
这满足(b)、(c)和(d)。另请注意,它与使用嵌套辅助方法的朴素递归有多么相似(请参阅完整代码)。
/** The naive recursion proposed at the beginning of the question.*/
object NaiveRecursive extends SignaturePreservingApproach:
def description = """Naive recursive (from question itself, 1st attempt)"""
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
as match {
case head +: next => f(head, foldRight(base)(f)(next))
case _ => base
}
}
/** This attempts to use TailRec, but fails on infinite streams */
object TailRecFromUsersScalaLangOrg extends SignaturePreservingApproach:
def description = "TailRec (from link to discussion on scala-lang.org)"
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
import scala.util.control.TailCalls._
@annotation.tailrec
def _foldRightCPS(fu: U => TailRec[U])(as: Seq[T]): TailRec[U] = {
as match {
case Nil => fu(base)
case head +: next => _foldRightCPS(u => tailcall(fu(f(head, u))))(next)
}
}
_foldRightCPS(u => done(u))(as).result
}
/** The solution that satisfies all properties (a), (b), (c),
* but requires multiple threads.
*/
object InterThreadRecursion extends SignaturePreservingApproach:
def description = "Inter-thread recursion"
def foldRight[T, U](base: U)(f: (T, => U) => U)(as: Seq[T]): U = {
// The number of active stack-frames remains bounded for each stack
val MaxFrames = 1000
// This recursively spawns multiple threads;
def interthreadRecursion(remainingSeq: Seq[T]): U = {
// a synchronized mailbox that we use to pass the result
// from the child thread to the caller
val resultFromChildThread = new java.util.concurrent.atomic.AtomicReference[U]
val t = new Thread {
def stackRec(remainingSeq: Seq[T], remainingFrames: Int): U = {
if (remainingFrames == 0) {
// Note that this happens in a different thread,
// the frames of `interthreadRecursion` belong to
// separate stacks
interthreadRecursion(remainingSeq)
} else {
remainingSeq match {
case Nil => base
case head +: next => f(head, stackRec(next, remainingFrames - 1))
}
}
}
override def run(): Unit = {
// start the stack-level recursion
resultFromChildThread.set(stackRec(remainingSeq, MaxFrames))
}
}
t.start()
t.join()
// return the result to the caller
resultFromChildThread.get()
}
// start the thread-level recursion
interthreadRecursion(as)
}
/** A solution that works for (b) and (c), but requires (!a). */
object ReturnEither extends SolutionApproach:
import util.{Either, Left, Right}
def description = "Return Either[U => U, U]"
def foldRight[T, U](base: U)(f: T => Either[U => U, U])(as: Seq[T]): U = {
@annotation.tailrec
def rec(remainingAs: Seq[T], todoSteps: List[U => U]): U =
remainingAs match
case head +: tail => f(head) match
case Left(step) => rec(tail, step :: todoSteps)
case Right(done) => todoSteps.foldLeft(done)((r, s) => s(r))
case _ => todoSteps.foldLeft(base)((r, s) => s(r))
rec(as, Nil)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
foldRight(false)((el: T) => if (el == e) Right(true) else Left(identity))(as)
}
def nestBrackets(labels: Seq[String], center: String): String =
foldRight(center)(l => Left(w => s"[${l}${w}${l}]"))(labels)
def foldNonassocOp(numbers: Seq[Int]): Int = (
foldRight
(0)
(
(n: Int) =>
if n == 0
then Right(0)
else Left((h: Int) => nonassocOp(n, h))
)
(numbers)
)
/** An overly restrictive signature that leads to expressiveness loss. */
object NoPostprocessing extends SolutionApproach:
import util.{Either, Left, Right}
def description = "No postprocessing"
def collectFirst[T, U](base: U)(f: T => Option[U])(as: Seq[T]): U = {
@annotation.tailrec
def rec(remainingAs: Seq[T]): U =
remainingAs match
case head +: tail => f(head) match
case Some(done) => done
case None => rec(tail)
case _ => base
rec(as)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = {
collectFirst(false)((el: T) => if (el == e) Some(true) else None)(as)
}
def nestBrackets(labels: Seq[String], center: String): String = ???
def foldNonassocOp(numbers: Seq[Int]): Int = ???
/** This is just to demonstrate syntactic similarity to EvalApproach */
object SyntacticallySimilarToEval extends SolutionApproach:
def description = "Syntactically analogous to Eval"
def foldRight[T, U](base: U)(f: (T, U) => U)(as: Seq[T]): U = {
def rec(remaining: Seq[T]): U =
remaining match
case head +: tail => f(head, rec(tail))
case _ => base
rec(as)
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = (
foldRight[T, Boolean]
(false)
((elem, rest) => if (elem == e) true else rest)
(as)
)
def nestBrackets(labels: Seq[String], center: String): String = (
foldRight[String, String]
(center)
((label, middle) => s"[${label}${middle}${label}]")
(labels)
)
def foldNonassocOp(numbers: Seq[Int]): Int = (
foldRight[Int, Int]
(0)
(
(n, acc) =>
if n == 0
then 0
else nonassocOp(n, acc)
)
(numbers)
)
/** A `TailRec`-based solution that works */
object TailRecDoneRight extends SolutionApproach:
def description = "Ok solution with TailRec"
import util.control.TailCalls._
def foldRight[T, U](base: U)(f: (T, TailRec[U]) => TailRec[U])(as: Seq[T]): U = {
def rec(remaining: Seq[T]): TailRec[U] =
remaining match
case head +: tail => tailcall(f(head, rec(tail)))
case _ => done(base)
rec(as).result
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = (
foldRight[T, Boolean]
(false)
((elem, rest) => if (elem == e) done(true) else rest)
(as)
)
def nestBrackets(labels: Seq[String], center: String): String = (
foldRight[String, String]
(center)
((label, middle) => middle.map(m => (s"[${label}${m}${label}]")))
(labels)
)
def foldNonassocOp(numbers: Seq[Int]): Int = (
foldRight[Int, Int]
(0)
(
(n, acc) =>
if n == 0
then done(0)
else acc.map(nonassocOp(n, _))
)
(numbers)
)
/** Bonus: same as "TailRec Done Right", but with `cats.Eval` */
/* requires `cats`
object EvalApproach extends SolutionApproach:
def description = "Nice solution with Eval"
import cats.Eval
def foldRight[T, U](base: U)(f: (T, Eval[U]) => Eval[U])(as: Seq[T]): U = {
def rec(remaining: Seq[T]): Eval[U] =
remaining match
case head +: tail => Eval.defer(f(head, rec(tail)))
case _ => Eval.now(base)
rec(as).value
}
def containsElement[T](e: T)(as: Seq[T]): Boolean = (
foldRight[T, Boolean]
(false)
((elem, rest) => if (elem == e) Eval.now(true) else rest)
(as)
)
def nestBrackets(labels: Seq[String], center: String): String = (
foldRight[String, String]
(center)
((label, middle) => middle.map(m => (s"[${label}${m}${label}]")))
(labels)
)
def foldNonassocOp(numbers: Seq[Int]): Int = (
foldRight[Int, Int]
(0)
(
(n, acc) =>
if n == 0
then Eval.now(0)
else acc.map(nonassocOp(n, _))
)
(numbers)
)
*/
/** A base trait for a series of experiments using
* one particular approach to foldRight implementation.
*/
trait SolutionApproach {
/** Short description that uniquely identifies
* the approach within the context of the question.
*/
def description: String
/** Does it respect the signature requested in the question? */
def respectsSignature: Boolean = false
/** Checks whether `as` contains `e`. */
def containsElement[T](e: T)(as: Seq[T]): Boolean
/** Puts labeled brackets around a central string.
*
* E.g. `nestBrackets(List("1", "2", "3"), "<*>")) == "[1[2[3<*>3]2]1]".
*/
def nestBrackets(bracketLabels: Seq[String], center: String): String
/** A non-associative operation on integers with
* the property that `nonassocOp(0, x) = 0`.
*/
def nonassocOp(a: Int, b: Int): Int = a * s"string${b}".hashCode
/** fold-right with nonassocOp */
def foldNonassocOp(numbers: Seq[Int]): Int
/** Runs a single experiment, prints the description of the outcome */
private def runExperiment[A](label: String)(intendedOutcome: A)(body: => A): Unit = {
val resultRef = new java.util.concurrent.atomic.AtomicReference[util.Either[String, A]]()
val t = new Thread {
override def run(): Unit = {
try {
resultRef.set(util.Right(body))
} catch {
case so: StackOverflowError => resultRef.set(util.Left("StackOverflowError"))
case ni: NotImplementedError => resultRef.set(util.Left("Not implementable"))
}
}
}
val killer = new Thread {
override def run(): Unit = {
Thread.sleep(2000)
if (t.isAlive) {
// Yes, it's bad, it damages objects, blah blah blah...
t.stop()
resultRef.set(util.Left("Timed out"))
}
}
}
t.start()
killer.start()
t.join()
val result = resultRef.get()
val outcomeString =
result match
case util.Left(err) => s"[failure] ${err}"
case util.Right(r) => {
if (r == intendedOutcome) {
s"[success] ${r} (as expected)"
} else {
s"[failure] ${r} (expected ${intendedOutcome})"
}
}
val formattedOutcome = "%-40s %s".format(
s"${label}:",
outcomeString
)
println(formattedOutcome)
}
/** Runs multiple experiments. */
def runExperiments(): Unit = {