优化Haskell程序

Edv*_*olm 5 optimization haskell

我昨天开始关注Haskell,目的是实际学习它.我在编程语言课程中编写了一些简单的程序,但没有一个真正关心效率.我试图了解如何改善以下程序的运行时间.

我的程序解决了以下玩具问题(我知道如果你知道什么是阶乘,那么手动计算答案是很简单的,但我正在使用后继函数进行蛮力方式):

http://projecteuler.net/problem=24

给定有限长度列表的词典排序的后继函数的算法如下:

  1. 如果列表已经按递减顺序排列,那么我们在词典排序中有最大元素,因此没有后继者.

  2. 给定一个列表h:t,或者在词典排序中t是最大的,或者不是.在后一种情况下,计算t的后继.在前一种情况下,进行如下.

  3. 选择t大于h的最小元素d.

  4. 用h代替d,给出一个新的列表t'.排序中的下一个元素是d :(排序t')

我实现此目的的程序如下(许多这些函数可能在标准库中):

max_list :: (Ord a) => [a] -> a
max_list []     = error "Empty list has no maximum!"
max_list (h:[]) = h
max_list (h:t)  = max h (max_list t)

min_list :: (Ord a) => [a] -> a
min_list []     = error "Empty list has no minimum!"
min_list (h:[]) = h
min_list (h:t)  = min h (min_list t)

-- replaces first occurrence of x in list with y
replace :: (Eq a) => a -> a -> [a] -> [a]
replace _ _ []  = []
replace x y (h:t)
    | h == x    = y : t
    | otherwise = h : (replace x y t)

-- sort in increasing order
sort_list :: (Ord a) => [a] -> [a]
sort_list []    = []
sort_list (h:t) = (sort_list (filter (\x -> x <= h) t))
               ++ [h]
               ++ (sort_list (filter (\x -> x > h) t))

-- checks if list is in descending order
descending :: (Ord a) => [a] -> Bool
descending []     = True
descending (h:[]) = True
descending (h:t)
    | h > (max_list t) = descending t
    | otherwise        = False

succ_list :: (Ord a) => [a] -> [a]
succ_list []      = []
succ_list (h:[])  = [h]
succ_list (h:t)
    | descending (h:t)   = (h:t)
    | not (descending t) = h : succ_list t
    | otherwise = next_h : sort_list (replace next_h h t)
    where next_h = min_list (filter (\x -> x > h) t)

-- apply function n times
apply_times :: (Integral n) => n -> (a -> a) -> a -> a
apply_times n _ a
    | n <= 0      = a
apply_times n f a = apply_times (n-1) f (f a)

main = putStrLn (show (apply_times 999999 succ_list [0,1,2,3,4,5,6,7,8,9]))
Run Code Online (Sandbox Code Playgroud)

现在是实际的问题.在注意到我的程序运行了一段时间后,我写了一个等效的C程序进行比较.我的猜测是,对Haskell的惰性求值导致apply_times函数在实际开始评估结果之前在内存中构建一个巨大的列表.我不得不增加运行时的堆栈大小.由于高效的Haskell编程似乎与技巧有关,是否有任何可用于最小化内存消耗的好技巧?如何最小化复制和垃圾收集,因为列表不断被创建,而C实现将完成所有操作.

既然Haskell被认为是有效的,我想必须有办法吗?关于Haskell,我不得不说的一件很酷的事情是程序在第一次编译时工作正常,因此该部分语言似乎确实填补了它的承诺.

Dan*_*her 12

很多这些功能可能都在标准库中

确实.如果你import Data.List,那是sort可用的,maximum并且minimum可以从Prelude.该sortData.List什么都是比准快速排序更有效,特别是因为你有很多的排序块在这里列出.

descending :: (Ord a) => [a] -> Bool
descending []     = True
descending (h:[]) = True
descending (h:t)
    | h > (max_list t) = descending t
    | otherwise        = False
Run Code Online (Sandbox Code Playgroud)

是低效的 - O(n²)因为它在每一步中遍历整个左尾,但如果列表正在下降,则尾部的最大值必须是它的头部. 但这在这里有一个很好的结果.它可以防止thunk的堆积,因为第三个等式的第一个保护succ_list强制列表被完全评估.但是,这可以通过明确强制列表一次来更有效地完成.

descending (h:t@(ht:_)) = h > ht && descending t
Run Code Online (Sandbox Code Playgroud)

会使它成为线性的.那

在注意到我的程序运行了一段时间后,我写了一个等效的C程序进行比较.

