State Monadと正格性について

今回の話題は、State MonadはStrictなものを使おうという話である。

StateT s m a

HaskellでState Monadを使う際、一番よく使うのは
transformer packageのStateT s m aだと思う。
transformers: Concrete functor and monad transformers
あるいは、そのwrapper libraryとしてmtlを使うかもしれない。
mtl: Monad classes, using functional dependencies
どちらにしろ実体は同じだ。
このStateTにはLazyなバージョンとStrictなバージョンの2種類がある。
二つともデータ型の定義は同じである。

newtype StateT s m a = StateT { runStateT :: s -> m (a, s ) }

違いはMonad (StateT s m)のインスタンス宣言である。

-- Strict
instance Monad m => Monad (StateT s m) where
    return x = StateT (\s -> return (x, s))
    StateT mv >>= f = StateT $ \s -> do
        (v,s1) <- mv s
        runStateT (f v) s1
-- Lazy
instance Monad m => Monad (StateT s m) where
    return x = StateT (\s -> return (x, s))
    StateT mv >>= f = StateT $ \s -> do
        ~(v,s1) <- mv s
        runStateT (f v) s1

Lazyの方は~(チルダ)パターンという見慣れないものが使われているが、これは概ね以下と同じだ。

    p <- mv s
    let (v,s1) = p

違いは、mv sの結果をバインドするときに、タプルのWHNFまで簡約するか、あるいは遅延させるかという点である。

正格性

ここでState MonadのExampleとして、状態を使って1からnまでの和を求めるプログラムを書いてみよう。

-- Main.hs
import Data.Functor.Identity
import Control.Monad.Trans.State.Lazy (StateT(..), get, put)
main :: IO ()
main = do
  n <- readLn
  let (_,s) = runIdentity (runStateT (sumState n) 0)
  print s

sumState :: Monad m => Int -> StateT Int m ()
sumState n = mapM_ (\i -> modify (+i)) [1..n]

modify :: Monad m => (s -> s) -> StateT s m ()
modify f = do
    v <- get
    put (f v)

しかし、このプログラムをコンパイルして実行してみると非常に遅いことがわかる。

$ ghc -O Main.hs
$ echo 10000000 | time ./Main
50000005000000
        2.27 real         1.41 user         0.72 sys

これはなぜかというと、modify fが正格でないため、Stateに(...((0 + 1) + 2) + ... + n)というthunkが積まれてしまうからである。
この現象はControl.Monad.Trans.State.LazyとControl.Monad.Trans.State.Strictの両方で発生する。
状態を正格に評価するために、modifyの定義を少し変更する。

modify f = do
    v <- get
    put $! f v

($!)は正格な関数適用演算子である。こうすることで、modify (+i)の部分が実行されるたびにStateを正格評価するように思える。
実際、この関数はControl.Monad.Trans.State.(Strict | Lazy).modfy'としてライブラリにも定義されている。

実際、Control.Monad.Trans.State.Strict.StateTではこの変更で期待通りに動く。

$ echo 10000000 | time ./Main
50000005000000
        0.02 real         0.01 user         0.00 sys

しかし、残念ながら、Control.Monad.Trans.State.Lazy.StateTでは上手くいかない。

echo 10000000 | time ./Main
50000005000000
        1.94 real         1.27 user         0.50 sys

また、奇妙なことだが、LazyなStateTでも、baseモナドをIdentityからIOに変えたりすると状態が正格に計算される。

import Control.Monad.Trans.State.Lazy (StateT(..), get, put)
main :: IO ()
main = do
  n <- readLn
  (_,s) <- runStateT (sumState n) 0
  print s
$ echo 10000000 | time ./Main 
50000005000000
        0.02 real         0.01 user         0.00 sys

謎解き

これはなぜだろうか、この謎を解くためにはmodify fがどのようにコンパイルされるかをみていく必要がある。

 modify f 
== get >>= (\v -> put $! (f v)) -- unfold "modify"
== StateT (\s -> return (s,s)) >>= (\v -> put $! (f v))  -- unfold "get"
== StateT (\s -> do  -- unfold (>>=)
            ~(a,s1) <- return (s,s)
            runStateT ((\v -> put $! (f v)) a) s1)
