Haskell递归效率

Nat*_*han 1 memory recursion haskell

我正在做一些Project Euler项目(不是作为功课,只是为了娱乐/学习),而且我正在学习Haskell.其中一个问题是找到最大的Collat​​z序列,起始数量低于100万(http://projecteuler.net/problem=14)

无论如何,我能够做到这一点,我的算法工作,并在编译时得到正确的答案.但是,它使用1000000深度递归.

所以我的问题是:我做对了吗?原样,Haskell的正确方法是什么?我怎么能让它更快?另外,在内存使用情况下,如何在低级别实际实现递归?如何使用内存?

(SPOILER ALERT:如果你想在不看答案的情况下自己解决Project Euler的#14问题,请不要看这个.)

--haskell script --problem:找到一个不到200万的最长的collat​​z链.

collatzLength x| x == 1 = 1
               | otherwise = 1 + collatzLength(nextStep x)


longestChain (num, numLength) bound counter
           | counter >= bound = (num, numLength)
           | otherwise = longestChain (longerOf (num,numLength)
             (counter,   (collatzLength counter)) ) bound (counter + 1)
           --I know this is a messy function, but I was doing this problem just 
           --for myself, so I didn't bother making some utility functions for it.
           --also, I split the big line in half to display on here nicer, would
           --it actually run with this line split?


longerOf (a1,a2) (b1,b2)| a2 > b2 = (a1,a2)
                        | otherwise = (b1,b2)

nextStep n | mod n 2 == 0 = (n `div` 2)
           | otherwise = 3*n + 1

main = print (longestChain (0,0) 1000000 1)
Run Code Online (Sandbox Code Playgroud)

使用-O2编译时,程序运行大约7.5秒.

那么,有什么建议/意见吗?我想尝试让程序以更少的内存使用速度运行,并且我希望以非常Haskellian(应该是一个单词)的方式来实现.

提前致谢!

Tho*_*son 7

编辑以回答问题

我这样做了吗?

几乎就像评论所说的那样,你构建了一大堆1+(1+(1+...))- 使用严格的累加器或者更高级的函数来为你处理事情.还有其他一些小问题,比如定义一个函数来比较第二个元素而不是使用,maximumBy (comparing snd)但更具风格.

原样,Haskell的正确方法是什么?

它是可接受的惯用Haskell代码.

我怎么能让它更快?

请参阅下面的基准测试.欧拉表现问题的极为常见的答案是:

  • 使用-O2(就像你一样)
  • 尝试-fllvm(GHC NCG是次优的)
  • 使用worker/wrappers来减少参数,或者在您的情况下,使用累加器.
  • 使用快速/不可用的类型(当您可以使用Int而不是Integer时,Int64如果需要可移植性等).
  • 当所有值都是正值时使用rem而不是mod.对于你的情况,知道或发现div往往编译成比慢的东西也是有用的quot.

另外,在内存使用情况下,如何在低级别实际实现递归?如何使用内存?

Both these questions are very broad. Complete answers would likey need to address lazy evaluation, tail call optimization, worker transformations, garbage collection, etc. I suggest you explore these answers in more depth over time (or hope someone here makes the complete answer I'm avoiding).

Original Post - Benchmark numbers

Original:

$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main             ( so.hs, so.o )
Linking so ...
(837799,525)

real    0m5.971s
user    0m5.940s
sys 0m0.019s
Run Code Online (Sandbox Code Playgroud)

Use a worker function with an accumulator for collatzLength:

$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main             ( so.hs, so.o )
Linking so ...
(837799,525)

real    0m5.617s
user    0m5.590s
sys 0m0.012s
Run Code Online (Sandbox Code Playgroud)

Use Int and not defaulting to Integer - it's also easier to read with type signatures!

$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main             ( so.hs, so.o )
Linking so ...
(837799,525)

real    0m2.937s
user    0m2.932s
sys 0m0.001s
Run Code Online (Sandbox Code Playgroud)

Use rem and not mod:

$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main             ( so.hs, so.o )
Linking so ...
(837799,525)

real    0m2.436s
user    0m2.431s
sys 0m0.001s
Run Code Online (Sandbox Code Playgroud)

Use quotRem and not rem then div:

$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main             ( so.hs, so.o )
Linking so ...
(837799,525)

real    0m1.672s
user    0m1.669s
sys 0m0.002s
Run Code Online (Sandbox Code Playgroud)

This is all very much like a previous question: Speed comparison with Project Euler: C vs Python vs Erlang vs Haskell

EDIT: and yes, as Daniel Fischer suggests, using bit ops of .&. and shiftR improves on quotRem:

$ ghc -O2 so.hs ; time ./so
(837799,525)

real    0m0.314s
user    0m0.312s
sys 0m0.001s
Run Code Online (Sandbox Code Playgroud)

Or you can just use LLVM and let it do it's magic (NB this version uses quotRem still)

$ time ./so
(837799,525)

real    0m0.286s
user    0m0.283s
sys 0m0.002s
Run Code Online (Sandbox Code Playgroud)

LLVM actually does well, so long as you avoid the hideousness that is mod, and optimizes the guard-based code using either rem or even equally well as the hand-optimized .&. with shiftR.

For a result that is around 20x faster than the original.

EDIT: People are surprised that quotRem performs as well as bit manipulation in the face of Int. The code is included, but I'm not clear on the surprise: just because something is possibly negative doesn't mean you can't handle it with very similar bit manipulations that could be of identical cost on the right hardware. All three versions of nextStep seem to be performing identically (ghc -O2 -fforce-recomp -fllvm, ghc version 7.6.3, LLVM 3.3, x86-64).

{-# LANGUAGE BangPatterns, UnboxedTuples #-}

import Data.Bits

collatzLength :: Int -> Int
collatzLength x| x == 1    = 1
               | otherwise = go x 0
 where
    go 1 a  = a + 1
    go x !a = go (nextStep x) (a+1)


longestChain :: (Int, Int) -> Int -> Int -> (Int,Int)
longestChain (num, numLength) bound !counter
   | counter >= bound = (num, numLength)
   | otherwise = longestChain (longerOf (num,numLength) (counter, collatzLength counter)) bound (counter + 1)
           --I know this is a messy function, but I was doing this problem just 
           --for myself, so I didn't bother making some utility functions for it.
           --also, I split the big line in half to display on here nicer, would
           --it actually run with this line split?


longerOf :: (Int,Int) -> (Int,Int) -> (Int,Int)
longerOf (a1,a2) (b1,b2)| a2 > b2 = (a1,a2)
                        | otherwise = (b1,b2)
{-# INLINE longerOf #-}

nextStep :: Int -> Int
-- Version 'bits'
nextStep n = if 0 == n .&. 1 then n `shiftR` 1 else 3*n+1
-- Version 'quotRem'
-- nextStep n = let (q,r) = quotRem n 2 in if r == 0 then q else 3*n+1
-- Version 'almost the original'
-- nextStep n | even n = quot n 2
--            | otherwise  = 3*n + 1
{-# INLINE nextStep #-}


main = print (longestChain (0,0) 1000000 1)
Run Code Online (Sandbox Code Playgroud)

  • @PetrPudlákGHC还没有被教过许多这样的优化,很少有人在研究它,他们的时间花在更高级别的优化和类型系统欺骗上,它本身并不是这样做(到目前为止).当您使用正确的类型时,LLVM后端会进行优化(最后我检查过,仅用于`quot/rem`,而不是用于`div/mod`).但是当然,对于`Int`,它必须考虑到负数的可能性,所以你可以使用只有正数的域知识来打败它(不是很多)[除非`Int`是32位,在这种情况下你有溢出]. (2认同)