是否可以使用延续来使 FoldRight 尾部递归?

发布于 2024-12-21 13:03:59 字数 1524 浏览 1 评论 0原文

以下博客文章展示了如何在 F# foldBack 中 可以使用连续传递风格进行尾递归。

在 Scala 中,这意味着:

def foldBack[T,U](l: List[T], acc: U)(f: (T, U) => U): U = {
  l match {
    case x :: xs => f(x, foldBack(xs, acc)(f))
    case Nil => acc
  }
} 

可以通过这样做进行尾递归:

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    l match {
      case x :: xs => loop(xs, (racc => k(f(x, racc))))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

不幸的是,对于长列表,我仍然遇到堆栈溢出。循环是尾递归和优化的,但我猜堆栈积累只是移到了继续调用中。

为什么这对于 F# 来说不是问题? Scala 有什么办法可以解决这个问题吗?

编辑:这里有一些显示堆栈深度的代码:

def showDepth(s: Any) {
  println(s.toString + ": " + (new Exception).getStackTrace.size)
}

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    showDepth("loop")
    l match {
      case x :: xs => loop(xs, (racc => { showDepth("k"); k(f(x, racc)) }))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

foldCont(List.fill(10)(1), 0)(_ + _)

这会打印:

loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
k: 51
k: 52
k: 53
k: 54
k: 55
k: 56
k: 57
k: 58
k: 59
k: 60
res2: Int = 10

The following blog article shows how in F# foldBack can be made tail recursive using continuation passing style.

In Scala this would mean that:

def foldBack[T,U](l: List[T], acc: U)(f: (T, U) => U): U = {
  l match {
    case x :: xs => f(x, foldBack(xs, acc)(f))
    case Nil => acc
  }
} 

can be made tail recursive by doing this:

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    l match {
      case x :: xs => loop(xs, (racc => k(f(x, racc))))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

Unfortunately, I still get a stack overflow for long lists. loop is tail recursive and optimized but I guess the stack accumulation is just moved into the continuation calls.

Why is this not a problem with F#? And is there any way to work around this with Scala?

Edit: here some code that shows depth of stack:

def showDepth(s: Any) {
  println(s.toString + ": " + (new Exception).getStackTrace.size)
}

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    showDepth("loop")
    l match {
      case x :: xs => loop(xs, (racc => { showDepth("k"); k(f(x, racc)) }))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

foldCont(List.fill(10)(1), 0)(_ + _)

This prints:

loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
k: 51
k: 52
k: 53
k: 54
k: 55
k: 56
k: 57
k: 58
k: 59
k: 60
res2: Int = 10

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(4

世俗缘 2024-12-28 13:03:59

乔恩,nm,谢谢您的回答。根据您的评论,我想我应该尝试一下使用蹦床。一些研究表明 Scala 在 TailCalls 中提供了对蹦床的库支持。这是我经过一番摆弄后得出的结论:

def foldContTC[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  import scala.util.control.TailCalls._
  @annotation.tailrec
  def loop(l: List[T], k: (U) => TailRec[U]): TailRec[U] = {
    l match {
      case x :: xs => loop(xs, (racc => tailcall(k(f(x, racc)))))
      case Nil => k(acc)
    }
  }
  loop(list, u => done(u)).result
} 

我有兴趣看看这与没有蹦床以及默认的 foldLeftfoldRight 的解决方案相比如何实施。以下是基准代码和一些结果:

val size = 1000
val list = List.fill(size)(1)
val warm = 10
val n = 1000
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldCont", warm, lots(n, foldCont(list, 0)(_ + _)))
bench("foldRight", warm, lots(n, list.foldRight(0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))

时间安排是:

foldContTC: warming...
Elapsed: 0.094
foldCont: warming...
Elapsed: 0.060
foldRight: warming...
Elapsed: 0.160
foldLeft: warming...
Elapsed: 0.076
foldLeft.reverse: warming...
Elapsed: 0.155

基于此,蹦床似乎实际上产生了相当好的性能。我怀疑装箱/拆箱之上的惩罚相对来说并没有那么糟糕。

编辑:根据乔恩的评论建议,以下是 100 万个项目的计时,证实列表较大时性能会下降。我还发现库 List.foldLeft 实现没有被覆盖,所以我用以下的foldLeft2计时:

def foldLeft2[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  list match {
    case x :: xs => foldLeft2(xs, f(x, acc))(f)
    case Nil => acc
  }
} 

val size = 1000000
val list = List.fill(size)(1)
val warm = 10
val n = 2
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft2", warm, lots(n, foldLeft2(list, 0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))
bench("foldLeft2.reverse", warm, lots(n, foldLeft2(list.reverse, 0)(_ + _)))

收益:

foldContTC: warming...
Elapsed: 0.801
foldLeft: warming...
Elapsed: 0.156
foldLeft2: warming...
Elapsed: 0.054
foldLeft.reverse: warming...
Elapsed: 0.808
foldLeft2.reverse: warming...
Elapsed: 0.221

所以foldLeft2.reverse是赢家......

Jon, n.m., thank you for your answers. Based on your comments I thought I'd give a try and use trampoline. A bit of research shows Scala has library support for trampolines in TailCalls. Here is what I came up with after a bit of fiddling around:

def foldContTC[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  import scala.util.control.TailCalls._
  @annotation.tailrec
  def loop(l: List[T], k: (U) => TailRec[U]): TailRec[U] = {
    l match {
      case x :: xs => loop(xs, (racc => tailcall(k(f(x, racc)))))
      case Nil => k(acc)
    }
  }
  loop(list, u => done(u)).result
} 

I was interested to see how this compares to the solution without the trampoline as well as the default foldLeft and foldRight implementations. Here is the benchmark code and some results:

val size = 1000
val list = List.fill(size)(1)
val warm = 10
val n = 1000
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldCont", warm, lots(n, foldCont(list, 0)(_ + _)))
bench("foldRight", warm, lots(n, list.foldRight(0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))

The timings are:

foldContTC: warming...
Elapsed: 0.094
foldCont: warming...
Elapsed: 0.060
foldRight: warming...
Elapsed: 0.160
foldLeft: warming...
Elapsed: 0.076
foldLeft.reverse: warming...
Elapsed: 0.155

Based on this, it would seem that trampolining is actually yielding pretty good performance. I suspect that the penalty on top of the boxing/unboxing is relatively not that bad.

Edit: as suggested by Jon's comments, here are the timings on 1M items which confirm that performance degrades with larger lists. Also I found out that library List.foldLeft implementation is not overriden, so I timed with the following foldLeft2:

def foldLeft2[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  list match {
    case x :: xs => foldLeft2(xs, f(x, acc))(f)
    case Nil => acc
  }
} 

val size = 1000000
val list = List.fill(size)(1)
val warm = 10
val n = 2
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft2", warm, lots(n, foldLeft2(list, 0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))
bench("foldLeft2.reverse", warm, lots(n, foldLeft2(list.reverse, 0)(_ + _)))

yields:

foldContTC: warming...
Elapsed: 0.801
foldLeft: warming...
Elapsed: 0.156
foldLeft2: warming...
Elapsed: 0.054
foldLeft.reverse: warming...
Elapsed: 0.808
foldLeft2.reverse: warming...
Elapsed: 0.221

So foldLeft2.reverse is the winner...

享受孤独 2024-12-28 13:03:59

问题在于连续函数 (racc => k(f(x, racc))) 本身。它应该针对整个业务的工作进行尾调用优化,但事实并非如此。

Scala 无法对任意尾部调用进行尾部调用优化,只能对那些可以转换为循环的尾部调用进行优化(即当函数调用自身而不是其他函数时)。

The problem is the continuation function (racc => k(f(x, racc))) itself. It should be tailcall optimized for this whole business to work, but isn't.

Scala cannot make tailcall optimizations for arbitrary tail calls, only for those it can transform into loops (i.e. when the function calls itself, not some other function).

兔姬 2024-12-28 13:03:59

为什么这对于 F# 来说不是问题?

F# 已优化所有尾部调用。

有什么办法可以用 Scala 解决这个问题吗?

您可以使用其他技术(如蹦床)来实现 TCO,但您会失去互操作性,因为它改变了调用约定,并且速度慢了约 10 倍。这是我不使用 Scala 的三个原因之一。

编辑

您的基准测试结果表明,Scala 的蹦床比我上次测试时快很多。此外,使用 F# 和较大的列表添加等效基准也很有趣(因为在小列表上执行 CPS 没有意义!)。

对于配备 1.67GHz N570 Intel Atom 的上网本上的 1,000 元素列表的 1,000x,我得到:

List.fold     0.022s
List.rev+fold 0.116s
List.foldBack 0.047s
foldContTC    0.334s

对于 1x 1,000,000 元素列表,我得到:

List.fold     0.024s
List.rev+fold 0.188s
List.foldBack 0.054s
foldContTC    0.570s

您可能还对 caml-list 上有关此问题的旧讨论感兴趣将 OCaml 的非尾递归列表函数替换为优化的尾递归函数的上下文。

Why is this not a problem with F#?

F# has all tail calls optimized.

And is there any way to work around this with Scala?

You can do TCO using other techniques like trampolines but you lose interop because it changes the calling convention and it is ~10× slower. This is one of the three reasons I don't use Scala.

EDIT

Your benchmark results indicate that Scala's trampolines are a lot faster than they were the last time I tested them. Also, it is interesting to add equivalent benchmarks using F# and for larger lists (because there's no point in doing CPS on small lists!).

For 1,000x on a 1,000-element list on my netbook with a 1.67GHz N570 Intel Atom, I get:

List.fold     0.022s
List.rev+fold 0.116s
List.foldBack 0.047s
foldContTC    0.334s

For 1x 1,000,000-element list, I get:

List.fold     0.024s
List.rev+fold 0.188s
List.foldBack 0.054s
foldContTC    0.570s

You may also be interested in the old discussions about this on the caml-list in the context of replacing OCaml's non-tail-recursive list functions with optimized tail recursive ones.

ぃ双果 2024-12-28 13:03:59

我迟到了这个问题,但我想展示如何在不使用完整蹦床的情况下编写尾递归 FoldRight;通过累积连续列表(而不是让它们在完成时互相调用,这会导致堆栈溢出)并在最后折叠它们,有点像保留堆栈,但在堆上:

object FoldRight {

  def apply[A, B](list: Seq[A])(init: B)(f: (A, B) => B): B = {
    @scala.annotation.tailrec
    def step(current: Seq[A], conts: List[B => B]): B = current match {
      case Seq(last) => conts.foldLeft(f(last, init)) { (acc, next) => next(acc) }
      case Seq(x, xs @ _*) => step(xs, { acc: B => f(x, acc) } +: conts)
      case Nil => init
    }
    step(list, Nil)
  }

}

发生在end 本身就是尾递归。 在 ScalaFiddle 中试试

性能方面,比尾调用版本表现稍差。

[info] Benchmark            (length)  Mode  Cnt   Score    Error  Units
[info] FoldRight.conts           100  avgt   30   0.003 ±  0.001  ms/op
[info] FoldRight.conts         10000  avgt   30   0.197 ±  0.004  ms/op
[info] FoldRight.conts       1000000  avgt   30  77.292 ±  9.327  ms/op
[info] FoldRight.standard        100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.standard      10000  avgt   30   0.154 ±  0.036  ms/op
[info] FoldRight.standard    1000000  avgt   30  18.796 ±  0.551  ms/op
[info] FoldRight.tailCalls       100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.tailCalls     10000  avgt   30   0.176 ±  0.004  ms/op
[info] FoldRight.tailCalls   1000000  avgt   30  33.525 ±  1.041  ms/op

I'm late to this question, but I wanted to show how you can write a tail-recursive FoldRight without using a full trampoline; by accumulating a list of continuations (instead of having them call each other when done, which leads to a stack overflow) and folding over them at the end, kind of like keeping a stack, but on the heap:

object FoldRight {

  def apply[A, B](list: Seq[A])(init: B)(f: (A, B) => B): B = {
    @scala.annotation.tailrec
    def step(current: Seq[A], conts: List[B => B]): B = current match {
      case Seq(last) => conts.foldLeft(f(last, init)) { (acc, next) => next(acc) }
      case Seq(x, xs @ _*) => step(xs, { acc: B => f(x, acc) } +: conts)
      case Nil => init
    }
    step(list, Nil)
  }

}

The fold that happens at the end is itself tail-recursive. Try it out in ScalaFiddle

In terms of performance, it performs slightly worse than the tail call version.

[info] Benchmark            (length)  Mode  Cnt   Score    Error  Units
[info] FoldRight.conts           100  avgt   30   0.003 ±  0.001  ms/op
[info] FoldRight.conts         10000  avgt   30   0.197 ±  0.004  ms/op
[info] FoldRight.conts       1000000  avgt   30  77.292 ±  9.327  ms/op
[info] FoldRight.standard        100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.standard      10000  avgt   30   0.154 ±  0.036  ms/op
[info] FoldRight.standard    1000000  avgt   30  18.796 ±  0.551  ms/op
[info] FoldRight.tailCalls       100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.tailCalls     10000  avgt   30   0.176 ±  0.004  ms/op
[info] FoldRight.tailCalls   1000000  avgt   30  33.525 ±  1.041  ms/op
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文