这将是不寻常的.到目前为止,很少有人会使用C中的链表,在此基础上实施惰性评估将是一项艰巨的任务.

用C语言编写一个等效的程序是非常不同寻常的.在C中,实现算法的自然方式是使用数组和就地变异.这在这里自动更有效.

我的猜测是,对Haskell的惰性求值导致apply_times函数在实际开始评估结果之前在内存中构建一个巨大的列表.

不完全是,它构建的是一个巨大的thunk,

apply_times 999999 succ_list [0,1,2,3,4,5,6,7,8,9]
~> apply_times 999998 succ_list (succ_list [0 .. 9])
~> apply_times 999997 succ_list (succ_list (succ_list [0 .. 9]))
~> apply_times 999996 succ_list (succ_list (succ_list (succ_list [0 .. 9])))
...
succ_list (succ_list (succ_list ... (succ_list [0 .. 9])...))
Run Code Online (Sandbox Code Playgroud)

并且,在构建了thunk之后,必须对其进行评估.要评估最外层的调用,必须对下一个调用进行足够的评估,以找出最外层调用中哪个模式匹配.因此,最外层的调用被推入堆栈,并开始评估下一个调用.为此,必须确定哪个模式匹配,因此需要第三个调用的部分结果.因此,第二个呼叫被推到堆栈上...... 最后,您在堆栈上有999998个调用,并开始评估最里面的调用.然后你在每个调用和下一个外部调用之间播放一些乒乓(至少,依赖关系可能会进一步扩展),同时冒泡并从堆栈中弹出调用.

有什么好的技巧可以用来减少内存消耗

是的,强制中间列表在成为参数之前进行评估apply_times.你需要在这里完成评估,所以香草seq不够好

import Control.DeepSeq

apply_times' :: (NFData a, Integral n) => n -> (a -> a) -> a -> a
apply_times' 0 _ x = x
apply_times' k f x = apply_times' (k-1) f $!! f x
Run Code Online (Sandbox Code Playgroud)

这可以防止thunk的堆积,因此你不需要比构造的一些短列表succ_list和计数器更多的内存.

如何最小化复制和垃圾收集,因为列表不断被创建,而C实现将完成所有操作.

是的,那仍然会分配(和垃圾收集)很多.现在,GHC是非常在分配和垃圾收集短命数据好(在我的盒子,它可以很容易地在每MUT 2GB的速度分配第二而不慢),不过,不分配所有这些列表会更快.

因此,如果您想推送它,请使用就地变异.工作

STUArray s Int Int
Run Code Online (Sandbox Code Playgroud)

