Understanding State Monad

I have tried to understand, how the bind function (>>=) for the State monad is defined, for some time now.

I have read several explanations among others Learn You a Haskell, which has great explanations, but in this case it wasn’t enough. I have also looked at Three Useful Monads which tries to visualize difficult concepts in Haskell. Finally I read Real World Haskell and maybe I get it now. I’m not sure but maybe.

Bind is defined as follows.

(>>=) :: State s a -> (a -> State s b) -> State s b

First you need to understand that State s a is really a state transformer and not regular state. State s a type is defined as follows.

State s a :: s -> (a, s)

In other words the state transformer is a function that takes one state transformer and returns a new state transformer and a result of type a.

So the bind function above really takes one state transformer (lets call it trans1) and a function (lets call it makeTrans) that takes a result from trans1 (lets call it res1) and creates a new state transformer (lets call it trans2). The resulting state transformer should be a binding of trans1 and trans2 as a single state transformer trans3.

If we call the initial state s0, the intermediate state s1 (after trans1) and the resulting state s2, the new state transformer should transform between s0 and s2 and give the result from trans2 i.e. s0 -> (res2, s2).

(>>=) step1 makeStep = \s0 -> (res2, s2) 
      (res1, s1) = trans1 s0
      trans2      = makeTrans res1
      (res2, s2) = trans2 s1 

This definition can be written as it is defined in the standard library.

