为什么这个简单的 haskell 算法这么慢?

发布于 2024-12-23 08:15:53 字数 4750 浏览 1 评论 0原文

剧透警告:这与 Project Euler 的问题 14 有关。

以下内容代码运行大约需要 15 秒。我有一个运行时间为 1 秒的非递归 Java 解决方案。我想我应该能够让这段代码更接近那个。

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

我已经使用 +RHS -p 进行了分析,并注意到分配的内存很大,并且随着输入的增长而增长。对于n = 100,000 分配了1gb(!),对于n = 1,000,000 分配了13gb(!!)。

不过,-sstderr 显示虽然分配了很多字节,但总内存使用量为 1mb,生产率为 95% 以上,因此 13gb 可能是转移注意力的。

我可以想到几种可能性:

  1. 有些事情并不像需要的那么严格。我已经发现了 foldl1',但也许我需要做更多的事情?是否可以标记collat​​z 严格(这有意义吗?

  2. collat​​z 不是尾部调用优化。我认为应该是,但不 知道一种确认方法。

  3. 编译器没有做一些我认为应该做的优化 - 例如 任何时候只有 collat​​z 的两个结果需要存储在内存中(最大和当前)

有什么建议吗?

这几乎是 为什么这个 Haskell 表达式这么慢?,尽管我会注意到快速 Java 解决方案不必执行任何记忆。有没有什么方法可以加快速度而不需要求助于它?

作为参考,这里是我的分析输出:

  Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

     scratch +RTS -p -hc -RTS

  total time  =        5.12 secs   (256 ticks @ 20 ms)
  total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc

collatz                        Main                  99.6   99.4


                                                                                               individual    inherited
COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
 CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
  collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
  main                   Main                                                 214           1   0.4    0.6   100.0  100.0
   collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
 CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
 CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
 CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

And -sstderr:

./scratch +RTS -sstderr 
525
  21,085,474,908 bytes allocated in the heap
      87,799,504 bytes copied during GC
           9,420 bytes maximum residency (1 sample(s))          
          12,824 bytes maximum slop               
               1 MB total memory in use (0 MB lost due to fragmentation)  

  Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
  Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.38s  ( 36.37s elapsed)
  GC    time    0.40s  (  0.51s elapsed)
  RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

  Productivity  98.9% of total user, 95.9% of total elapsed

和 Java 解决方案(不是我的,取自 Project Euler 论坛,删除了记忆功能):

public class Collatz {
  public int getChainLength( int n )
  {
    long num = n;
    int count = 1;
    while( num > 1 )
    {
      num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
      count++;
    }
    return count;
  }

  public static void main(String[] args) {
    Collatz obj = new Collatz();
    long tic = System.currentTimeMillis();
    int max = 0, len = 0, index = 0;
    for( int i = 3; i < 1000000; i++ )
    {
      len = obj.getChainLength(i);
      if( len > max )
      {
        max = len;
        index = i;
      }
    }
    long toc = System.currentTimeMillis();
    System.out.println(toc-tic);
    System.out.println( "Index: " + index + ", length = " + max );
  }
}

Spoiler alert: this is related to Problem 14 from Project Euler.

The following code takes around 15s to run. I have a non-recursive Java solution that runs in 1s. I think I should be able to get this code much closer to that.

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

I have profiled with +RHS -p and noticed that the allocated memory is large, and grows as the input grows. For n = 100,000 1gb is allocated(!), for n = 1,000,000 13gb(!!) is allocated.

Then again, -sstderr shows that although lots of bytes were allocated, total memory use was 1mb, and productivity was 95%+, so maybe 13gb is red-herring.