或者是一个未装箱的可变矢量(我更喜欢array包装提供的界面,但大多数更喜欢vector界面;在性能方面,vector包装内置了很多优化,如果你使用array包装,你必须写快速编写代码,但编写良好的代码在所有实际用途中都是相同的.


我现在做了一些测试.我没有测试过原始的懒惰apply_times,只测试了deepseq每个应用程序f,并修复了所有涉及实体的类型Int.

通过该设置,替换sort_listData:list.sort将运行时间从1.82秒减少到1.65(但增加了分配的字节数).没有太大的区别,但是这些列表还不够长,以至于准快速排序的坏情况真的很糟糕.

最大的不同之处在于改变了descending所提出的,将时间降低到0.48秒,Alloc每秒MUT的速率为2,170,566,037字节,GC时间为0.01秒(然后使用sort_list而不是sort将时间带到0.58秒).

用更简单的方法替换列表结尾段的排序reverse- 算法保证在排序时按降序排序 - 将时间缩短到0.43秒.

算法的相当直接的翻译,以使用未装箱的可变数组,

{-# LANGUAGE BangPatterns #-}
module Main (main) where

import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Control.Monad (when, replicateM_)

sortPart :: STUArray s Int Int -> Int -> Int -> ST s ()
sortPart a lo hi
   | lo < hi   = do
       let lscan !p h i
               | i < h = do
                   v <- unsafeRead a i
                   if p < v then return i else lscan p h (i+1)
               | otherwise = return i
           rscan !p l i
               | l < i = do
                   v <- unsafeRead a i
                   if v < p then return i else rscan p l (i-1)
               | otherwise = return i
           swap i j = do
               v <- unsafeRead a i
               unsafeRead a j >>= unsafeWrite a i
               unsafeWrite a j v
           sloop !p l h
               | l < h = do
                   l1 <- lscan p h l
                   h1 <- rscan p l1 h
                   if (l1 < h1) then (swap l1 h1 >> sloop p l1 h1) else return l1
               | otherwise = return l
       piv <- unsafeRead a hi
       i <- sloop piv lo hi
       swap i hi
       sortPart a lo (i-1)
       sortPart a (i+1) hi
   | otherwise = return ()

descending :: STUArray s Int Int -> Int -> Int -> ST s Bool
descending arr lo hi
    | lo < hi   = do
        let check i !v
                | hi < i    = return True
                | otherwise = do
                    w <- unsafeRead arr i
                    if w < v
                      then check (i+1) w
                      else return False
        x <- unsafeRead arr lo
        check (lo+1) x
    | otherwise = return True

findAndReplace :: STUArray s Int Int -> Int -> Int -> ST s ()
findAndReplace arr lo hi
    | lo < hi   = do
        x <- unsafeRead arr lo
        let go !mi !mv i
                | hi < i    = when (lo < mi) $ unsafeWrite arr mi x >> unsafeWrite arr lo mv
                | otherwise = do
                    w <- unsafeRead arr i
                    if x < w && w < mv
                      then go i w (i+1)
                      else go mi mv (i+1)
            look i
                | hi < i    = return ()
                | otherwise = do
                    w <- unsafeRead arr i
                    if x < w
                      then go i w (i+1)
                      else look (i+1)
        look (lo+1)
    | otherwise = return ()

succArr :: STUArray s Int Int -> Int -> Int -> ST s ()
succArr arr lo hi
    | lo < hi   = do
        end <- descending arr lo hi
        if end
          then return ()
          else do
              needSwap <- descending arr (lo+1) hi
              if needSwap
                then do
                    findAndReplace arr lo hi
                    sortPart arr (lo+1) hi
                else succArr arr (lo+1) hi
    | otherwise = return ()

solution :: [Int]
solution = runST $ do
    arr <- newListArray (0,9) [0 .. 9]
    replicateM_ 999999 $ succArr arr 0 9
    getElems arr

main :: IO ()
main = print solution
Run Code Online (Sandbox Code Playgroud)

在0.15秒内完成.通过更简单的零件反转来替换分类使其降至0.11.

将算法拆分为小的顶级函数,每个函数执行一个任务,使其更具可读性,但这需要付出代价.需要在函数之间传递更多参数,因此并非所有参数都可以在寄存器中传递,并且一些传递的参数 - 数组边界和元素数 - 根本不被使用,因此传递了自重.使所有其他功能本地功能在solution一定程度上减少了总体分配和运行时间(排序为0.13秒,反向为0.09),因为现在只需要传递必要的参数.

偏离给定的算法并使其恢复正常工作,

module Main (main) where

import Data.Array.ST
import Data.Array.Base
import Data.Array.Unboxed
import Control.Monad.ST
import Control.Monad (when)
import Data.Bits

lexPerm :: Int -> Int -> [Int]
lexPerm idx num = elems (runSTUArray $ do
    arr <- unsafeNewArray_ (0,num)
    let fill i
            | num < i   = return ()
            | otherwise = unsafeWrite arr i i >> fill (i+1)
        swap i j = do
            x <- unsafeRead arr i
            y <- unsafeRead arr j
            unsafeWrite arr j x
            unsafeWrite arr i y
        flop i j
            | i < j     = do
                swap i j
                flop (i+1) (j-1)
            | otherwise = return ()
        binsearch v a b = go a b
          where
            go i j
              | i < j     = do
                let m = (i+j+1) `unsafeShiftR` 1
                w <- unsafeRead arr m
                if w < v
                  then go i (m-1)
                  else go m j
              | otherwise = swap a i
        upstep k j
            | k < 1     = return ()
            | j == num-1 = unsafeRead arr num >>= flip (back k) (num-1)
            | otherwise  = nextP k (num-1)
        back k v i
            | i < 0     = return ()
            | otherwise = do
                w <- unsafeRead arr i
                if w < v
                  then nextP k i
                  else back k w (i-1)
        nextP k up
            | k < 1 || up < 0   = return ()
            | otherwise = do
                v <- unsafeRead arr up
                binsearch v up num
                flop (up+1) num
                upstep (k-1) up
    fill 0
    nextP (idx-1) (num-1)
    return arr)

main :: IO ()
main = print $ lexPerm 1000000 9
Run Code Online (Sandbox Code Playgroud)

我们可以在0.02秒内完成任务.

然而,问题中提到的聪明算法在更短的时间内用更少的代码解决了任务.