== StateT (\s -> do  -- beta reduction
            ~(a,s1) <- return (s,s)
            runStateT (put $! (f a)) s1)
== StateT (\s -> do -- unfold ($!) and "put"
            ~(a,s1) <- return (s,s)
            runStateT (let !v = f a in StateT (\_ -> return ((),v))) s1)
== StateT (\s -> do -- simplify
            ~(a,s1) <- return (s,s)
            (let !v = f a in return ((),v)))
== StateT (\s -> return (s,s) >>= (\p -> -- desugar do notation
                 let (a, s1) = p in
                 let !v = f a in
                 return ((),v)))
== StateT (\s -> let (a, s1) = (s,s) in -- apply monad law (left identity)
                 let !v = f a in
                 return ((),v)))
== StateT (\s -> let !v = f s in return ((),v)) -- simplify

今の所、特に問題のある部分はない。
しかし、baseモナドによっては問題が明らかになる。
その前にIdentity Monadの定義を復習しておく。

newtype Identity a = Identity { runIdentity :: a }
instance Monad Identity where
  return x = Identity x
  m >>= f = f (runIdentity a)

さて、次の項をIdentityモナド上で簡約してみよう。

(modify f) >> action :: StateT s Identity a
== StateT (\s -> let !v = f s in return ((),v)) >> action 
== StateT (\s -> let !v = f s in return ((),v)) >>= (\_ -> action) -- unfold (>>)
== StateT (\s -> do  -- unfold (>>=)
            ~(_,s1) <- let !v = f s in return ((),v)
            runStateT action s1)
== StateT (\s -> -- desugar do notation
           (let !v = f s in return ((),v)) >>= (\p ->
           let (_,s1) = p in
           runStateT action s1))
== StateT (\s -> -- unfold (>>=)
           (\p ->
           let (_,s1) = p in
           runStateT action s1) (let !v = f s in return ((),v)))
== StateT (\s -> -- beta reduction
           let (_,s1) = let !v = f s in return ((),v) in
           runStateT action s1)

さて、簡約結果を見てみると、 let !v = f s in return ((),v)の部分がさらにlet (_,s1) = ... in ...でくるまれているため、
評価が遅延されていることがわかると思う。この問題はbaseモナドがIOの場合は発生しない。
なぜなら、IO a は RealWorld -> (# RealWorld, a #)という型として評価されるわけだが、(# RealWolrd, a #)というタプルはthunkになり得ないからだ。

ちなみに、Control.Monad.Trans.State.Strictの場合は、以下のようになる。

(modify f) >> action :: StateT s Identity a
== StateT (\s -> let !v = f s in return ((),v)) >> action
== StateT (\s -> let !v = f s in return ((),v)) >>= (\_ -> action)
== StateT (\s -> do
            (_,s1) <- let !v = f s in return ((),v)
            runStateT action s1)
== StateT (\s -> 
           (let !v = f s in return ((),v)) >>= (\(_,s1) ->
           runStateT action s1))
== StateT (\s -> 
           (let !v = f s in Identity ((),v)) >>= (\(_,s1) ->
           runStateT action s1))
== StateT (\s ->
           (\(_,s1) ->
           runStateT action s1) (let !v = f s in Identity ((),v)))
== StateT (\s ->
           case (let !v = f s in Identity ((),v)) of
             (_,s1) -> runStateT action s1)
== StateT (\s ->
           case f s of
             v -> case Identity ((),v) of
               (_,s1) -> runStateT action s1)
== StateT (\s ->
           case f s of
             v -> runStateT action v)

従ってf sの評価はactionの評価より前に行われる。

結論

LazyなStateモナドでmodify'を使うべきではない。使う必要が出てきた場合にはStrictなStateモナドを使おう。

普通はStrictなものを使っておけば良いと思う。LazyなStateモナドが役に立つ場面は本当にあるのだろうか。甚だ疑問である。