Sal*_*Sal 16 monads state haskell loops
go
worker tail-recursive loop pattern似乎非常适合编写纯代码.为ST
monad 编写这种循环的等效方法是什么?更具体地说,我想避免循环迭代中的新堆分配.我的猜测是它涉及任一CPS transformation
或fixST
重新写使得所有在跨越环路改变的值被在每个迭代中通过,从而使寄存器的位置(或在溢出的情况下,叠加)适用于在可迭代这些值的代码.我在下面有一个简化的示例(不要尝试运行它 - 它可能会因为分段错误而崩溃!),其中涉及一个findSnakes
具有go
工作模式的函数,但更改的状态值不通过累加器参数传递:
{-# LANGUAGE BangPatterns #-}
module Test where
import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int
type MVI1 s = MVector (PrimState (ST s)) Int
-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
y1 <- MU.unsafeRead fp (k+offset+1)
if y0 > y1 then return y0
else return y1
{-#INLINE findYP #-}
findSnakes :: Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
where
offset=1+U.length a
go x k'
| x < ct = do
yp <- findYP fp k' offset
MU.unsafeWrite fp (k'+offset) (yp + k')
go (x+1) (op k' 1)
| otherwise = return ()
{-#INLINE findSnakes #-}
Run Code Online (Sandbox Code Playgroud)
看看cmm
输出ghc 7.6.1
(用我有限的知识cmm
- 请纠正我,如果我弄错了),我看到这种调用流程,循环进入s1tb_info
(导致每次迭代中的堆分配和堆检查):
findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out
what that register points to)
-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1
s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)
Run Code Online (Sandbox Code Playgroud)
我的猜测是arg >= 1
在cmm
代码中检查表单是为了确定go
循环是否已终止.如果这是正确的,似乎除非go
循环被重写以通过yp
循环,堆分配将在循环中发生新值(我猜yp
是导致堆分配).go
在上面的例子中写循环的有效方法是什么?我想yp
必须作为go
循环中的参数传递,或者通过等效方式传递fixST
或CPS
转换.我想不出一个好的方法来重写go
上面的循环以删除堆分配,并将欣赏它的帮助.
我重写了您的函数以避免任何显式递归,并删除了一些计算偏移量的冗余操作。这会编译成比原始函数更好的核心。
顺便说一下,Core 可能是分析此类分析的编译代码的更好方法。用于ghc -ddump-simpl
查看生成的核心输出,或使用类似的工具ghc-core
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Int
import qualified Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector.Unboxed as U
type MVI1 s = M.MVector (PrimState (ST s)) Int
findYP :: MVI1 s -> Int -> ST s Int
findYP fp offset = do
y0 <- M.unsafeRead fp (offset+0)
y1 <- M.unsafeRead fp (offset+2)
return $ max (y0 + 1) y1
findSnakes :: U.Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0
where writeAt k = do
let offset = U.length a + k
yp <- findYP fp offset
M.unsafeWrite fp (offset + 1) (yp + k)
-- or inline findYP manually
writeAt k = do
let offset = U.length a + k
y0 <- M.unsafeRead fp (offset + 0)
y1 <- M.unsafeRead fp (offset + 2)
M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)
Run Code Online (Sandbox Code Playgroud)
另外,您将 a 传递U.Vector Int32
给findSnakes
,只是为了计算其长度并且不再使用a
。为什么不直接传入长度呢?