Haskell 中的记忆化?

发布于 2024-09-08 21:34:34 字数 241 浏览 8 评论 0原文

关于如何在 Haskell 中有效求解以下函数的任何指针,对于大数 (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

我已经看过 Haskell 中记忆化求解斐波那契数列的示例 数字,涉及(惰性地)计算所有斐波那契数 直到所需的n。但在这种情况下,对于给定的 n,我们只需要 计算很少的中间结果。

谢谢

Any pointers on how to solve efficiently the following function in Haskell, for large numbers (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

I've seen examples of memoization in Haskell to solve fibonacci
numbers, which involved computing (lazily) all the fibonacci numbers
up to the required n. But in this case, for a given n, we only need to
compute very few intermediate results.

Thanks

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(8

白云不回头 2024-09-15 21:34:34

我们可以通过创建一个可以在亚线性时间内索引的结构来非常有效地做到这一点。

但首先,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

让我们定义f,但让它使用“开放递归”而不是直接调用自身。

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

您可以使用 fix f 获得未记忆的 f

这将让您测试 f 是否符合您对 f 小值的意思。 通过调用,例如: fix f 123 = 144

我们可以通过定义来记住这一点:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

表现还算不错,并取代了将要花费的时间 O(n^3) 花一些时间来记住中间结果。

但仅通过索引来查找 mf 的记忆答案仍然需要线性时间。这意味着像这样的结果

*Main Data.List> faster_f 123801
248604

是可以容忍的,但结果并没有比这更好。我们可以做得更好!

首先,让我们定义一个无限树:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

然后我们将定义一种对其进行索引的方法,这样我们就可以在 O(log n) 中找到索引为 n 的节点时间:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

...我们可能会发现一棵充满自然数的树很方便,因此我们不必摆弄这些索引:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

因为我们可以索引,所以您可以将树转换为列表:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

您可以检查到目前为止,通过验证 toList nats 为您提供 [0..]

现在,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

工作方式与上面的列表类似,但不是花费线性时间来查找每个节点,可以在对数时间内追到它。

结果要快得多:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

事实上,它的速度要快得多,您可以遍历上面的 Int 并将其替换为 Integer ,几乎可以立即获得大得离谱的答案

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

。 - 实现基于树的记忆化的盒子库,使用 MemoTrie

$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

We can do this very efficiently by making a structure that we can index in sub-linear time.

But first,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Let's define f, but make it use 'open recursion' rather than call itself directly.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

You can get an unmemoized f by using fix f

This will let you test that f does what you mean for small values of f by calling, for example: fix f 123 = 144

We could memoize this by defining:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

That performs passably well, and replaces what was going to take O(n^3) time with something that memoizes the intermediate results.

But it still takes linear time just to index to find the memoized answer for mf. This means that results like:

*Main Data.List> faster_f 123801
248604

are tolerable, but the result doesn't scale much better than that. We can do better!

First, let's define an infinite tree:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

And then we'll define a way to index into it, so we can find a node with index n in O(log n) time instead:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... and we may find a tree full of natural numbers to be convenient so we don't have to fiddle around with those indices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Since we can index, you can just convert a tree into a list:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

You can check the work so far by verifying that toList nats gives you [0..]

Now,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

works just like with list above, but instead of taking linear time to find each node, can chase it down in logarithmic time.

The result is considerably faster:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

In fact it is so much faster that you can go through and replace Int with Integer above and get ridiculously large answers almost instantaneously

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

For an out-of-the-box library that implements the tree based memoization, use MemoTrie:

$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
陪你到最终 2024-09-15 21:34:34

Edward 的答案是如此精彩,我复制了它并提供了 memoListmemoTree 组合器,以开放递归形式记忆函数。

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f

Edward's answer is such a wonderful gem that I've duplicated it and provided implementations of memoList and memoTree combinators that memoize a function in open-recursive form.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
他不在意 2024-09-15 21:34:34

这不是最有效的方法,但确实会记住:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

当请求 f !! 144,经检查f!! 143 存在,但未计算其确切值。它仍然被设置为某种未知的计算结果。计算出的唯一精确值是所需的值。

所以最初,对于计算了多少,程序一无所知。

f = .... 

当我们发出请求时 f !! 12,它开始做一些模式匹配:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在它开始计算

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

这递归地对 f 提出了另一个要求,所以我们计算

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

现在我们可以滴流备份一些

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

这意味着程序现在知道:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

继续滴流向上:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

这意味着程序现在知道:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算 f!!6

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

意味着程序现在知道:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算 f!!12

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

这 意味着程序现在知道:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

所以计算是相当懒惰地完成的。程序知道 f 的某个值!! 8 存在,它等于 g 8,但它不知道 g 8 是什么。

Not the most efficient way, but does memoize:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

when requesting f !! 144, it is checked that f !! 143 exists, but its exact value is not calculated. It's still set as some unknown result of a calculation. The only exact values calculated are the ones needed.

So initially, as far as how much has been calculated, the program knows nothing.

f = .... 

When we make the request f !! 12, it starts doing some pattern matching:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now it starts calculating

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

This recursively makes another demand on f, so we calculate

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Now we can trickle back up some

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Which means the program now knows:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuing to trickle up:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Which means the program now knows:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now we continue with our calculation of f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Which means the program now knows:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now we continue with our calculation of f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Which means the program now knows:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

So the calculation is done fairly lazily. The program knows that some value for f !! 8 exists, that it's equal to g 8, but it has no idea what g 8 is.

白日梦 2024-09-15 21:34:34

这是爱德华·克梅特(Edward Kmett)出色答案的附录。

当我尝试他的代码时,natsindex 的定义似乎相当神秘,因此我编写了一个更容易理解的替代版本。

我根据index'nats'定义了indexnats

index' t n 是在 [1..] 范围内定义的。 (回想一下,index t 是在 [0..] 范围内定义的。)它通过将 n 视为字符串来搜索树位,并反向读取这些位。如果该位为1,则采用右侧分支。如果该位为0,则采用左侧分支。当到达最后一位(必须是1)时它会停止。

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

正如为 index 定义 nats 一样,index nats n == n 始终为 true,nats' 为为index'定义。

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

现在,natsindex 只是 nats'index',但值移动了 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'

This is an addendum to Edward Kmett's excellent answer.

When I tried his code, the definitions of nats and index seemed pretty mysterious, so I write an alternative version that I found easier to understand.

I define index and nats in terms of index' and nats'.

index' t n is defined over the range [1..]. (Recall that index t is defined over the range [0..].) It works searches the tree by treating n as a string of bits, and reading through the bits in reverse. If the bit is 1, it takes the right-hand branch. If the bit is 0, it takes the left-hand branch. It stops when it reaches the last bit (which must be a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Just as nats is defined for index so that index nats n == n is always true, nats' is defined for index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Now, nats and index are simply nats' and index' but with the values shifted by 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
GRAY°灰色天空 2024-09-15 21:34:34

正如 Edward Kmett 的回答所述,为了加快速度,您需要缓存昂贵的计算并能够快速访问它们。

为了保持函数非单子,构建无限惰性树的解决方案,并使用适当的方法对其进行索引(如之前的帖子所示)可以实现该目标。如果您放弃函数的非单子性质,则可以将 Haskell 中可用的标准关联容器与“类状态”单子(例如 State 或 ST)结合使用。

虽然主要缺点是您获得了非单子函数,但您不必再自己对结构进行索引,而只需使用关联容器的标准实现即可。

为此,您首先需要重写函数以接受任何类型的 monad:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <
gt; [2, 3, 4]
   return $ max n (sum recs)

对于您的测试,您仍然可以使用 Data.Function.fix 定义一个不进行记忆的函数,尽管它有点冗长

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

:然后可以将 State monad 与 Data.Map 结合使用来加快速度:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <
gt; get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

通过微小的更改,您可以调整代码以与 Data.HashMap 配合使用:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <
gt; get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

除了持久数据结构,您还可以尝试可变数据结构(例如 Data .HashTable)与 ST monad 结合:与

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

没有任何记忆的实现相比,这些实现中的任何一个都允许您在大量输入的情况下在微秒内获得结果,而不必等待几秒钟。

使用 Criterion 作为基准,我可以观察到 Data.HashMap 的实现实际上比 Data.Map 和 Data.HashTable 的执行稍好(大约 20%),两者的时间非常相似。

我发现基准测试的结果有点令人惊讶。我最初的感觉是 HashTable 会优于 HashMap 实现,因为它是可变的。最后的实现中可能隐藏着一些性能缺陷。

As stated in Edward Kmett's answer, to speed things up, you need to cache costly computations and be able to access them quickly.

To keep the function non monadic, the solution of building an infinite lazy tree, with an appropriate way to index it (as shown in previous posts) fulfills that goal. If you give up the non-monadic nature of the function, you can use the standard associative containers available in Haskell in combination with “state-like” monads (like State or ST).

While the main drawback is that you get a non-monadic function, you do not have to index the structure yourself anymore, and can just use standard implementations of associative containers.

To do so, you first need to re-write you function to accept any kind of monad:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <
gt; [2, 3, 4]
   return $ max n (sum recs)

For your tests, you can still define a function that does no memoization using Data.Function.fix, although it is a bit more verbose:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

You can then use State monad in combination with Data.Map to speed things up:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <
gt; get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

With minor changes, you can adapt the code to works with Data.HashMap instead:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <
gt; get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Instead of persistent data structures, you may also try mutable data structures (like the Data.HashTable) in combination with the ST monad:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Compared to the implementation without any memoization, any of these implementation allows you, for huge inputs, to get results in micro-seconds instead of having to wait several seconds.

Using Criterion as benchmark, I could observe that the implementation with the Data.HashMap actually performed slightly better (around 20%) than that the Data.Map and Data.HashTable for which the timings were very similar.

I found the results of the benchmark a bit surprising. My initial feeling was that the HashTable would outperform the HashMap implementation because it is mutable. There might be some performance defect hidden in this last implementation.

温馨耳语 2024-09-15 21:34:34

几年后,我看到了这个,并意识到有一种简单的方法可以使用 zipWith 和辅助函数在线性时间内记住它:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate 具有方便的属性, > 扩大 n xs !!我==xs!! div i n。

因此,假设我们给出 f(0),这将简化计算,

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

看起来很像我们原来的问题描述,并给出一个线性解决方案(sum $ take n fs 将花费 O(n) )。

A couple years later, I looked at this and realized there's a simple way to memoize this in linear time using zipWith and a helper function:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate has the handy property that dilate n xs !! i == xs !! div i n.

So, supposing we're given f(0), this simplifies the computation to

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Looking a lot like our original problem description, and giving a linear solution (sum $ take n fs will take O(n)).

谈下烟灰 2024-09-15 21:34:34

Edward Kmett 答案的另一个附录:一个独立的示例:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

按如下方式使用它来记忆具有单个整数参数的函数(例如斐波那契):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

仅缓存非负参数的值。

要缓存负参数的值,请使用 memoInt,定义如下:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

要缓存具有两个整数参数的函数的值,请使用 memoIntInt,定义如下:

memoIntInt f = memoInt (\n -> memoInt (f n))

Yet another addendum to Edward Kmett's answer: a self-contained example:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Use it as follows to memoize a function with a single integer arg (e.g. fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Only values for non-negative arguments will be cached.

To also cache values for negative arguments, use memoInt, defined as follows:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

To cache values for functions with two integer arguments use memoIntInt, defined as follows:

memoIntInt f = memoInt (\n -> memoInt (f n))
坠似风落 2024-09-15 21:34:34

没有索引的解决方案,并且不基于 Edward KMETT 的解决方案。

我将公共子树分解为公共父树(f(n/4)f(n/2)f(n/4)< 之间共享/code>,并且 f(n/6)f(2)f(3) 之间共享。通过将它们保存为父级中的单个变量,子树的计算只需完成一次。

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

该代码不容易扩展到一般的记忆功能(至少,我不知道如何做到这一点),并且您确实必须考虑子问题如何重叠,但策略应该起作用对于一般的多个非整数参数。 (我想到了两个字符串参数。)

每次计算后都会丢弃备忘录。 (再次,我在考虑两个字符串参数。)

我不知道这是否比其他答案更有效。从技术上讲,每次查找只需一两个步骤(“查看您的孩子或您孩子的孩子”),但可能会使用大量额外的内存。

编辑:此解决方案尚不正确。共享不完整。

编辑:现在应该正确共享子子项,但我意识到这个问题有很多重要的共享:n/2/2/2n/3/3 可能是相同的。这个问题不太适合我的策略。

A solution without indexing, and not based on Edward KMETT's.

I factor out common subtrees to a common parent (f(n/4) is shared between f(n/2) and f(n/4), and f(n/6) is shared between f(2) and f(3)). By saving them as a single variable in the parent, the calculation of the subtree is done once.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

The code doesn't easily extend to a general memoization function (at least, I wouldn't know how to do it), and you really have to think out how subproblems overlap, but the strategy should work for general multiple non-integer parameters. (I thought it up for two string parameters.)

The memo is discarded after each calculation. (Again, I was thinking about two string parameters.)

I don't know if this is more efficient than the other answers. Each lookup is technically only one or two steps ("Look at your child or your child's child"), but there might be a lot of extra memory use.

Edit: This solution isn't correct yet. The sharing is incomplete.

Edit: It should be sharing subchildren properly now, but I realized that this problem has a lot of nontrivial sharing: n/2/2/2 and n/3/3 might be the same. The problem is not a good fit for my strategy.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文