(>>=) m k = \s -> let (a, s') = m s in (k a) s'

I thought it was easier if I could visualize this as boxes.

<!--Generated by ySVG 2.5--> trans1 s0 trans2 s1 (res2, s2) makeTrans trans3 res1 "creates" \begin{svg} <svg xmlns="" xmlns:xlink="" fill-opacity="1" color-rendering="auto" color-interpolation="auto" stroke="black" text-rendering="auto" stroke-linecap="square" width="518" stroke-miterlimit="10" stroke-opacity="1" shape-rendering="auto" fill="black" stroke-dasharray="none" font-weight="normal" stroke-width="1" height="204" font-family="'Dialog'" font-style="normal" stroke-linejoin="miter" font-size="12" stroke-dashoffset="0" image-rendering="auto"> <!--Generated by ySVG 2.5--> <defs id="genericDefs"/> <g> <defs id="defs1"> <clipPath clipPathUnits="userSpaceOnUse" id="clipPath1"> <path d="M0 0 L518 0 L518 204 L0 204 L0 0 Z"/> </clipPath> <clipPath clipPathUnits="userSpaceOnUse" id="clipPath2"> <path d="M83 56 L601 56 L601 260 L83 260 L83 56 Z"/> </clipPath> </defs> <g fill="white" text-rendering="geometricPrecision" shape-rendering="geometricPrecision" transform="translate(-83,-56)" stroke="white"> <rect x="83" width="518" height="204" y="56" clip-path="url(#clipPath2)" stroke="none"/> </g> <g text-rendering="geometricPrecision" stroke-miterlimit="1.45" shape-rendering="geometricPrecision" transform="matrix(1,0,0,1,-83,-56)" stroke-linecap="butt"> <rect fill="none" x="152" width="331" height="147" y="98" clip-path="url(#clipPath2)"/> </g> <g fill="rgb(255,204,0)" text-rendering="geometricPrecision" shape-rendering="geometricPrecision" transform="matrix(1,0,0,1,-83,-56)" stroke="rgb(255,204,0)"> <rect x="186" width="52" height="30" y="186" clip-path="url(#clipPath2)" stroke="none"/> </g> <g text-rendering="geometricPrecision" stroke-miterlimit="1.45" shape-rendering="geometricPrecision" font-family="sans-serif" transform="matrix(1,0,0,1,-83,-56)" stroke-linecap="butt"> <text x="193.4111" xml:space="preserve" y="205.5352" clip-path="url(#clipPath2)" stroke="none">trans1</text> <rect fill="none" x="186" width="52" height="30" y="186" clip-path="url(#clipPath2)"/> <text x="102.6475" xml:space="preserve" y="205.0352" clip-path="url(#clipPath2)" stroke="none">s0</text> </g> <g fill="rgb(255,204,0)" text-rendering="geometricPrecision" shape-rendering="geometricPrecision" transform="matrix(1,0,0,1,-83,-56)" stroke="rgb(255,204,0)"> <rect x="372" width="52" height="30" y="186.5" clip-path="url(#clipPath2)" stroke="none"/> </g> <g text-rendering="geometricPrecision" stroke-miterlimit="1.45" shape-rendering="geometricPrecision" font-family="sans-serif" transform="matrix(1,0,0,1,-83,-56)" stroke-linecap="butt"> <text x="379.4111" xml:space="preserve" y="206.0352" clip-path="url(#clipPath2)" stroke="none">trans2</text> <rect fill="none" x="372" width="52" height="30" y="186.5" clip-path="url(#clipPath2)"/> <text x="298.1475" xml:space="preserve" y="206.0352" clip-path="url(#clipPath2)" stroke="none">s1</text> <text x="525.2979" xml:space="preserve" y="206.0352" clip-path="url(#clipPath2)" stroke="none">(res2, s2)</text> </g> <g fill="rgb(255,204,0)" text-rendering="geometricPrecision" shape-rendering="geometricPrecision" transform="matrix(1,0,0,1,-83,-56)" stroke="rgb(255,204,0)"> <rect x="290" width="86" height="30" y="125" clip-path="url(#clipPath2)" stroke="none"/> </g> <g text-rendering="geometricPrecision" stroke-miterlimit="1.45" shape-rendering="geometricPrecision" font-family="sans-serif" transform="matrix(1,0,0,1,-83,-56)" stroke-linecap="butt"> <text x="300.8906" xml:space="preserve" y="144.5352" clip-path="url(#clipPath2)" stroke="none">makeTrans</text> <rect fill="none" x="290" width="86" height="30" y="125" clip-path="url(#clipPath2)"/> <text x="299.9111" xml:space="preserve" y="85.5352" clip-path="url(#clipPath2)" stroke="none">trans3</text> <path fill="none" d="M121 200.5561 L175.0001 200.8195" clip-path="url(#clipPath2)"/> <path d="M186 200.8732 L174.0294 194.8147 L173.9709 206.8146 Z" clip-path="url(#clipPath2)" stroke="none"/> <path fill="none" d="M238 201.1398 L290 201.4194" clip-path="url(#clipPath2)"/> <path fill="none" d="M320 201.5 L361 201.5" clip-path="url(#clipPath2)"/> <path d="M372 201.5 L360 195.5 L360 207.5 Z" clip-path="url(#clipPath2)" stroke="none"/> <path fill="none" d="M424 201.5 L519 201.5" clip-path="url(#clipPath2)"/> <path fill="none" d="M238.0021 195.5498 L258 195.5498 L258 140 L279 140" clip-path="url(#clipPath2)"/> <path d="M290 140 L278 134 L278 146 Z" clip-path="url(#clipPath2)" stroke="none"/> <text x="215.3496" xml:space="preserve" y="144.5352" clip-path="url(#clipPath2)" stroke="none">res1</text> <path fill="none" stroke-dasharray="6,2" d="M376 140 L398 140 L398 175.4546" clip-path="url(#clipPath2)"/> <path stroke-dasharray="6,2" d="M398 186.4546 L404 174.4546 L392 174.4546 Z" clip-path="url(#clipPath2)" stroke="none"/> <text x="402.6875" xml:space="preserve" y="144.5352" clip-path="url(#clipPath2)" stroke="none">"creates"</text> </g> </g> </svg> \end{svg}\includegraphics[width=32px]{myfig}

Motivations for State Monad and monads in general

I have also had a hard time to understand what’s the big deal about state monads.

This is actually really well explained in Tasteful stateful computations(LYaH), but I will give my angle anyhow.

Consider two stateful computations (push and pop) not using the state monad (from LYaH).

type Stack = [Int]

pop :: Stack -> (Int,Stack)
pop (x:xs) = (x,xs)

push :: Int -> Stack -> ((),Stack)
push a xs = ((),a:xs)

The definitions of push and pop are very clear and understandable when the state is explicitely stated as input and output parameters.

The problem arise when you want to compose functions several of these stateful computations in a single stateful computation.

stackManip :: Stack -> (Int, Stack)
stackManip stack = let
    ((),newStack1) = push 3 stack
    (a ,newStack2) = pop newStack1
    in pop newStack2

Because stateful computations are ordered the intermediate state must explicitely stated in each step which becomes tedius. This is where the State monad comes in.

The state monad allows a state to be manipulated without explicitely naming each state in a state transformer.

When using state monad the stateful computations above is defined as state transformers.

pop :: State Stack Int pop = State $ (x:xs) -> (x,xs)

push :: Int -> State Stack () push a = State $ \xs -> ((),a:xs)

The push and pop operations are state transformers that takes an input state (the stack implemented as a regular list), transforms the state and possibly output a value.

The difference of using these transformers shows when using them in a composed stateful computation. The stackManip is translated to the composed state transformer below. Note that each state manipulation doesn’t need to explicitely define sub-states because this is (generally) taken care of by the bind (>>=) operation for State monads.

stackManip :: State Stack Int
stackManip = do
    push 3
    a <- pop

The actual state transform functions is hidden behind the runState parameter of the State type.

newtype State s a = State { runState :: s -> (a,s) }

To perform the computation you just need to call the stateful computation with the input state as input.

ghci> runState stackManip [5,8,2,1]  

So, using State monad removes the need to explicitely state each sub-state in a composed stateful computation.

Comparison with other monads

Reader and Writer monads have similarities with the State monad.

The Reader monad is a function where the input value is implicit. You can also see it as being able to create functions where a common variable (or environment) is always available in each of the monad transformations building up a composite function.

instance Monad ((->) r) where
    return x = \_ -> x
    h >>= f = \w -> f (h w) w

In the example from LYaH: Reader? Ugh, not this joke again a set of functions are applied on an implicit input value.

addStuff :: Int -> Int
addStuff = do
    a <- (*2)
    b <- (+10)
    return (a+b)

The same composite function with an explicit input parameter looks like the following.

addStuff :: Int -> Int
addStuff x = let
    a = (*2) x
    b = (+10) x
    in a+b

This is similar to State monad in that the input state is implicit in each step of a composite state transformation.

The State monad have a way to implicitely change the state in each composite transformation. The Writer monad also have an implicit value (e.g. a log value) which can be modified in each transformation.

instance (Monoid w) => Monad (Writer w) where
    return x = Writer (x, mempty)
    (Writer (x,v)) >>= f = let (Writer (y, v')) = f x in Writer (y, v `mappend` v')

However while the State monad can apply arbitrary transformations to the implicit state, the Writer monad only allow appending to the value.

So you could use the State monad to achieve both the purpose of Reader monad and the Writer monad. The state of the State monad could be used as an environment, used in the Reader monad, and the state transformations of the state could be used to keep a log for the computations, used in the Writer monad. Reader and Writer are however more restrictive than State and you can select only one of the two aspects when you don’t need both.

The IO monad could also be compared with the State monad. IO operations could be seen as state transformations with the state being the environment outside of the program (aka “world”). A read operation read from the world state and a write operation write to the world state.

The main function of a Haskell program can in theory be seens as running the state transformation included in the program (main) with the world as input state.

runState main world

The output state (world') has all the changes (e.g. write operations) that this program made to the world.

It is especially userful to have an implicit state when working with the world as state. It would be impractical to explicitely state the world as a valie when the format and shape isn’t defined in practice. It is a mere abstraction that explains how an haskell program relates to the world during execution.

Understanding bind for functions

First step is to understand how bind (>>=) function works for functions.

instance Monad ((->) r) where  
  return x = \_ -> x  
  h >>= f = \w -> f (h w) w 

And some examples using the bind function compared with Maybe monad.

Just 1 >>= (\a -> Just (a+2)) >>= (\b -> Just (b+3))    -- not using previous results
Just 6
Just 1 >>= (\a -> (Just (a+2) >>= (\b -> Just (a+b))))  -- using previous results 
Just 4

let f = (*2) >>= (\a -> (+10)) >>= (return)             -- (*2) is not used
f 3
let f = (*2) >>= (\a -> (+10) >>= (\b -> return (a+b)))
f 3