用于"取n(排序xs)"("排序前缀")问题的内存有效算法

Ten*_*ner 7 sorting haskell memory-management lazy-evaluation

我想从懒惰列表中获取n个最大的元素.

我听说在Data.List.sort中实现的mergesort是惰性的,它不会产生超过必要的元素.这在比较方面可能是正确的,但在内存使用方面肯定不是这样.以下程序说明了该问题:

{-# LANGUAGE ScopedTypeVariables #-}

module Main where

import qualified Data.Heap as Heap
import qualified Data.List as List

import System.Random.MWC
import qualified Data.Vector.Unboxed as Vec

import System.Environment

limitSortL n xs = take n (List.sort xs)
limitSortH n xs = List.unfoldr Heap.uncons (List.foldl' (\ acc x -> Heap.take n (Heap.insert x acc) ) Heap.empty xs) 

main = do
  st <- create
  rxs :: [Int] <- Vec.toList `fmap` uniformVector st (10^7)

  args <- getArgs
  case args of
    ["LIST"] -> print (limitSortL 20 rxs)
    ["HEAP"] -> print (limitSortH 20 rxs)

  return ()
Run Code Online (Sandbox Code Playgroud)

运行:

Data.List模块:

./lazyTest LIST +RTS -s 
[-9223371438221280004,-9223369283422017686,-9223368296903201811,-9223365203042113783,-9223364809100004863,-9223363058932210878,-9223362160334234021,-9223359019266180408,-9223358851531436915,-9223345045262962114,-9223343191568060219,-9223342956514809662,-9223341125508040302,-9223340661319591967,-9223337771462470186,-9223336010230770808,-9223331570472117335,-9223329558935830150,-9223329536207787831,-9223328937489459283]
   2,059,921,192 bytes allocated in the heap
   2,248,105,704 bytes copied during GC
     552,350,688 bytes maximum residency (5 sample(s))
       3,390,456 bytes maximum slop
            1168 MB total memory in use (0 MB lost due to fragmentation)

  Generation 0:  3772 collections,     0 parallel,  1.44s,  1.48s elapsed
  Generation 1:     5 collections,     0 parallel,  0.90s,  1.13s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time    0.82s  (  0.84s elapsed)
  GC    time    2.34s  (  2.61s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time    3.16s  (  3.45s elapsed)

  %GC time      74.1%  (75.7% elapsed)

  Alloc rate    2,522,515,156 bytes per MUT second

  Productivity  25.9% of total user, 23.7% of total elapsed

Data.Heap:

./lazyTest HEAP +RTS -s 
[-9223371438221280004,-9223369283422017686,-9223368296903201811,-9223365203042113783,-9223364809100004863,-9223363058932210878,-9223362160334234021,-9223359019266180408,-9223358851531436915,-9223345045262962114,-9223343191568060219,-9223342956514809662,-9223341125508040302,-9223340661319591967,-9223337771462470186,-9223336010230770808,-9223331570472117335,-9223329558935830150,-9223329536207787831,-9223328937489459283]
 177,559,536,928 bytes allocated in the heap
     237,093,320 bytes copied during GC
      80,031,376 bytes maximum residency (2 sample(s))
         745,368 bytes maximum slop
              78 MB total memory in use (0 MB lost due to fragmentation)

  Generation 0: 338539 collections,     0 parallel,  1.24s,  1.31s elapsed
  Generation 1:     2 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.24s  ( 35.46s elapsed)
  GC    time    1.24s  (  1.31s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   36.48s  ( 36.77s elapsed)

  %GC time       3.4%  (3.6% elapsed)

  Alloc rate    5,038,907,812 bytes per MUT second

  Productivity  96.6% of total user, 95.8% of total elapsed

显然,limitSortL要快得多,但它也非常耗费内存.在较大的列表上,它达到了RAM的大小.

是否有更快的算法来解决这个问题,而不是内存饥饿?

编辑:澄清:我从堆包中使用Data.Heap,我没有尝试堆包.

Ten*_*ner 4

所以,我实际上已经成功解决了这个问题。这个想法是扔掉花哨的数据结构并手工工作;-)本质上我们将输入列表分成块,对它们进行排序,然后折叠列表[[Int]],在每一步选择n最小的元素。棘手的部分是以正确的方式将累加器与排序块合并。我们必须使用seq,否则懒惰会咬住你,结果仍然需要大量的内存来计算。另外,我将 merge 与 混合使用take n,只是为了进一步优化。这是整个程序以及之前的尝试:

{-# LANGUAGE ScopedTypeVariables, PackageImports #-}     
module Main where

import qualified Data.List as List
import qualified Data.List.Split as Split
import qualified "heaps" Data.Heap as Heap -- qualified import from "heaps" package

import System.Random.MWC
import qualified Data.Vector.Unboxed as Vec

import System.Environment

limitSortL n xs = take n (List.sort xs)
limitSortH n xs = List.unfoldr Heap.uncons (List.foldl' (\ acc x -> Heap.take n (Heap.insert x acc) ) Heap.empty xs)
takeSortMerge n inp = List.foldl' 
                        (\acc lst -> (merge n acc (List.sort lst))) 
                        [] (Split.splitEvery n inp)
    where
     merge 0 _ _ = []
     merge _ [] xs = xs
     merge _ ys [] = ys
     merge f (x:xs) (y:ys) | x < y = let tail = merge (f-1) xs (y:ys) in tail `seq` (x:tail) 
                           | otherwise = let tail = merge (f-1) (x:xs) ys in tail `seq` (y:tail)


main = do
  st <- create

  let n1 = 10^7
      n2 = 20

  rxs :: [Int] <- Vec.toList `fmap` uniformVector st (n1)

  args <- getArgs

  case args of
    ["LIST"] ->  print (limitSortL n2 rxs)
    ["HEAP"] ->  print (limitSortH n2 rxs)
    ["MERGE"] -> print (takeSortMerge n2 rxs)
    _ -> putStrLn "Nothing..."

  return ()
Run Code Online (Sandbox Code Playgroud)

运行时性能、内存消耗、GC时间:

列表 3.96s 1168 MB 75 %
堆 35.29s 78 MB 3.6 %
合并 1.00s 78 MB 3.0 %
just rxs 0.21s 78 MB 0.0 % -- 只是评估随机向量