如何优化 Haskell 中的数值积分性能(举例)

pen*_*sky 5 optimization haskell numeric numerical-integration

如何优化数值积分例程(与 C 相比)?

目前已经做了什么:

  1. 我用未装箱的向量替换了列表(显而易见)。
  2. 我应用了“Read World Haskell”一书中描述的分析技术http://book.realworldhaskell.org/read/profiling-and-optimization.html。我内联了一些琐碎的函数,并在各处插入了很多刘海。这带来了大约 10 倍的加速。
  3. 我重构了代码(即提取iterator函数)。这带来了 3 倍的加速。
  4. 我尝试用 Floats 替换多态签名,如这个问题 Optimizing numeric array Performance in Haskell 的答案。这几乎提高了 2 倍的速度。
  5. 我这样编译 cabal exec ghc -- Simul.hs -O2 -fforce-recomp -fllvm -Wall
  6. 更新按照 cchalmers 的建议,type Sample = (F, F)被替换为 data Sample = Sample {-# UNPACK #-} !F {-# UNPACK #-} !F

现在的性能几乎和C代码一样好。我们可以做得更好吗?

{-# LANGUAGE BangPatterns #-}

module Main
  where

import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import qualified Control.Monad.Primitive as PrimitiveM

import Dynamics.Nonlin ( birefrP )

type F = Float
type Delay = U.Vector F
type Input = U.Vector F
-- Sample can be a vector of any length (x, y, z, ...)
data Sample = Sample {-# UNPACK #-} !F {-# UNPACK #-} !F
-- Pair is used to define exactly a pair of values
data Pair = Pair {-# UNPACK #-} !F {-# UNPACK #-} !F

type ParametrizedDelayFunction = (Sample, F) -> Sample

getX :: Sample -> F
getX (Sample a _) = a
{-# INLINE getX #-}

toDelay :: [F] -> Delay
toDelay = U.fromList

stepsPerNode :: Int
stepsPerNode = 40  -- Number of integration steps per node

infixl 6 ..+..
(..+..) :: Sample -> Sample -> Sample
(..+..) (Sample x1 y1) (Sample x2 y2) = Sample (x1 + x2) (y1 + y2)
{-# INLINE (..+..) #-}

infixl 7 .*..
(.*..) :: F -> Sample -> Sample
(.*..) c (Sample x2 y2) = Sample (c * x2) (c * y2)
{-# INLINE (.*..) #-}

-- | Ikeda model (dynamical system, DDE)
ikeda_model2
  :: (F -> F) -> (Sample, F) -> Sample
ikeda_model2 f (!(Sample x y), !x_h) = Sample x' y'
  where
    ! x' = recip_epsilon * (-x + (f x_h))
    y' = 0
    recip_epsilon = 2^(6 :: Int)

-- | Integrate using improved Euler's method (fixed step).
--
-- hOver2 is already half of step size h
-- f is the function to integrate
-- x_i is current argument (x and y)
-- x_h is historical (delayed) value
-- x_h2 it the value after x_h
heun2 :: F -> ParametrizedDelayFunction
  -> Sample -> Pair -> Sample
heun2 hOver2 f !x !(Pair x_h x_h2) = x_1
  where
    ! f1 = f (x, x_h)
    ! x_1' = x ..+.. 2 * hOver2 .*.. f1
    ! f2 = f (x_1', x_h2)
    ! x_1 = x ..+.. hOver2 .*.. (f1 ..+.. f2)


initialCond :: Int -> (Sample, Delay, Int)
initialCond nodesN = (initialSampleXY, initialInterval, samplesPerDelay)
  where cdi = 1.1247695e-4 :: F  -- A fixed point for birefrP
        initialInterval = U.replicate samplesPerDelay cdi
        samplesPerDelay = nodesN * stepsPerNode
        initialSampleXY = Sample 0.0 0.0

integrator
  :: PrimitiveM.PrimMonad m =>
    (Sample -> Pair -> Sample)
    -> Int
    -> Int
    -> (Sample, (Delay, Input))
    -> m (Sample, U.Vector F)
integrator iterate1 len total (xy0, (history0, input)) = do
    ! v <- UM.new total
    go v 0 xy0
    history <- U.unsafeFreeze v
    -- Zero y value, currently not used
    let xy = Sample (history `U.unsafeIndex` (total - 1)) 0.0
    return (xy, history)
  where
    h i = history0 `U.unsafeIndex` i
    go !v !i !xy
      -- The first iteration
      | i == 0 = do
        let !r = iterate1 xy (Pair (h 0) (h 1))
        UM.unsafeWrite v i (getX r)
        go v 1 r
      | i < len - 1 = do
        let !r = iterate1 xy (Pair (h i) (h $ i + 1))
        UM.unsafeWrite v i (getX r)
        go v (i + 1) r
      | i == total = do
        return ()
      -- Iterations after the initial history has been exhausted
      | otherwise = do
        ! newX0 <- if i == len - 1
                      then return (getX xy0)
                      else UM.unsafeRead v (i - len - 1)
        ! newX <- UM.unsafeRead v (i - len)
        let !r = iterate1 xy (Pair newX0 newX)
        UM.unsafeWrite v i (getX r)
        go v (i + 1) r

-- Not used in this version
zero :: Input
zero = U.fromList []

nodes :: Int
nodes = 306

main :: IO ()
main = do
  let delays = 4000
      (sample0, hist0, delayLength) = initialCond nodes
      -- Iterator implements Heun's schema
      iterator = heun2 (recip 2^(7::Int) :: F) (ikeda_model2 birefrP)
      totalComputedIterations = delayLength * delays

  -- Calculates all the time trace
  (xy1, history1) <- integrator iterator delayLength totalComputedIterations (sample0, (hist0, zero))
  putStrLn $ show $ getX xy1

  return ()
Run Code Online (Sandbox Code Playgroud)

非线性函数(导入)可以如下所示:

data Parameters = Parameters { beta :: Float
                             , alpha :: Float
                             , phi :: Float } deriving Show
paramA :: Parameters
paramA = Parameters { beta = 1.1
                    , alpha = 1.0
                    , phi = 0.01 }

birefr :: Parameters -> Float -> Float
birefr par !x = 0.5 * beta' * (1 - alpha' * (cos $ 2.0 * (x + phi')))
  where
    ! beta' = beta par
    ! alpha' = alpha par
    ! phi' = phi par

birefrP :: Float -> Float
birefrP = birefr paramA
Run Code Online (Sandbox Code Playgroud)