《Scala 编程》中的合并排序导致堆栈溢出

发布于 2024-08-20 03:07:23 字数 578 浏览 4 评论 0原文

直接剪切并粘贴以下算法:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

在 5000 个长列表上导致 StackOverflowError。

有没有办法优化一下,避免这种情况发生?

A direct cut and paste of the following algorithm:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

causes a StackOverflowError on 5000 long lists.

Is there any way to optimize this so that this doesn't occur?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(3

鲜血染红嫁衣 2024-08-27 03:07:23

它这样做是因为它不是尾递归的。您可以通过使用非严格集合或使其尾递归来解决此问题。

后一种解决方案如下:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

使用非严格性涉及按名称传递参数,或使用非严格集合,例如 Stream。以下代码使用 Stream 只是为了防止堆栈溢出,并在其他地方使用 List

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}

It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.

The latter solution goes like this:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}
枕梦 2024-08-27 03:07:23

只是玩弄 scala 的 TailCalls (蹦床支持),我怀疑最初提出这个问题时它并不存在。这是 Rex 的回答中合并的递归不可变版本。

import scala.util.control.TailCalls._

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {

  def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
    if (a.isEmpty) {
      done(b.reverse ::: s)
    } else if (b.isEmpty) {
      done(a.reverse ::: s)
    } else if (a.head<b.head) {
      tailcall(build(a.head::s,a.tail,b))
    } else {
      tailcall(build(b.head::s,a,b.tail))
    }
  }

  build(List(),x,y).result.reverse
}

与 64 位 OpenJDK(i7 上的 Debian/Squeeze amd64)上的 Scala 2.9.1 上的大 List[Long] 上的可变版本一样快。

Just playing around with scala's TailCalls (trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.

import scala.util.control.TailCalls._

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {

  def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
    if (a.isEmpty) {
      done(b.reverse ::: s)
    } else if (b.isEmpty) {
      done(a.reverse ::: s)
    } else if (a.head<b.head) {
      tailcall(build(a.head::s,a.tail,b))
    } else {
      tailcall(build(b.head::s,a,b.tail))
    }
  }

  build(List(),x,y).result.reverse
}

Runs just as fast as the mutable version on big List[Long]s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).

贩梦商人 2024-08-27 03:07:23

万一大牛的解决方案说得不够清楚,问题是merge的递归深度与列表的长度一样深,而且它不是尾递归,因此无法转换为迭代。

Scala 可以将 Daniel 的尾递归合并解决方案转换为与此大致等效的解决方案:

def merge(xs: List[T], ys: List[T]): List[T] = {
  var acc:List[T] = Nil
  var decx = xs
  var decy = ys
  while (!decx.isEmpty || !decy.isEmpty) {
    (decx, decy) match { 
      case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
      case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) { acc = x :: acc ; decx = xs1 }
        else { acc = y :: acc ; decy = ys1 }
    }
  }
  acc.reverse
}

但它会为您跟踪所有变量。

(尾递归方法是指方法调用自身来获取要传回的完整答案;它从不调用自身,然后在将结果传回之前对结果执行某些操作。此外,尾递归如果方法可能是多态的,则不能使用递归,因此它通常仅适用于对象或标记为 Final 的类。)

Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.

Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:

def merge(xs: List[T], ys: List[T]): List[T] = {
  var acc:List[T] = Nil
  var decx = xs
  var decy = ys
  while (!decx.isEmpty || !decy.isEmpty) {
    (decx, decy) match { 
      case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
      case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) { acc = x :: acc ; decx = xs1 }
        else { acc = y :: acc ; decy = ys1 }
    }
  }
  acc.reverse
}

but it keeps track of all the variables for you.

(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文