算法的 Haskell 复杂度

Kha*_*ong 8 haskell

我在 Codeforces 上遇到一个简单的问题,问题是这样的。我想讨论的不是问题,而是我们使用的语言,在本例中是 Python3 和 Haskell。

具体来说,我的算法有两个版本,一个在 Haskell 中,另一个在 Python3 中。两者都采用函数式编程风格。代码看起来像这样

Python3

from operator import add
from itertools import accumulate
from functools import reduce
 
 
def floss(l):
    def e(u):
        a, b = u
        return b if a % 2 == 1 else -b
 
    return map(e, enumerate(l))
 
 
def flock(l):
    return accumulate(l, add)
 
 
def search(l):
    b = zip(l, l[1:])
 
    def equal(u):
        x, y = u
        return x == y
 
    c = any(map(equal, b))
    return 'YES\n' if c else 'NO\n'
 
 
def main():
    t = int(input())
 
    def solution(x):
        return search(sorted(list(flock(floss(x)))))
 
    def get():
        _ = input()
        b = [0] + [int(x) for x in input().split()]
        return b
 
    all_data = [get() for _ in range(t)]
    all_solution = map(solution, all_data)
    print(reduce(add, all_solution))
 
 
main()
Run Code Online (Sandbox Code Playgroud)

哈斯克尔

module Main (main) where
import Data.List (sort)
 
main :: IO ()
main = do
  x <- des
  putStrLn x
 
readInts :: IO [Int]
readInts = fmap (map read.words) getLine
 
flock :: [Int] -> [Int]
flock l = scanr (+) 0 l 
  
floss :: [Int] -> [Int]
floss l = map (e :: (Int, Int) -> Int) $ zip [0..] l where {
  e (u, v) = if mod u 2 == 0 then v else -v
  }
  
search :: [Int] -> String 
search l = if c then "YES\n" else "NO\n" where {
  b = zip l $ tail l;
  c = any (\(x, y) -> x == y) b;
  }
  
solution :: [Int] -> String 
solution = search.sort.flock.floss
 
des :: IO String
des = do 
  io <- readInts
  let t = head io
  all_data <- sequence $ replicate t $ do
    _ <- readInts
    b <- readInts
    return b
  let all_solution = map solution all_data
  let output = foldr (++) "" all_solution
  return output
Run Code Online (Sandbox Code Playgroud)

两者在算法上比较相同。事实上,Python3 通过了高复杂度的测试用例,而 Haskell 代码却不能。我想知道为什么我在 Haskell 中的代码运行速度比 Python3 慢,我想知道 Haskell 的操作导致了错误。我发现可疑的一件事是我的 Haskell 代码的内存使用量比 Python3 代码高得多(2-8 倍)。

我最近才开始学习FP,所以帖子中可能犯了一些错误。

更新 #1:我发现可能对 bug 检测有用的一件事是 Haskell 代码中的let output = foldr (++) "" all_solution. 在更糟糕的代码中,我使用了foldl而不是foldr,这使得代码变得非常慢。我认为这可能会使错误检测任务变得更容易一些。

Dan*_*ner 10

从分析来看,大部分时间都花在了readInts。一个非常愚蠢的重新实现,在我的测试中不会调用read已经使程序速度增加三倍的全部开销:

readIntDumb :: String -> IO Int
readIntDumb = go 0 1 where
    go n sgn [] = pure (sgn * n)
    go n sgn (c:cs) = case c of
        '-' -> go n (negate sgn) cs
        d | '0' <= d && d <= '9' -> go (10*n + fromEnum d - fromEnum '0') sgn cs
        _ -> fail "whoops"
Run Code Online (Sandbox Code Playgroud)

转向ByteString-basedIO会得到另一个因数 2:

import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS8

readIntsBS :: ByteString -> [Int]
readIntsBS bs = case BS8.readInt bs of
    Nothing -> []
    Just (n, bs') -> n : readIntsBS (BS8.dropWhile isSpace bs')

readInts :: IO [Int]
readInts = readIntsBS <$> BS8.getLine
Run Code Online (Sandbox Code Playgroud)

此时,分析显示大约一半的运行时间是由于sort. 切换到nubInt速度会提高很多:

import Data.Containers.ListUtils
search l = if nubInt l /= l then "YES" else "NO"
solution = search.flock.floss
Run Code Online (Sandbox Code Playgroud)

或者您可以实现自定义唯一性检查,尽管这比使用以下方法只节省了一点点nubInt

import qualified Data.IntSet as IS

search :: [Int] -> String 
search l = if uniqInt l then "NO" else "YES"

uniqInt :: [Int] -> Bool
uniqInt = go IS.empty where
    go seen [] = True
    go seen (n:ns) = case IS.alterF (,True) n seen of
        (False, seen') -> go seen' ns
        _ -> False
Run Code Online (Sandbox Code Playgroud)

您还需要从 切换scanrscanl。(根据经验,对于折叠,您通常会根据要折叠的操作进行选择foldrfoldl'但对于扫描,您几乎总是需要scanl或其较小的变体,例如scanl1。)这几乎使速度翻倍。

flock = scanl (+) 0
Run Code Online (Sandbox Code Playgroud)

至此,累计节省已将我的机器+测试文件上的运行时间从 25s 降至 0.8s;也许这已经足够了。这是完整的最终结果,还有一些上面未明确讨论的细微调整(针对风格,而不是性能)。

import Control.Monad
import Data.Bool
import Data.ByteString.Char8 (ByteString)
import Data.Char
import Data.Containers.ListUtils
import qualified Data.ByteString.Char8 as BS8
import qualified Data.IntSet as IS

main :: IO ()
main = do
    t <- readLn
    replicateM_ t $ do
        BS8.getLine
        putStrLn . solution . readIntsBS =<< BS8.getLine

solution :: [Int] -> String
solution = bool "YES" "NO" . uniqInt . scanl (+) 0 . zipWith ($) (cycle [id, negate])

readIntsBS :: ByteString -> [Int]
readIntsBS bs = case BS8.readInt bs of
    Nothing -> []
    Just (n, bs') -> n : readIntsBS (BS8.dropWhile isSpace bs')

uniqInt :: [Int] -> Bool
uniqInt = go IS.empty where
    go seen [] = True
    go seen (n:ns) = case IS.alterF (,True) n seen of
        (False, seen') -> go seen' ns
        _ -> False
Run Code Online (Sandbox Code Playgroud)