Levenshtein 距离的 Haskell 尾递归性能问题

发布于 2024-09-25 09:56:39 字数 2066 浏览 12 评论 0原文

我正在 Haskell 中计算 Levenshtein 距离,并对以下性能感到有点沮丧问题。如果你用 Haskell 最“正常”的方式实现它,就像下面的(dist),一切都很好:

dist :: (Ord a) => [a] -> [a] -> Int
dist s1 s2 = ldist s1 s2 (L.length s1, L.length s2)

ldist :: (Ord a) => [a] -> [a] -> (Int, Int) -> Int
ldist _ _ (0, 0) = 0
ldist _ _ (i, 0) = i
ldist _ _ (0, j) = j
ldist s1 s2 (i+1, j+1) = output
  where output | (s1!!(i)) == (s2!!(j)) = ldist s1 s2 (i, j)
               | otherwise = 1 + L.minimum [ldist s1 s2 (i, j)
                                          , ldist s1 s2 (i+1, j)
                                          , ldist s1 s2 (i, j+1)]

但是,如果你稍微弯曲你的大脑并将其实现为 dist',它的执行速度快得多 (约 10 倍)。

dist' :: (Ord a) => [a] -> [a] -> Int
dist' o1 o2 = (levenDist o1 o2 [[]])!!0!!0 

levenDist :: (Ord a) => [a] -> [a] -> [[Int]] -> [[Int]]
levenDist s1 s2 arr@([[]]) = levenDist s1 s2 [[0]]
levenDist s1 s2 arr@([]:xs) = levenDist s1 s2 ([(L.length arr) -1]:xs)
levenDist s1 s2 arr@(x:xs) = let
    n1 = L.length s1
    n2 = L.length s2
    n_i = L.length arr
    n_j = L.length x
    match | (s2!!(n_j-1) == s1!!(n_i-2)) = True | otherwise = False
    minCost = if match      then (xs!!0)!!(n2 - n_j + 1) 
                            else L.minimum [(1 + (xs!!0)!!(n2 - n_j + 1))
                                          , (1 + (xs!!0)!!(n2 - n_j + 0))
                                          , (1 + (x!!0))
                                          ]
    dist | (n_i > n1) && (n_j > n2)  = arr 
         | n_j > n2  = []:arr `seq` levenDist s1 s2 $ []:arr
         | n_i == 1 = (n_j:x):xs `seq` levenDist s1 s2 $ (n_j:x):xs
         | otherwise = (minCost:x):xs `seq` levenDist s1 s2 $ (minCost:x):xs
    in dist 

我在第一个版本中尝试了所有常见的 seq 技巧,但似乎没有什么可以加快速度。这对我来说有点不满意,因为我预计第一个版本会更快,因为它不需要评估整个矩阵,只需评估它需要的部分。

有谁知道是否有可能让这两个实现以类似的方式执行,或者我只是在后者中获得尾递归优化的好处,因此如果我想要性能就需要忍受它的不可读性?

谢谢, 猎户座

I'm playing around with calculating Levenshtein distances in Haskell, and am a little frustrated with the following performance problem. If you implement it most 'normal' way for Haskell, like below (dist), everything works just fine:

dist :: (Ord a) => [a] -> [a] -> Int
dist s1 s2 = ldist s1 s2 (L.length s1, L.length s2)

ldist :: (Ord a) => [a] -> [a] -> (Int, Int) -> Int
ldist _ _ (0, 0) = 0
ldist _ _ (i, 0) = i
ldist _ _ (0, j) = j
ldist s1 s2 (i+1, j+1) = output
  where output | (s1!!(i)) == (s2!!(j)) = ldist s1 s2 (i, j)
               | otherwise = 1 + L.minimum [ldist s1 s2 (i, j)
                                          , ldist s1 s2 (i+1, j)
                                          , ldist s1 s2 (i, j+1)]

But, if you bend your brain a little and implement it as dist', it executes MUCH faster (about 10x).

dist' :: (Ord a) => [a] -> [a] -> Int
dist' o1 o2 = (levenDist o1 o2 [[]])!!0!!0 

