State Monadと正格性について
今回の話題は、State MonadはStrictなものを使おうという話である。
StateT s m a
HaskellでState Monadを使う際、一番よく使うのは
transformer packageのStateT s m aだと思う。
transformers: Concrete functor and monad transformers
あるいは、そのwrapper libraryとしてmtlを使うかもしれない。
mtl: Monad classes, using functional dependencies
どちらにしろ実体は同じだ。
このStateTにはLazyなバージョンとStrictなバージョンの2種類がある。
二つともデータ型の定義は同じである。
newtype StateT s m a = StateT { runStateT :: s -> m (a, s ) }
違いはMonad (StateT s m)のインスタンス宣言である。
-- Strict instance Monad m => Monad (StateT s m) where return x = StateT (\s -> return (x, s)) StateT mv >>= f = StateT $ \s -> do (v,s1) <- mv s runStateT (f v) s1
-- Lazy instance Monad m => Monad (StateT s m) where return x = StateT (\s -> return (x, s)) StateT mv >>= f = StateT $ \s -> do ~(v,s1) <- mv s runStateT (f v) s1
Lazyの方は~(チルダ)パターンという見慣れないものが使われているが、これは概ね以下と同じだ。
p <- mv s let (v,s1) = p
違いは、mv sの結果をバインドするときに、タプルのWHNFまで簡約するか、あるいは遅延させるかという点である。
正格性
ここでState MonadのExampleとして、状態を使って1からnまでの和を求めるプログラムを書いてみよう。
-- Main.hs import Data.Functor.Identity import Control.Monad.Trans.State.Lazy (StateT(..), get, put) main :: IO () main = do n <- readLn let (_,s) = runIdentity (runStateT (sumState n) 0) print s sumState :: Monad m => Int -> StateT Int m () sumState n = mapM_ (\i -> modify (+i)) [1..n] modify :: Monad m => (s -> s) -> StateT s m () modify f = do v <- get put (f v)
しかし、このプログラムをコンパイルして実行してみると非常に遅いことがわかる。
$ ghc -O Main.hs $ echo 10000000 | time ./Main 50000005000000 2.27 real 1.41 user 0.72 sys
これはなぜかというと、modify fが正格でないため、Stateに(...((0 + 1) + 2) + ... + n)というthunkが積まれてしまうからである。
この現象はControl.Monad.Trans.State.LazyとControl.Monad.Trans.State.Strictの両方で発生する。
状態を正格に評価するために、modifyの定義を少し変更する。
modify f = do v <- get put $! f v
($!)は正格な関数適用演算子である。こうすることで、modify (+i)の部分が実行されるたびにStateを正格評価するように思える。
実際、この関数はControl.Monad.Trans.State.(Strict | Lazy).modfy'としてライブラリにも定義されている。
実際、Control.Monad.Trans.State.Strict.StateTではこの変更で期待通りに動く。
$ echo 10000000 | time ./Main 50000005000000 0.02 real 0.01 user 0.00 sys
しかし、残念ながら、Control.Monad.Trans.State.Lazy.StateTでは上手くいかない。
echo 10000000 | time ./Main 50000005000000 1.94 real 1.27 user 0.50 sys
また、奇妙なことだが、LazyなStateTでも、baseモナドをIdentityからIOに変えたりすると状態が正格に計算される。
import Control.Monad.Trans.State.Lazy (StateT(..), get, put) main :: IO () main = do n <- readLn (_,s) <- runStateT (sumState n) 0 print s
$ echo 10000000 | time ./Main 50000005000000 0.02 real 0.01 user 0.00 sys
謎解き
これはなぜだろうか、この謎を解くためにはmodify fがどのようにコンパイルされるかをみていく必要がある。
modify f == get >>= (\v -> put $! (f v)) -- unfold "modify" == StateT (\s -> return (s,s)) >>= (\v -> put $! (f v)) -- unfold "get" == StateT (\s -> do -- unfold (>>=) ~(a,s1) <- return (s,s) runStateT ((\v -> put $! (f v)) a) s1) == StateT (\s -> do -- beta reduction ~(a,s1) <- return (s,s) runStateT (put $! (f a)) s1) == StateT (\s -> do -- unfold ($!) and "put" ~(a,s1) <- return (s,s) runStateT (let !v = f a in StateT (\_ -> return ((),v))) s1) == StateT (\s -> do -- simplify ~(a,s1) <- return (s,s) (let !v = f a in return ((),v))) == StateT (\s -> return (s,s) >>= (\p -> -- desugar do notation let (a, s1) = p in let !v = f a in return ((),v))) == StateT (\s -> let (a, s1) = (s,s) in -- apply monad law (left identity) let !v = f a in return ((),v))) == StateT (\s -> let !v = f s in return ((),v)) -- simplify
今の所、特に問題のある部分はない。
しかし、baseモナドによっては問題が明らかになる。
その前にIdentity Monadの定義を復習しておく。
newtype Identity a = Identity { runIdentity :: a } instance Monad Identity where return x = Identity x m >>= f = f (runIdentity a)
さて、次の項をIdentityモナド上で簡約してみよう。
(modify f) >> action :: StateT s Identity a == StateT (\s -> let !v = f s in return ((),v)) >> action == StateT (\s -> let !v = f s in return ((),v)) >>= (\_ -> action) -- unfold (>>) == StateT (\s -> do -- unfold (>>=) ~(_,s1) <- let !v = f s in return ((),v) runStateT action s1) == StateT (\s -> -- desugar do notation (let !v = f s in return ((),v)) >>= (\p -> let (_,s1) = p in runStateT action s1)) == StateT (\s -> -- unfold (>>=) (\p -> let (_,s1) = p in runStateT action s1) (let !v = f s in return ((),v))) == StateT (\s -> -- beta reduction let (_,s1) = let !v = f s in return ((),v) in runStateT action s1)
さて、簡約結果を見てみると、 let !v = f s in return ((),v)の部分がさらにlet (_,s1) = ... in ...でくるまれているため、
評価が遅延されていることがわかると思う。この問題はbaseモナドがIOの場合は発生しない。
なぜなら、IO a は RealWorld -> (# RealWorld, a #)という型として評価されるわけだが、(# RealWolrd, a #)というタプルはthunkになり得ないからだ。
ちなみに、Control.Monad.Trans.State.Strictの場合は、以下のようになる。
(modify f) >> action :: StateT s Identity a == StateT (\s -> let !v = f s in return ((),v)) >> action == StateT (\s -> let !v = f s in return ((),v)) >>= (\_ -> action) == StateT (\s -> do (_,s1) <- let !v = f s in return ((),v) runStateT action s1) == StateT (\s -> (let !v = f s in return ((),v)) >>= (\(_,s1) -> runStateT action s1)) == StateT (\s -> (let !v = f s in Identity ((),v)) >>= (\(_,s1) -> runStateT action s1)) == StateT (\s -> (\(_,s1) -> runStateT action s1) (let !v = f s in Identity ((),v))) == StateT (\s -> case (let !v = f s in Identity ((),v)) of (_,s1) -> runStateT action s1) == StateT (\s -> case f s of v -> case Identity ((),v) of (_,s1) -> runStateT action s1) == StateT (\s -> case f s of v -> runStateT action v)
従ってf sの評価はactionの評価より前に行われる。