jdo*_*jdo 6 recursion haskell tail sequencing levenshtein-distance
我正在玩Haskell 计算Levenshtein距离,并对下面的性能问题感到有点沮丧.如果你为Haskell实现最"正常"的方式,就像下面(dist)一样,一切正常:
dist :: (Ord a) => [a] -> [a] -> Int
dist s1 s2 = ldist s1 s2 (L.length s1, L.length s2)
ldist :: (Ord a) => [a] -> [a] -> (Int, Int) -> Int
ldist _ _ (0, 0) = 0
ldist _ _ (i, 0) = i
ldist _ _ (0, j) = j
ldist s1 s2 (i+1, j+1) = output
where output | (s1!!(i)) == (s2!!(j)) = ldist s1 s2 (i, j)
| otherwise = 1 + L.minimum [ldist s1 s2 (i, j)
, ldist s1 s2 (i+1, j)
, ldist s1 s2 (i, j+1)]
Run Code Online (Sandbox Code Playgroud)
但是,如果你弯曲你的大脑一点,并实现它为DIST",它执行MUCH更快(约10倍).
dist' :: (Ord a) => [a] -> [a] -> Int
dist' o1 o2 = (levenDist o1 o2 [[]])!!0!!0
levenDist :: (Ord a) => [a] -> [a] -> [[Int]] -> [[Int]]
levenDist s1 s2 arr@([[]]) = levenDist s1 s2 [[0]]
levenDist s1 s2 arr@([]:xs) = levenDist s1 s2 ([(L.length arr) -1]:xs)
levenDist s1 s2 arr@(x:xs) = let
n1 = L.length s1
n2 = L.length s2
n_i = L.length arr
n_j = L.length x
match | (s2!!(n_j-1) == s1!!(n_i-2)) = True | otherwise = False
minCost = if match then (xs!!0)!!(n2 - n_j + 1)
else L.minimum [(1 + (xs!!0)!!(n2 - n_j + 1))
, (1 + (xs!!0)!!(n2 - n_j + 0))
, (1 + (x!!0))
]
dist | (n_i > n1) && (n_j > n2) = arr
| n_j > n2 = []:arr `seq` levenDist s1 s2 $ []:arr
| n_i == 1 = (n_j:x):xs `seq` levenDist s1 s2 $ (n_j:x):xs
| otherwise = (minCost:x):xs `seq` levenDist s1 s2 $ (minCost:x):xs
in dist
Run Code Online (Sandbox Code Playgroud)
我seq
在第一个版本中尝试了所有常用的技巧,但似乎没有什么能加速它.这对我来说有点不满意,因为我期望第一个版本更快,因为它不需要评估整个矩阵,只需要它需要的部分.
有谁知道是否有可能让这两个实现类似地执行,或者我只是在后者中获得尾递归优化的好处,因此如果我想要性能,需要忍受其不可读性?
谢谢,猎户座
在过去,我用这个非常简洁的版本foldl
,并scanl
从维基教科书:
distScan :: (Ord a) => [a] -> [a] -> Int
distScan sa sb = last $ foldl transform [0 .. length sa] sb
where
transform xs@(x:xs') c = scanl compute (x + 1) (zip3 sa xs xs')
where
compute z (c', x, y) = minimum [y + 1, z + 1, x + fromEnum (c' /= c)]
Run Code Online (Sandbox Code Playgroud)
我只是使用Criterion运行这个简单的基准测试:
test :: ([Int] -> [Int] -> Int) -> Int -> Int
test f n = f up up + f up down + f up half + f down half
where
up = [1..n]
half = [1..div n 2]
down = reverse up
main = let n = 20 in defaultMain
[ bench "Scan" $ nf (test distScan) n
, bench "Fast" $ nf (test dist') n
, bench "Slow" $ nf (test dist) n
]
Run Code Online (Sandbox Code Playgroud)
Wikibooks版本非常引人注目地击败了你们两个版本:
benchmarking Scan
collecting 100 samples, 51 iterations each, in estimated 683.7163 ms...
mean: 137.1582 us, lb 136.9858 us, ub 137.3391 us, ci 0.950
benchmarking Fast
collecting 100 samples, 11 iterations each, in estimated 732.5262 ms...
mean: 660.6217 us, lb 659.3847 us, ub 661.8530 us, ci 0.950...
Run Code Online (Sandbox Code Playgroud)
Slow
几分钟后仍在运行.
我还没有完全理解您的第二次尝试,但据我记得 Levenshtein 算法背后的想法是通过使用矩阵来节省重复计算。在第一段代码中,您没有共享任何计算,因此您将重复大量计算。例如,在计算时,您将进行至少三次单独的ldist s1 s2 (5,5)
计算(一次直接,一次通过,一次通过)。ldist s1 s2 (4,4)
ldist s1 s2 (4,5)
ldist s1 s2 (5,4)
您应该做的是定义一个生成矩阵的算法(如果您愿意,可以作为列表的列表)。我认为这就是您的第二段代码正在做的事情,但它似乎专注于以自上而下的方式计算矩阵,而不是以归纳风格干净地构建矩阵(基本情况中的递归调用非常不寻常)在我看来)。不幸的是,我没有时间写出整个内容,但值得庆幸的是其他人有:查看此地址的第一个版本:http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Haskell
还有两件事:第一,我不确定 Levenshtein 算法是否只能使用矩阵的一部分,因为每个条目都依赖于对角线、垂直和水平邻居。当您需要一个角的值时,您将不可避免地必须一直计算矩阵到另一个角。其次,该match | foo = True | otherwise = False
行可以简单地替换为match = foo
.