I can think of a few possibilities:

  1. Something isn't as strict as it needs to be. I've already discovered
    foldl1', but maybe I need to do more? Is it possible to mark collatz
    as strict (does that even make sense?

  2. collatz isn't tail-call optimizing. I think it should be but don't
    know a way to confirm.

  3. The compiler isn't doing some optimizations I think it should - for instance
    only two results of collatz need to be in memory at any one time (max and current)

Any suggestions?

This is pretty much a duplicate of Why is this Haskell expression so slow?, though I will note that the fast Java solution does not have to perform any memoization. Are there any ways to speed this up without having to resort to it?

For reference, here is my profiling output:

  Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

     scratch +RTS -p -hc -RTS

  total time  =        5.12 secs   (256 ticks @ 20 ms)
  total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc

collatz                        Main                  99.6   99.4


                                                                                               individual    inherited
COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
 CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
  collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
  main                   Main                                                 214           1   0.4    0.6   100.0  100.0
   collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
 CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
 CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
 CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

And -sstderr:

./scratch +RTS -sstderr 
525
  21,085,474,908 bytes allocated in the heap
      87,799,504 bytes copied during GC
           9,420 bytes maximum residency (1 sample(s))          
          12,824 bytes maximum slop               
               1 MB total memory in use (0 MB lost due to fragmentation)  

  Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
  Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.38s  ( 36.37s elapsed)
  GC    time    0.40s  (  0.51s elapsed)
  RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

  Productivity  98.9% of total user, 95.9% of total elapsed

And Java solution (not mine, taken from Project Euler forums with memoization removed):

public class Collatz {
  public int getChainLength( int n )
  {
    long num = n;
    int count = 1;
    while( num > 1 )
    {
      num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
      count++;
    }
    return count;
  }

  public static void main(String[] args) {
    Collatz obj = new Collatz();
    long tic = System.currentTimeMillis();
    int max = 0, len = 0, index = 0;
    for( int i = 3; i < 1000000; i++ )
    {
      len = obj.getChainLength(i);
      if( len > max )
      {
        max = len;
        index = i;
      }
    }
    long toc = System.currentTimeMillis();
    System.out.println(toc-tic);
    System.out.println( "Index: " + index + ", length = " + max );
  }
}

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(3

梦情居士 2024-12-30 08:15:53

起初,我认为您应该尝试在 collat​​z 中的 a 之前添加一个感叹号:(

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

您需要添加 {-# LANGUAGE BangPatterns #-}< /code> 在你的源文件的顶部才能工作。)

我的推理如下:问题是你在 collat​​z 的第一个参数中构建了一个巨大的 thunk :它开始关闭为1,然后变成 1 + 1,然后变成 (1 + 1) + 1,...所有这些都无需强迫。这个 bang 模式 强制 < 的第一个参数code>collat​​z 每当进行调用时都会被强制执行,因此它从 1 开始,然后变成 2,依此类推,而不会建立一个大的未评估的 thunk:它只是保持为整数。

请注意,爆炸模式只是使用 seq;在这种情况下,我们可以重写 collat​​z 如下:

collatz a _ | seq a False = undefined
collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

这里的技巧是在守卫中强制 a ,然后它总是评估为 False (因此主体是不相关的) 。然后继续评估下一个案例,a 已经被评估。然而,刘海图案更清晰。

不幸的是,当使用 -O2 编译时,它的运行速度并不比原始版本快!我们还能尝试什么?好吧,我们可以做的一件事是假设这两个数字永远不会溢出机器大小的整数,并给予 collat​​z 这个类型注释:

collatz :: Int -> Int -> Int

我们将把 bang 模式留在那里,因为我们仍然应该避免构建thunks,即使它们不是性能问题的根源。这使得我的(慢速)计算机上的时间缩短至 8.5 秒。

下一步是尝试使其更接近 Java 解决方案。首先要认识到,在 Haskell 中,div 对于负整数的行为在数学上更正确,但比“正常”C 除法慢,在 Haskell 中称为 quot。将 div 替换为 quot 将运行时间缩短至 5.2 秒,并将 x `quot` 2 替换为 x `shiftR` 1< /code>(导入 Data.Bits)以匹配 Java 解决方案,将其缩短至 4.9 秒。

这大约是我目前能得到的最低值,但我认为这是一个相当不错的结果;因为你的计算机比我的更快,所以它应该更接近 Java 解决方案。

这是最终的代码(我在途中做了一些清理):

{-# LANGUAGE BangPatterns #-}

import Data.Bits
import Data.List

collatz :: Int -> Int
collatz = collatz' 1
  where collatz' :: Int -> Int -> Int
        collatz' !a 1 = a
        collatz' !a x
          | even x    = collatz' (a + 1) (x `shiftR` 1)
          | otherwise = collatz' (a + 1) (3 * x + 1)

main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]

查看该程序的 GHC 核心(使用 ghc-core),我认为这可能已经是最好的了; collat​​z 循环使用未装箱的整数,程序的其余部分看起来不错。我能想到的唯一改进是消除 map collat​​z [1..1000000] 迭代中的装箱。

顺便说一句,不要担心“总分配”数字;它是在程序的生命周期内分配的总内存,即使 GC 回收该内存,它也永远不会减少。数 TB 的数字很常见。

At first, I thought you should try putting an exclamation mark before a in collatz:

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

(You'll need to put {-# LANGUAGE BangPatterns #-} at the top of your source file for this to work.)

My reasoning went as follows: The problem is that you're building up a massive thunk in the first argument to collatz: it starts off as 1, and then becomes 1 + 1, and then becomes (1 + 1) + 1, ... all without ever being forced. This bang pattern forces the first argument of collatz to be forced whenever a call is made, so it starts off as 1, and then becomes 2, and so on, without building up a large unevaluated thunk: it just stays as an integer.

Note that a bang pattern is just shorthand for using seq; in this case, we could rewrite collatz as follows:

collatz a _ | seq a False = undefined
collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

The trick here is to force a in the guard, which then always evaluates to False (and so the body is irrelevant). Then evaluation continues with the next case, a having already been evaluated. However, a bang pattern is clearer.

Unfortunately, when compiled with -O2, this doesn't run any faster than the original! What else can we try? Well, one thing we can do is assume that the two numbers never overflow a machine-sized integer, and give collatz this type annotation:

collatz :: Int -> Int -> Int

We'll leave the bang pattern there, since we should still avoid building up thunks, even if they aren't the root of the performance problem. This brings the time down to 8.5 seconds on my (slow) computer.

The next step is to try bringing this closer to the Java solution. The first thing to realise is that in Haskell, div behaves in a more mathematically correct manner with respect to negative integers, but is slower than "normal" C division, which in Haskell is called quot. Replacing div with quot brought the runtime down to 5.2 seconds, and replacing x `quot` 2 with x `shiftR` 1 (importing Data.Bits) to match the Java solution brought it down to 4.9 seconds.

This is about as low as I can get it for now, but I think this is a pretty good result; since your computer is faster than mine, it should hopefully be even closer to the Java solution.

Here's the final code (I did a little bit of clean-up on the way):

{-# LANGUAGE BangPatterns #-}

import Data.Bits
import Data.List

collatz :: Int -> Int
collatz = collatz' 1
  where collatz' :: Int -> Int -> Int
        collatz' !a 1 = a
        collatz' !a x
          | even x    = collatz' (a + 1) (x `shiftR` 1)
          | otherwise = collatz' (a + 1) (3 * x + 1)

main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]

Looking at the GHC Core for this program (with ghc-core), I think this is probably about as good as it gets; the collatz loop uses unboxed integers and the rest of the program looks OK. The only improvement I can think of would be eliminating the boxing from the map collatz [1..1000000] iteration.

By the way, don't worry about the "total alloc" figure; it's the total memory allocated over the lifetime of the program, and it never decreases even when the GC reclaims that memory. Figures of multiple terabytes are common.

世俗缘 2024-12-30 08:15:53

您可能会丢失列表和爆炸模式,但通过使用堆栈仍然可以获得相同的性能。

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

这样做的一个问题是,对于大输入(例如 1_000_000),您必须增加堆栈的大小。

更新:

这是一个尾递归版本,不会遇到堆栈溢出问题。

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

请注意使用 Word 而不是 Int。它在性能上有所不同。如果您愿意,您仍然可以使用刘海图案,这将使性能几乎翻倍。

You could lose the list and the bang patterns and still get the same performance by using the stack instead.

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

One problem with this is that you will have to increase the size of the stack for large inputs, like 1_000_000.

update:

Here is a tail recursive version that doesn't suffer from the stack overflow problem.

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

Notice the use of Word instead of Int. It makes a difference in performance. You could still use the bang patterns if you want, and that would nearly double the performance.

忘东忘西忘不掉你 2024-12-30 08:15:53

我发现一件事对这个问题产生了惊人的影响。我坚持使用直接的递归关系而不是折叠,你应该原谅这个表达式,以及它的计数。 重写

collatz n = if even n then n `div` 2 else 3 * n + 1

collatz n = case n `divMod` 2 of
            (n', 0) -> n'
            _       -> 3 * n + 1

在具有 2.8 GHz Athlon II X4 430 CPU 的系统上, 我的程序运行时间缩短了 1.2 秒。我最初的更快版本(使用 divMod 后 2.3 秒):

{-# LANGUAGE BangPatterns #-}

import Data.List
import Data.Ord

collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
    where collatzChainLen' n !l
            | n == 1    = l
            | otherwise = collatzChainLen' (collatz n) (l + 1)

collatz:: Int -> Int
collatz n = case n `divMod` 2 of
                 (n', 0) -> n'
                 _       -> 3 * n + 1

pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]

main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

也许更惯用的 Haskell 版本运行时间约为 9.7 秒(使用 divMod 为 8.5 秒);它与

collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

使用 Data.List.Stream 相同,应该允许流融合,这将使该版本的运行更像具有显式累积的版本,但我找不到具有 Data.List.Stream 的 Ubuntu libghc* 包,所以我还无法验证这一点。

One thing I found made a surprising difference in this problem. I stuck with the straight recurrence relation rather than folding, you should pardon the expression, the counting in with it. Rewriting

collatz n = if even n then n `div` 2 else 3 * n + 1

as

collatz n = case n `divMod` 2 of
            (n', 0) -> n'
            _       -> 3 * n + 1

took 1.2 seconds off the runtime for my program on a system with a 2.8 GHz Athlon II X4 430 CPU. My initial faster version (2.3 seconds after the use of divMod):

{-# LANGUAGE BangPatterns #-}

import Data.List
import Data.Ord

collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
    where collatzChainLen' n !l
            | n == 1    = l
            | otherwise = collatzChainLen' (collatz n) (l + 1)

collatz:: Int -> Int
collatz n = case n `divMod` 2 of
                 (n', 0) -> n'
                 _       -> 3 * n + 1

pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]

main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

A perhaps more idiomatic Haskell version run in about 9.7 seconds (8.5 with divMod); it's identical save for

collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

Using Data.List.Stream is supposed to allow stream fusion that would make this version run more like that with the explicit accumulation, but I can't find an Ubuntu libghc* package that has Data.List.Stream, so I can't yet verify that.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文