levenDist :: (Ord a) => [a] -> [a] -> [[Int]] -> [[Int]]
levenDist s1 s2 arr@([[]]) = levenDist s1 s2 [[0]]
levenDist s1 s2 arr@([]:xs) = levenDist s1 s2 ([(L.length arr) -1]:xs)
levenDist s1 s2 arr@(x:xs) = let
    n1 = L.length s1
    n2 = L.length s2
    n_i = L.length arr
    n_j = L.length x
    match | (s2!!(n_j-1) == s1!!(n_i-2)) = True | otherwise = False
    minCost = if match      then (xs!!0)!!(n2 - n_j + 1) 
                            else L.minimum [(1 + (xs!!0)!!(n2 - n_j + 1))
                                          , (1 + (xs!!0)!!(n2 - n_j + 0))
                                          , (1 + (x!!0))
                                          ]
    dist | (n_i > n1) && (n_j > n2)  = arr 
         | n_j > n2  = []:arr `seq` levenDist s1 s2 $ []:arr
         | n_i == 1 = (n_j:x):xs `seq` levenDist s1 s2 $ (n_j:x):xs
         | otherwise = (minCost:x):xs `seq` levenDist s1 s2 $ (minCost:x):xs
    in dist 

I've tried all the usual seq tricks in the first version, but nothing seems to speed it up. This is a little unsatisfying for me, because I expected the first version to be faster because it doesn't need to evaluate the entire matrix, only the parts it needs.

Does anyone know if it is possible to get these two implementations to perform similarly, or am I just reaping the benefits of tail-recursion optimizations in the latter, and therefore need to live with its unreadability if I want performance?

Thanks,
Orion

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(5

不爱素颜 2024-10-02 09:56:39

过去,我使用过这个非常简洁的版本以及来自 Wikibooks

distScan :: (Ord a) => [a] -> [a] -> Int
distScan sa sb = last $ foldl transform [0 .. length sa] sb
  where
    transform xs@(x:xs') c = scanl compute (x + 1) (zip3 sa xs xs')
       where
         compute z (c', x, y) = minimum [y + 1, z + 1, x + fromEnum (c' /= c)]

我刚刚使用 Criterion 运行了这个简单的基准测试

test :: ([Int] -> [Int] -> Int) -> Int -> Int
test f n = f up up + f up down + f up half + f down half
  where
    up = [1..n]
    half = [1..div n 2]
    down = reverse up

main = let n = 20 in defaultMain
  [ bench "Scan" $ nf (test distScan) n
  , bench "Fast" $ nf (test dist') n
  , bench "Slow" $ nf (test dist) n
  ]

维基教科书的版本明显击败了你们的版本:

benchmarking Scan
collecting 100 samples, 51 iterations each, in estimated 683.7163 ms...
mean: 137.1582 us, lb 136.9858 us, ub 137.3391 us, ci 0.950

benchmarking Fast
collecting 100 samples, 11 iterations each, in estimated 732.5262 ms...
mean: 660.6217 us, lb 659.3847 us, ub 661.8530 us, ci 0.950...

Slow 几分钟后仍在运行。

In the past I've used this very concise version with foldl and scanl from Wikibooks:

distScan :: (Ord a) => [a] -> [a] -> Int
distScan sa sb = last $ foldl transform [0 .. length sa] sb
  where
    transform xs@(x:xs') c = scanl compute (x + 1) (zip3 sa xs xs')
       where
         compute z (c', x, y) = minimum [y + 1, z + 1, x + fromEnum (c' /= c)]

I just ran this simple benchmark using Criterion:

test :: ([Int] -> [Int] -> Int) -> Int -> Int
test f n = f up up + f up down + f up half + f down half
  where
    up = [1..n]
    half = [1..div n 2]
    down = reverse up

