Levenshtein距离的Haskell尾递归性能问题

jdo*_*jdo 6 recursion haskell tail sequencing levenshtein-distance

我正在玩Haskell 计算Levenshtein距离,并对下面的性能问题感到有点沮丧.如果你为Haskell实现最"正常"的方式,就像下面(dist)一样,一切正常:

dist :: (Ord a) => [a] -> [a] -> Int
dist s1 s2 = ldist s1 s2 (L.length s1, L.length s2)

ldist :: (Ord a) => [a] -> [a] -> (Int, Int) -> Int
ldist _ _ (0, 0) = 0
ldist _ _ (i, 0) = i
ldist _ _ (0, j) = j
ldist s1 s2 (i+1, j+1) = output
  where output | (s1!!(i)) == (s2!!(j)) = ldist s1 s2 (i, j)
               | otherwise = 1 + L.minimum [ldist s1 s2 (i, j)
                                          , ldist s1 s2 (i+1, j)
                                          , ldist s1 s2 (i, j+1)]
Run Code Online (Sandbox Code Playgroud)

但是,如果你弯曲你的大脑一点,并实现它为DIST",它执行MUCH更快(约10倍).

dist' :: (Ord a) => [a] -> [a] -> Int
dist' o1 o2 = (levenDist o1 o2 [[]])!!0!!0 

levenDist :: (Ord a) => [a] -> [a] -> [[Int]] -> [[Int]]
levenDist s1 s2 arr@([[]]) = levenDist s1 s2 [[0]]
levenDist s1 s2 arr@([]:xs) = levenDist s1 s2 ([(L.length arr) -1]:xs)
levenDist s1 s2 arr@(x:xs) = let
    n1 = L.length s1
    n2 = L.length s2
    n_i = L.length arr
    n_j = L.length x
    match | (s2!!(n_j-1) == s1!!(n_i-2)) = True | otherwise = False
    minCost = if match      then (xs!!0)!!(n2 - n_j + 1) 
                            else L.minimum [(1 + (xs!!0)!!(n2 - n_j + 1))
                                          , (1 + (xs!!0)!!(n2 - n_j + 0))
                                          , (1 + (x!!0))
                                          ]
    dist | (n_i > n1) && (n_j > n2)  = arr 
         | n_j > n2  = []:arr `seq` levenDist s1 s2 $ []:arr
         | n_i == 1 = (n_j:x):xs `seq` levenDist s1 s2 $ (n_j:x):xs
         | otherwise = (minCost:x):xs `seq` levenDist s1 s2 $ (minCost:x):xs
    in dist 
Run Code Online (Sandbox Code Playgroud)

seq在第一个版本中尝试了所有常用的技巧,但似乎没有什么能加速它.这对我来说有点不满意,因为我期望第一个版本更快,因为它不需要评估整个矩阵,只需要它需要的部分.

有谁知道是否有可能让这两个实现类似地执行,或者我只是在后者中获得尾递归优化的好处,因此如果我想要性能,需要忍受其不可读性?

谢谢,猎户座

Tra*_*own 5

在过去,我用这个非常简洁的版本foldl,并scanl维基教科书:

distScan :: (Ord a) => [a] -> [a] -> Int
distScan sa sb = last $ foldl transform [0 .. length sa] sb
  where
    transform xs@(x:xs') c = scanl compute (x + 1) (zip3 sa xs xs')
       where
         compute z (c', x, y) = minimum [y + 1, z + 1, x + fromEnum (c' /= c)]
Run Code Online (Sandbox Code Playgroud)

我只是使用Criterion运行这个简单的基准测试:

test :: ([Int] -> [Int] -> Int) -> Int -> Int
test f n = f up up + f up down + f up half + f down half
  where
    up = [1..n]
    half = [1..div n 2]
    down = reverse up

main = let n = 20 in defaultMain
  [ bench "Scan" $ nf (test distScan) n
  , bench "Fast" $ nf (test dist') n
  , bench "Slow" $ nf (test dist) n
  ]
Run Code Online (Sandbox Code Playgroud)

Wikibooks版本非常引人注目地击败了你们两个版本:

benchmarking Scan
collecting 100 samples, 51 iterations each, in estimated 683.7163 ms...
mean: 137.1582 us, lb 136.9858 us, ub 137.3391 us, ci 0.950

benchmarking Fast
collecting 100 samples, 11 iterations each, in estimated 732.5262 ms...
mean: 660.6217 us, lb 659.3847 us, ub 661.8530 us, ci 0.950...
Run Code Online (Sandbox Code Playgroud)

Slow 几分钟后仍在运行.


Nei*_*own 2

我还没有完全理解您的第二次尝试,但据我记得 Levenshtein 算法背后的想法是通过使用矩阵来节省重复计算。在第一段代码中,您没有共享任何计算,因此您将重复大量计算。例如,在计算时,您将进行至少三次单独的ldist s1 s2 (5,5)计算(一次直接,一次通过,一次通过)。ldist s1 s2 (4,4)ldist s1 s2 (4,5)ldist s1 s2 (5,4)

您应该做的是定义一个生成矩阵的算法(如果您愿意,可以作为列表的列表)。我认为这就是您的第二段代码正在做的事情,但它似乎专注于以自上而下的方式计算矩阵,而不是以归纳风格干净地构建矩阵(基本情况中的递归调用非常不寻常)在我看来)。不幸的是,我没有时间写出整个内容,但值得庆幸的是其他人有:查看此地址的第一个版本:http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Haskell

还有两件事:第一,我不确定 Levenshtein 算法是否只能使用矩阵的一部分,因为每个条目都依赖于对角线、垂直和水平邻居。当您需要一个角的值时,您将不可避免地必须一直计算矩阵到另一个角。其次,该match | foo = True | otherwise = False行可以简单地替换为match = foo.