Mémorisation à Haskell?

136

Tout pointeur sur la façon de résoudre efficacement la fonction suivante dans Haskell, pour les grands nombres (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

J'ai vu des exemples de mémorisation dans Haskell pour résoudre des nombres de fibonacci, qui impliquaient de calculer (paresseusement) tous les nombres de fibonacci jusqu'au n requis. Mais dans ce cas, pour un n donné, il suffit de calculer très peu de résultats intermédiaires.

Merci

Ange de Vicente
la source
110
Seulement dans le sens où c'est un travail que je fais à la maison :-)
Angel de Vicente

Réponses:

256

Nous pouvons le faire très efficacement en créant une structure que nous pouvons indexer en temps sous-linéaire.

Mais d'abord,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Définissons f, mais faisons-lui utiliser la «récursion ouverte» plutôt que de s'appeler directement.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

Vous pouvez obtenir un démo fen utilisantfix f

Cela vous permettra de tester cela ffait ce que vous voulez dire pour de petites valeurs de fen appelant, par exemple:fix f 123 = 144

Nous pourrions mémoriser cela en définissant:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

Cela fonctionne passablement bien et remplace ce qui allait prendre du temps O (n ^ 3) par quelque chose qui mémorise les résultats intermédiaires.

Mais il faut toujours un temps linéaire pour indexer pour trouver la réponse mémorisée mf. Cela signifie que des résultats comme:

*Main Data.List> faster_f 123801
248604

sont tolérables, mais le résultat ne s'adapte pas beaucoup mieux que cela. On peut faire mieux!

Tout d'abord, définissons un arbre infini:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

Et puis nous définirons un moyen de s'y indexer, afin que nous puissions trouver un nœud avec un index nen temps O (log n) à la place:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... et nous pouvons trouver un arbre plein de nombres naturels pour être pratique pour ne pas avoir à jouer avec ces indices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Puisque nous pouvons indexer, vous pouvez simplement convertir un arbre en liste:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

Vous pouvez vérifier le travail jusqu'à présent en vérifiant que toList natsvous[0..]

Maintenant,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

fonctionne comme avec la liste ci-dessus, mais au lieu de prendre un temps linéaire pour trouver chaque nœud, peut le poursuivre en temps logarithmique.

Le résultat est considérablement plus rapide:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

En fait, c'est tellement plus rapide que vous pouvez passer et remplacer Intpar Integerci-dessus et obtenir des réponses ridiculement volumineuses presque instantanément

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Edward KMETT
la source
3
J'ai essayé ce code et, fait intéressant, f_faster semblait être plus lent que f. Je suppose que ces références de liste ont vraiment ralenti les choses. La définition des nats et des index me paraissait assez mystérieuse, j'ai donc ajouté ma propre réponse qui pourrait clarifier les choses.
Pitarou
5
Le cas de la liste infinie doit traiter une liste chaînée de 111111111 éléments. Le cas de l'arborescence traite du log n * le nombre de nœuds atteints.
Edward KMETT
2
c'est-à-dire que la version liste doit créer des thunks pour tous les nœuds de la liste, alors que la version arborescente évite d'en créer beaucoup.
Tom Ellis
7
Je sais que c'est un article plutôt ancien, mais ne devrait pas f_treeêtre défini dans une whereclause pour éviter d'enregistrer des chemins inutiles dans l'arborescence entre les appels?
dfeuer
17
La raison de le mettre dans un CAF était que vous pouviez obtenir une mémorisation à travers les appels. Si j'avais un appel coûteux que je mémorisais, je le laisserais probablement dans un CAF, d'où la technique illustrée ici. Dans une application réelle, il y a bien sûr un compromis entre les avantages et les coûts de la mémorisation permanente. Cependant, étant donné que la question portait sur la façon de réaliser la mémorisation, je pense qu'il serait trompeur de répondre avec une technique qui évite délibérément la mémorisation entre les appels, et si rien d'autre, ce commentaire ici indiquera aux gens qu'il y a des subtilités. ;)
Edward KMETT
17

La réponse d'Edward est un bijou si merveilleux que je l'ai dupliquée et fourni des implémentations memoListet des memoTreecombinateurs qui mémorisent une fonction sous forme ouverte-récursive.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Tom Ellis
la source
12

Ce n'est pas le moyen le plus efficace, mais mémorise:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

lors de la demande f !! 144, il est vérifié qu'il f !! 143existe, mais sa valeur exacte n'est pas calculée. Il est toujours défini comme un résultat inconnu d'un calcul. Les seules valeurs exactes calculées sont celles nécessaires.

Donc, au départ, pour ce qui a été calculé, le programme ne sait rien.

f = .... 

Lorsque nous faisons la demande f !! 12, il commence à faire une correspondance de modèle:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, il commence à calculer

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

Cela fait récursivement une autre demande sur f, donc nous calculons

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Maintenant, nous pouvons en remonter

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuant à ruisseler:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, nous continuons notre calcul de f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Maintenant, nous continuons notre calcul de f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Ce qui signifie que le programme sait maintenant:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

Le calcul se fait donc assez paresseusement. Le programme sait qu'une valeur de f !! 8existe, à laquelle elle est égale g 8, mais il n'a aucune idée de ce que g 8c'est.

rampion
la source
Merci pour celui-ci. Comment créeriez-vous et utiliseriez-vous un espace de solution à 2 dimensions? Serait-ce une liste de listes? etg n m = (something with) f!!a!!b
vikingsteve
1
Bien sûr, vous pourriez. Pour une vraie solution, cependant, j'utiliserais probablement une bibliothèque de mémorisation
rampion
C'est O (n ^ 2) malheureusement.
Qumeric
8

Ceci est un addendum à l'excellente réponse d'Edward Kmett.

Quand j'ai essayé son code, les définitions de natset indexme paraissaient assez mystérieuses, alors j'écris une version alternative que j'ai trouvée plus facile à comprendre.

Je définis indexet natsen termes de index'et nats'.

index' t nest défini sur la plage [1..]. (Rappelez-vous qui index test défini sur la plage [0..].) Cela fonctionne recherche l'arbre en traitant ncomme une chaîne de bits et en lisant les bits à l'envers. Si le bit est 1, il prend la branche de droite. Si le bit est 0, il prend la branche de gauche. Il s'arrête lorsqu'il atteint le dernier bit (qui doit être a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Tout comme il natsest défini pour indexainsi qui index nats n == nest toujours vrai, nats'est défini pour index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Maintenant, natset indexsont simplement nats'et index'mais avec les valeurs décalées de 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
la source
Merci. Je mémorise une fonction multivariée, et cela m'a vraiment aidé à comprendre ce que faisaient vraiment l'index et les nats.
Kittsil
8

Comme indiqué dans la réponse d'Edward Kmett, pour accélérer les choses, vous devez mettre en cache des calculs coûteux et pouvoir y accéder rapidement.

Pour garder la fonction non monadique, la solution consistant à construire un arbre paresseux infini, avec une manière appropriée de l'indexer (comme indiqué dans les articles précédents) remplit cet objectif. Si vous renoncez à la nature non monadique de la fonction, vous pouvez utiliser les conteneurs associatifs standard disponibles dans Haskell en combinaison avec des monades «de type état» (comme State ou ST).

Alors que le principal inconvénient est que vous obtenez une fonction non monadique, vous n'avez plus à indexer la structure vous-même et vous pouvez simplement utiliser des implémentations standard de conteneurs associatifs.

Pour ce faire, vous devez d'abord réécrire votre fonction pour accepter tout type de monade:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

Pour vos tests, vous pouvez toujours définir une fonction qui ne fait aucune mémorisation à l'aide de Data.Function.fix, bien qu'elle soit un peu plus verbeuse:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

Vous pouvez ensuite utiliser State Monad en combinaison avec Data.Map pour accélérer les choses:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

Avec des modifications mineures, vous pouvez adapter le code pour qu'il fonctionne avec Data.HashMap à la place:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Au lieu de structures de données persistantes, vous pouvez également essayer des structures de données mutables (comme le Data.HashTable) en combinaison avec la monade ST:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Par rapport à l'implémentation sans aucune mémorisation, n'importe laquelle de ces implémentations vous permet, pour des entrées énormes, d'obtenir des résultats en micro-secondes au lieu d'avoir à attendre plusieurs secondes.

En utilisant Criterion comme référence, j'ai pu observer que l'implémentation avec Data.HashMap fonctionnait en fait légèrement mieux (environ 20%) que celle Data.Map et Data.HashTable pour lesquelles les horaires étaient très similaires.

J'ai trouvé les résultats du benchmark un peu surprenants. Mon sentiment initial était que le HashTable surpasserait l'implémentation de HashMap car il est modifiable. Il peut y avoir un défaut de performance caché dans cette dernière implémentation.

Quentin
la source
2
GHC fait un très bon travail d'optimisation autour de structures immuables. L'intuition de C ne fonctionne pas toujours.
John Tyree
3

Quelques années plus tard, j'ai regardé cela et j'ai réalisé qu'il y avait un moyen simple de mémoriser cela en temps linéaire en utilisant zipWith et une fonction d'aide:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilatea la propriété pratique que dilate n xs !! i == xs !! div i n.

Donc, en supposant qu'on nous donne f (0), cela simplifie le calcul à

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Ressemblant beaucoup à notre description originale du problème, et donnant une solution linéaire ( sum $ take n fsprendra O (n)).

rampion
la source
2
c'est donc une solution de programmation générative (corécursive?) ou dynamique. Prendre un temps O (1) pour chaque valeur générée, comme le fait habituellement Fibonacci. Génial! Et la solution d'EKMETT est comme le big-Fibonacci logarithmique, atteignant les grands nombres beaucoup plus rapidement, sautant une grande partie des entre-deux. Est-ce à peu près correct?
Will Ness
ou peut-être que c'est plus proche de celui des nombres de Hamming, avec les trois pointeurs arrière dans la séquence qui est produite, et les différentes vitesses pour chacun d'eux qui avancent le long de celle-ci. vraiment jolie.
Will Ness
2

Encore un autre addendum à la réponse d'Edward Kmett: un exemple autonome:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Utilisez-le comme suit pour mémoriser une fonction avec un seul entier arg (par exemple fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Seules les valeurs des arguments non négatifs seront mises en cache.

Pour mettre également en cache les valeurs des arguments négatifs, utilisez memoInt, défini comme suit:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

Pour mettre en cache les valeurs des fonctions avec deux arguments entiers, utilisez memoIntInt, défini comme suit:

memoIntInt f = memoInt (\n -> memoInt (f n))
Neal Young
la source
2

Une solution sans indexation et non basée sur celle d'Edward KMETT.

Je délimite les sous-arbres communs à un parent commun ( f(n/4)est partagé entre f(n/2)et f(n/4)et f(n/6)est partagé entre f(2)et f(3)). En les sauvegardant comme une seule variable dans le parent, le calcul du sous-arbre est effectué une fois.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

Le code ne s'étend pas facilement à une fonction de mémorisation générale (du moins, je ne saurais pas comment le faire), et vous devez vraiment réfléchir à la façon dont les sous-problèmes se chevauchent, mais la stratégie devrait fonctionner pour plusieurs paramètres généraux non entiers. . (Je l'ai pensé pour deux paramètres de chaîne.)

Le mémo est supprimé après chaque calcul. (Encore une fois, je pensais à deux paramètres de chaîne.)

Je ne sais pas si c'est plus efficace que les autres réponses. Chaque recherche ne comporte techniquement qu'une ou deux étapes («Regardez votre enfant ou l'enfant de votre enfant»), mais il peut y avoir beaucoup d'utilisation supplémentaire de la mémoire.

Edit: Cette solution n'est pas encore correcte. Le partage est incomplet.

Edit: Cela devrait être le partage des sous-enfants correctement maintenant, mais j'ai réalisé que ce problème avait beaucoup de partage non trivial: n/2/2/2et n/3/3pourrait être le même. Le problème ne correspond pas bien à ma stratégie.

leewz
la source