main = let n = 20 in defaultMain
  [ bench "Scan" $ nf (test distScan) n
  , bench "Fast" $ nf (test dist') n
  , bench "Slow" $ nf (test dist) n
  ]

And the Wikibooks version beats both of yours pretty dramatically:

benchmarking Scan
collecting 100 samples, 51 iterations each, in estimated 683.7163 ms...
mean: 137.1582 us, lb 136.9858 us, ub 137.3391 us, ci 0.950

benchmarking Fast
collecting 100 samples, 11 iterations each, in estimated 732.5262 ms...
mean: 660.6217 us, lb 659.3847 us, ub 661.8530 us, ci 0.950...

Slow is still running after a couple of minutes.

橘味果▽酱 2024-10-02 09:56:39

要计算长度,您需要评估整个列表。这是一个昂贵的 O(n) 操作。更重要的是,此后该列表将保留在内存中,直到您停止引用该列表(=>更大的内存占用)。经验法则是,如果预计列表很长,则不要在列表上使用 length(!!)也是如此,每次都是从列表的头部开始,所以也是O(n)。列表并非设计为随机访问数据结构。

使用 Haskell 列表的更好方法是部分使用它们。 折叠通常是解决类似问题的方法。编辑距离可以通过这种方式计算(请参阅下面的链接)。不知道有没有更好的算法。

另一种方法是使用不同的数据结构,而不是列表。例如,如果您需要随机访问、已知长度等,请查看 Data.Sequence.Seq

现有实现

第二种方法已用于此实现 Haskell 中的 Levenschtein 距离(使用数组)。您可以在第一个评论中找到基于 foldl 的实现。顺便说一句,foldl' 通常比 foldl 更好。

To calculate length you need to evaluate the whole list. It is an expensive, O(n), operation. And what's more important, after that the list will be kept in-memory until you stop referencing the list (=> bigger memory footprint). The rule of thumb is not to use length on lists if lists are expected to be long. The same refers to (!!), it goes from the very head of the list every time, so it is O(n) too. Lists are not designed as a random-access data structure.

Better approach with Haskell lists is to consume them partially. Folds are usually the way to go in similar problems. And Levenshtein distance can be calculated that way (see a link below). I don't know if there are better algorithms.

Another approach is to use a different data structure, not lists. For example, if you need random access, known length etc. take a look at Data.Sequence.Seq.

Existing implementations

The second approach has been used in this implementation of the Levenschtein distance in Haskell (using arrays). You can find foldl-based implementation in the first comment there. BTW, foldl' is usually better than foldl.

末蓝 2024-10-02 09:56:39

我还没有完全理解您的第二次尝试,但据我记得 Levenshtein 算法背后的想法是通过使用矩阵来节省重复计算。在第一段代码中,您没有共享任何计算,因此您将重复大量计算。例如,在计算 ldist s1 s2 (5,5) 时,您将至少单独计算三次 ldist s1 s2 (4,4)(一次直接,一次通过 ldist s1 s2 (4,5),一次通过 ldist s1 s2 (5,4))。

您应该做的是定义一个生成矩阵的算法(如果您愿意,可以作为列表的列表)。我认为这就是您的第二段代码正在做的事情,但它似乎专注于以自上而下的方式计算矩阵,而不是以归纳风格干净地构建矩阵(基本情况中的递归调用非常不寻常)在我看来)。不幸的是,我没有时间写出全部内容,但值得庆幸的是其他人有:查看以下地址的第一个版本:http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Haskell

还有两件事:第一,我不确定 Levenshtein 算法是否可以使用无论如何,这只是矩阵的一部分,因为每个条目都依赖于对角线、垂直和水平邻居。当您需要一个角的值时,您将不可避免地必须一直计算矩阵到另一个角。其次,匹配 | foo = 真 | else = False 行可以简单地替换为 match = foo

I don't follow all of your second attempt just yet, but as far as I recall the idea behind the Levenshtein algorithm is to save repeated calculation by using a matrix. In the first piece of code, you are not sharing any calculation and thus you will be repeating lots of calculations. For example, when calculating ldist s1 s2 (5,5) you'll make the calculation for ldist s1 s2 (4,4) at least three separate times (once directly, once via ldist s1 s2 (4,5), once via ldist s1 s2 (5,4)).

