Haskell:我可以直接将整数读入数组吗?

cob*_*bra 10 arrays performance stdin haskell

编程问题中,输入是一个n\xc3\x97m整数矩阵。通常,n\xe2\x89\x88 10 5m\xe2\x89\x88 10。官方解决方案(1606D,教程)非常必要:它涉及一些矩阵操作、预计算和聚合。为了好玩,我把它当作 STUArray 实现练习。

\n

问题

\n

我已经设法使用 STUArray 来实现它,但该程序仍然占用比允许的更多的内存(256MB)。即使在本地运行,最大驻留集大小也>400 MB。在分析中,从标准输入读取似乎占据了内存占用的主导地位:

\n

剖析

\n

函数readvreadv.readInt负责解析整数并将它们保存到 2D 列表中,大约占用 50-70 MB,而不是大约 16 MB = (10 6 个整数) \xc3\x97 (每个整数 8 个字节 + 每个链接 8 个字节) 。

\n

我是否希望总内存低于 256 MB?我已经在使用Text包进行输入。也许我应该完全避免列表并直接将整数从标准输入读取到数组。我们怎样才能做到这一点?或者,问题出在其他地方吗?

\n

代码

\n
{-# OPTIONS_GHC -O2 #-}\nmodule CF1606D where\nimport qualified Data.Text as T\nimport qualified Data.Text.IO as TI\nimport qualified Data.Text.Read as TR\nimport Control.Monad\nimport qualified Data.List as DL\nimport qualified Data.IntSet as DS\nimport Control.Monad.ST\nimport Data.Array.ST.Safe\nimport Data.Int (Int32)\nimport Data.Array.Unboxed\n\nsolve :: IO ()\nsolve =  do\n  ~[n,m] <- readv      \n  -- 2D list\n  input <- {-# SCC input #-} replicateM (fromIntegral n) readv     \n  let\n      ints = [1..]\n      sorted = DL.sortOn (head.fst) (zip input ints)\n      (rows,indices) = {-# SCC rows_inds #-} unzip sorted    \n      -- 2D list converted into matrix:\n      matrix = mat (fromIntegral n) (fromIntegral m) rows           \n      infinite = 10^7\n      asc x y = [x,x+1..y]\n      desc x y = [y,y-1..x]    \n      -- Four prefix-matrices:\n      tlMax = runSTUArray $ prefixMat max 0 asc asc (subtract 1) (subtract 1) =<< matrix\n      blMin = runSTUArray $ prefixMat min infinite desc asc (+1) (subtract 1) =<< matrix\n      trMin = runSTUArray $ prefixMat min infinite asc desc (subtract 1) (+1) =<< matrix\n      brMax = runSTUArray $ prefixMat max 0 desc desc (+1) (+1) =<< matrix    \n      good _ (i,j)\n        | tlMax!(i,j) < blMin!(i+1,j) && brMax!(i+1,j+1) < trMin!(i,j+1) = Left (i,j)\n        | otherwise = Right ()\n      {-# INLINABLE good #-}\n      nearAns = foldM good () [(i,j)|i<-[1..n-1],j<-[1..m-1]]\n      ans = either (\\(i,j)-> "YES\\n" ++ color n (take i indices) ++ " " ++ show j) (const "NO") nearAns\n  putStrLn ans\n\ntype I = Int32\ntype S s = (STUArray s (Int, Int) I)\ntype R = Int -> Int -> [Int]\ntype F = Int -> Int\n\nmat :: Int -> Int -> [[I]] -> ST s (S s)\nmat n m rows = newListArray ((1,1),(n,m)) $ concat rows\n\nprefixMat :: (I->I->I) -> I -> R -> R -> F -> F -> S s -> ST s (S s)\nprefixMat opt worst ordi ordj previ prevj mat = do\n  ((ilo,jlo),(ihi,jhi)) <- getBounds mat\n  pre <- newArray ((ilo-1,jlo-1),(ihi+1,jhi+1)) worst\n  forM_ (ordi ilo ihi) $ \\i-> do\n    forM_ (ordj jlo jhi) $ \\j -> do\n      matij <- readArray mat (i,j)\n      prei <- readArray pre (previ i,j)\n      prej <- readArray pre (i, prevj j)\n      writeArray pre (i,j) (opt (opt prei prej) matij)\n  return pre\n\ncolor :: Int -> [Int] -> String\ncolor n inds = let\n  temp = DS.fromList inds\n  colors = [if DS.member i temp then \'B\' else \'R\' | i<-[1..n]]\n  in colors\n\nreadv :: Integral t => IO [t]\nreadv = map readInt . T.words <$> TI.getLine where\n  readInt = fromIntegral . either (const 0) fst . TR.signed TR.decimal\n{-# INLINABLE readv #-}\n\nmain :: IO ()\nmain = do\n  ~[n] <- readv\n  replicateM_ n solve\n
Run Code Online (Sandbox Code Playgroud)\n

上面代码的简单描述:

\n
    \n
  1. 读取n每行都有m整数的行。
  2. \n
  3. 按行的第一个元素对行进行排序。
  4. \n
  5. 现在计算四个“前缀矩阵”,每个角一个。对于左上角和右下角,它是我们需要计算的前缀最大值,对于其他两个角,它是我们需要计算的前缀最小值。
  6. \n
  7. 查找这些前缀矩阵满足以下条件的单元格 [i,j]:top_left [i,j] < Bottom_left [i,j] 且 top_right [i,j] > Bottom_right [i,j]
  8. \n
  9. 对于第 1 行到第 i 行,将其原始索引(即未排序的输入矩阵中的位置)标记为蓝色。将其余标记为红色。
  10. \n
\n

输入和命令示例

\n

示例输入:inp3.txt

\n

命令:

\n
> stack ghc -- -main-is CF1606D.main -with-rtsopts="-s -h -p -P" -rtsopts -prof -fprof-auto CF1606D\n> gtime -v ./CF1606D < inp3.txt > outp\n    ...\n    ...\n    MUT     time    2.990s  (  3.744s elapsed)    #    RTS -s output\n    GC      time    4.525s  (  6.231s elapsed)    #    RTS -s output\n    ...\n    ...\n    Maximum resident set size (kbytes): 408532    #    >256 MB (gtime output)\n\n> stack exec -- hp2ps -t0.1 -e8in -c CF1606D.hp && open CF1606D.ps\n
Run Code Online (Sandbox Code Playgroud)\n

关于 GC 的问题:如上面 +RTS -s 输出所示,GC 似乎比实际逻辑执行花费的时间更长。这是正常的吗?有没有办法可视化 GC 活动随时间的变化?我尝试使矩阵严格,但这没有任何影响。

\n

也许这根本不是一个功能友好的问题(尽管我很乐意在这一点上被反驳)。例如,Java也使用GC,但有很多成功的Java提交。尽管如此,我还是想看看我能走多远。谢谢!

\n

leh*_*ins 3

与普遍看法相反,Haskell 对于此类问题非常友好。真正的问题是arrayGHC 附带的库完全是垃圾。另一个大问题是,Haskell 教每个人都应该使用列表,而应该使用数组,这通常是代码缓慢和内存膨胀程序的主要根源之一。因此,GC 花费很长时间并不奇怪,因为分配的东西太多了。这是对下面提供的解决方案所提供的输入的运行:

   1,483,547,096 bytes allocated in the heap
         566,448 bytes copied during GC
      18,703,640 bytes maximum residency (3 sample(s))
       1,223,400 bytes maximum slop
              32 MiB total memory in use (0 MB lost due to fragmentation)

                                     Tot time (elapsed)  Avg pause  Max pause
  Gen  0      1399 colls,     0 par    0.009s   0.009s     0.0000s    0.0011s
  Gen  1         3 colls,     0 par    0.002s   0.002s     0.0006s    0.0016s

  TASKS: 4 (1 bound, 3 peak workers (3 total), using -N1)

  SPARKS: 0 (0 converted, 0 overflowed, 0 dud, 0 GC'd, 0 fizzled)

  INIT    time    0.001s  (  0.001s elapsed)
  MUT     time    0.484s  (  0.517s elapsed)
  GC      time    0.011s  (  0.011s elapsed)
  EXIT    time    0.001s  (  0.002s elapsed)
  Total   time    0.496s  (  0.530s elapsed)
Run Code Online (Sandbox Code Playgroud)

下面提供的解决方案使用了数组库massiv,这使得无法提交到codeforces。然而,希望我们的目标是更好地使用 Haskell,而不是在某些网站上获得分数。

红蓝矩阵可以分为两个阶段:读取求解

读取尺寸

main函数中,我们只读取数组的总数和每个数组的维度。我们还打印结果。这里没什么令人兴奋的。(请注意,链接文件的inp3.txt数组大于问题中定义的限制n*m <= 10^6:)

   1,483,547,096 bytes allocated in the heap
         566,448 bytes copied during GC
      18,703,640 bytes maximum residency (3 sample(s))
       1,223,400 bytes maximum slop
              32 MiB total memory in use (0 MB lost due to fragmentation)

                                     Tot time (elapsed)  Avg pause  Max pause
  Gen  0      1399 colls,     0 par    0.009s   0.009s     0.0000s    0.0011s
  Gen  1         3 colls,     0 par    0.002s   0.002s     0.0006s    0.0016s

  TASKS: 4 (1 bound, 3 peak workers (3 total), using -N1)

  SPARKS: 0 (0 converted, 0 overflowed, 0 dud, 0 GC'd, 0 fizzled)

  INIT    time    0.001s  (  0.001s elapsed)
  MUT     time    0.484s  (  0.517s elapsed)
  GC      time    0.011s  (  0.011s elapsed)
  EXIT    time    0.001s  (  0.002s elapsed)
  Total   time    0.496s  (  0.530s elapsed)
Run Code Online (Sandbox Code Playgroud)

读取数组中

将输入加载到数组中是原始问题的主要问题根源:

  • 无需依赖text,ascii 字符是问题所期望的唯一有效输入。
  • 输入被读入列表的列表中。该列表的列表是内存开销的真正来源。
  • 对列表进行排序非常慢并且占用内存。

通常在这种情况下,使用诸如 之类的东西以流方式读取输入会更好conduit。特别是,将输入读取为字节流并将这些字节解析为数字将是最佳解决方案。话虽这么说,在问题的描述中对每个数组的宽度有硬性要求,因此我们可以将输入逐行读取为 a,ByteString然后解析每行中的数字(为了简单起见,假设无符号)并写入这些数字同时放入数组中。这确保了在这个阶段我们只会分配结果数组和一行作为字节序列。这可以使用像 之类的解析库来完成attoparsec,但问题很简单,只需临时执行即可。

import Control.Monad.ST
import Control.Monad
import qualified Data.ByteString as BS
import Data.Massiv.Array as A hiding (B)
import Data.Massiv.Array.Mutable.Algorithms (quicksortByM_)
import Control.Scheduler (trivialScheduler_)

main :: IO ()
main = do
  t <- Prelude.read <$> getLine
  when (t < 1 || t > 1000) $ error $ "Invalid t: " ++ show t
  replicateM_ t $ do
    dimsStr <- getLine
    case Prelude.map Prelude.read (words dimsStr) of
      -- Test file fails this check: && n * m <= 10 ^ (6 :: Int) -> do
      [n, m] | n >= 2 && m > 0 && m <= 5 * 10 ^ (5 :: Int) -> do
        mat <- readMatrix n m
        case solve mat of
          Nothing -> putStrLn "NO"
          Just (ix, cs) -> do
            putStrLn "YES"
            putStr $ foldMap show cs
            putStr " "
            print ix
      _ -> putStrLn $ "Unexpected dimensions: " ++ show dimsStr
Run Code Online (Sandbox Code Playgroud)

解决

这是我们实施实际解决方案的步骤。我没有尝试修复这个 SO 问题中提供的解决方案,而是继续翻译问题中链接的C++ 解决方案。我走这条路的原因有两个:

  • C++ 解决方案是高度命令式的,我想证明命令式数组操作对于 Haskell 来说并不陌生,因此我尝试创建一个尽可能接近的翻译。
  • 我知道这个解决方案有效

请注意,应该可以使用arraypackage 重写下面的解决方案,因为最终需要的只是read,writeallocate操作。

type Val = Word

readMatrix :: Int -> Int -> IO (Matrix P Val)
readMatrix n m = createArrayS_ (Sz2 n m) readMMatrix

readMMatrix :: MMatrix RealWorld P Val -> IO ()
readMMatrix mat =
  loopM_ 0 (< n) (+ 1) $ \i -> do
    line <- BS.getLine
    --- ^ reads at most 10Mb because it is known that input will be at most
    -- 5*10^5 Words: 19 digits max per Word and one for space: 5*10^5 * 20bytes
    loopM 0 (< m) (+ 1) line $ \j bs ->
      let (word, bs') = parseWord bs
       in bs' <$ write_ mat (i :. j) word
  where
    Sz2 n m = sizeOfMArray mat
    isSpace = (== 32)
    isDigit w8 = w8 >= 48 && w8 <= 57
    parseWord bs =
      case BS.uncons bs of
        Just (w8, bs')
          | isDigit w8 -> parseWordLoop (fromIntegral (w8 - 48)) bs'
          | otherwise -> error $ "Unexpected byte: " ++ show w8
        Nothing -> error "Unexpected end of input"
    parseWordLoop !acc bs =
      case BS.uncons bs of
        Nothing -> (acc, bs)
        Just (w8, bs')
          | isSpace w8 -> (acc, bs')
          | isDigit w8 -> parseWordLoop (acc * 10 + fromIntegral (w8 - 48)) bs'
          | otherwise -> error $ "Unexpected byte: " ++ show w8
Run Code Online (Sandbox Code Playgroud)