项目Euler 14:与C和memoization相比的性能

m09*_*m09 11 haskell memoization

我正在研究项目euler问题14.

我用一个编码不好的程序解决了它,没有记忆,运行了386 5秒(见编辑).

这里是:

step :: (Integer, Int) -> Integer -> (Integer, Int)
step (i, m) n   | nextValue > m         = (n, nextValue)
                | otherwise             = (i, m)
                where nextValue = syr n 1

syr :: Integer -> Int -> Int
syr 1 acc   = acc
syr x acc   | even x    = syr (x `div` 2) (acc + 1)
            | otherwise = syr (3 * x + 1) (acc + 1)

p14 = foldl step (0, 0) [500000..999999]
Run Code Online (Sandbox Code Playgroud)

我的问题是关于这个问题的线程中的几个注释,其中提到程序的执行时间<1秒如下(C代码,项目euler论坛用户ix的代码 - 注意:我没有检查执行时间实际上是如上所述):

#include <stdio.h>


int main(int argc, char **argv) {
    int longest = 0;
    int terms = 0;
    int i;
    unsigned long j;
    for (i = 1; i <= 1000000; i++) {
        j = i;
        int this_terms = 1;
        while (j != 1) {
            this_terms++;
            if (this_terms > terms) {
                terms = this_terms;
                longest = i;
            }
            if (j % 2 == 0) {
                j = j / 2;
            } else {
                j = 3 * j + 1;
            }
        }
    }
    printf("longest: %d (%d)\n", longest, terms);
    return 0;
}
Run Code Online (Sandbox Code Playgroud)

在谈到算法时,对我来说,这些程序有点相同.

所以我想知道为什么会有这么大的差异?或者我们的两种算法之间是否有任何基本的区别可以证明x6因素在性能上是合理的?

顺便说一句,我目前正试图用memoization来实现这个算法,但是对我来说有点迷失,用命令式语言实现它更容易(我不操纵monad所以我不能使用这个范例) .因此,如果您有任何适合初学者学习备忘录的好教程,我会很高兴(我遇到的那些不够详细或者不属于我的联盟).

注意:我通过Prolog来进行声明性范例,并且仍处于发现Haskell的早期阶段,所以我可能会错过重要的事情.

注2:欢迎任何关于我的代码的一般建议.

编辑:感谢delnan的帮助,我编译了程序,它现在运行5秒钟,所以我现在主要寻找关于memoization的提示(即使仍然欢迎有关现有x6差距的想法).

Dan*_*her 9

在使用优化编译之后,C程序仍然存在一些差异

  • 你使用div,而C程序使用机器分割(截断)[但任何自尊的C编译器将其转换为移位,这使得它更快],这将quot在Haskell中; 这使运行时间缩短了约15%.
  • C程序使用固定宽度64位(甚至32位,但它只是运气得到正确的答案,因为一些中间值超过32位范围)整数,Haskell程序使用任意精度Integers.如果您Int的GHC(Windows以外的64位操作系统)中有64位,请替换IntegerInt.这样可以将运行时间缩短约3倍.如果您使用的是32位系统,那么运气不好,GHC不会在那里使用本机64位指令,这些操作是作为C调用实现的,但仍然很慢.

对于备忘录,你可以将它外包给hackage上的一个memoisation包,我唯一记得的是数据memocombinators,但还有其他.或者你可以自己做,例如保留以前计算的值的地图 - 这在Statemonad中效果最好,

import Control.Monad.State.Strict
import qualified Data.Map as Map
import Data.Map (Map, singleton)

type Memo = Map Integer Int

syr :: Integer -> State Memo Int
syr n = do
    mb <- gets (Map.lookup n)
    case mb of
      Just l -> return l
      Nothing -> do
          let m = if even n then n `quot` 2 else 3*n+1
          l <- syr m
          let l' = l+1
          modify (Map.insert n l')
          return l'

solve :: Integer -> Int -> Integer -> State Memo (Integer,Int)
solve maxi len start
    | len > 1000000 = return (maxi,len)
    | otherwise = do
         l <- syr start
         if len < l
             then solve start l (start+1)
             else solve maxi len (start+1)

p14 :: (Integer,Int)
p14 = evalState (solve 0 0 500000) (singleton 1 1)
Run Code Online (Sandbox Code Playgroud)

但这可能不会获得太多(即使你已经添加了必要的严格性).问题是a中的查找Map不是太便宜而且插入相对昂贵.

另一种方法是为查找保留一个可变数组.代码变得更加复杂,因为您必须为要缓存的值选择合理的上限(应该不比起始值的界限大很多)并处理落在备忘范围之外的序列部分.但是数组查找和写入速度很快.如果你有64位Ints,下面的代码运行得非常快,这里需要0.03s,限制为100万,0.33s限制为1000万,相应的(尽可能合理的)C代码运行在0.018分别 0.2S.

module Main (main) where

import System.Environment (getArgs)
import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits
import Data.Int

main :: IO ()
main = do
    args <- getArgs
    let bd = case args of
               a:_ -> read a
               _   -> 100000
    print $ collMax bd

next :: Int -> Int
next n
    | n .&. 1 == 0  = n `unsafeShiftR` 1
    | otherwise     = 3*n + 1

collMax :: Int -> (Int,Int16)
collMax upper = runST $ do
    arr <- newArray (0,upper) 0 :: ST s (STUArray s Int Int16)
    let go l m
            | upper < m = go (l+1) $ next m
            | otherwise = do
                l' <- unsafeRead arr m
                case l' of
                  0 -> do
                      l'' <- go 1 $ next m
                      unsafeWrite arr m (l'' + 1)
                      return (l+l'')
                  _ -> return (l+l'-1)
        collect mi ml i
            | upper < i = return (mi, ml)
            | otherwise = do
                l <- go 1 i
                if l > ml
                  then collect i l (i+1)
                  else collect mi ml (i+1)
    unsafeWrite arr 1 1
    collect 1 1 2
Run Code Online (Sandbox Code Playgroud)