What you should do is define an algorithm for generating the matrix (as a list of lists, if you like). I think this is what your second piece of code is doing, but it seems to focus on calculating the matrix in a top-down manner rather than building up the matrix cleanly in an inductive style (the recursive calls in the base case are quite unusual to my eye). Unfortunately I don't have time to write out the whole thing, but thankfully someone else has: look at the first version at this address: http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#Haskell

Two more things: one, I'm not sure the Levenshtein algorithm can ever use only part of the matrix anyway, as each entry is dependent on the diagonal, vertical and horizontal neighbour. When you need the value for one corner, you'll inevitably have to evaluate the matrix all the way to the other corner. Secondly, that match | foo = True | otherwise = False line can be replaced by simply match = foo.

花开雨落又逢春i 2024-10-02 09:56:39

可以有 O(N*d) 算法,其中 d 是编辑距离。这是 Lloyd 在 Lazy ML 中的实现 Allison 利用惰性来提高复杂性。这是通过仅计算矩阵的一部分来实现的,即主对角线周围的区域,其宽度与编辑距离成正比。

编辑:我刚刚注意到这已被翻译向 haskell 提供一个漂亮的图像,显​​示计算了矩阵的哪些元素。当序列非常相似时,这应该比上述实现快得多。使用上述基准:

benchmarking Scan
collecting 100 samples, 100 iterations each, in estimated 1.410004 s
mean: 141.8836 us, lb 141.4112 us, ub 142.5126 us, ci 0.950

benchmarking LAllison.d
collecting 100 samples, 169 iterations each, in estimated 1.399984 s
mean: 82.93505 us, lb 82.75058 us, ub 83.19535 us, ci 0.950

It is possible to have an O(N*d) algorithm, where d is the Levenshtein distance. Here's a implementation in Lazy ML by Lloyd Allison which exploits laziness to achieve the improved complexity. This works by only computing part of the matrix, that is, a region around the main diagonal that is proportional in width to the Levenshtein distance.

Edit: I just noticed this has been translated to haskell with a nice image showing which elements of the matrix are computed. This should be significantly faster than the above implementations when the sequences are quite similar. Using the above benchmark:

benchmarking Scan
collecting 100 samples, 100 iterations each, in estimated 1.410004 s
mean: 141.8836 us, lb 141.4112 us, ub 142.5126 us, ci 0.950

benchmarking LAllison.d
collecting 100 samples, 169 iterations each, in estimated 1.399984 s
mean: 82.93505 us, lb 82.75058 us, ub 83.19535 us, ci 0.950
清晨说晚安 2024-10-02 09:56:39

使用 data-memocombinators 包的更直观的解决方案。归功于这个答案。欢迎使用基准测试,因为这里提供的所有解决方案似乎都比 python-Levenshtein< 慢得多/a>,大概是用 C 编写的。请注意,我尝试用字符数组代替字符串,但没有效果。

import Data.MemoCombinators (memo2, integral)

levenshtein :: String -> String -> Int
levenshtein a b = levenshtein' (length a) (length b) where
  levenshtein' = memo2 integral integral levenshtein'' where
    levenshtein'' x y -- take x characters from a and y characters from b
      | x==0 = y
      | y==0 = x
      | a !! (x-1) == b !! (y-1) = levenshtein' (x-1) (y-1)
      | otherwise = 1 + minimum [ levenshtein' (x-1) y, 
        levenshtein' x (y-1), levenshtein' (x-1) (y-1) ]

A more intuitive solution using the data-memocombinators package. Credit goes to this answer. Benchmarks are welcome, as all solutions presented here appear to be much, much slower than python-Levenshtein, which was presumably written in C. Note that I tried substituting arrays of chars instead of strings to no effect.

import Data.MemoCombinators (memo2, integral)

levenshtein :: String -> String -> Int
levenshtein a b = levenshtein' (length a) (length b) where
  levenshtein' = memo2 integral integral levenshtein'' where
    levenshtein'' x y -- take x characters from a and y characters from b
      | x==0 = y
      | y==0 = x
      | a !! (x-1) == b !! (y-1) = levenshtein' (x-1) (y-1)
      | otherwise = 1 + minimum [ levenshtein' (x-1) y, 
        levenshtein' x (y-1), levenshtein' (x-1) (y-1) ]
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文