The State Monad

2012-03-09 21:07 UTC
  • Xyne

About

A quick disclaimer: I’m still learning Haskell so I may have misunderstood some things or described them incorrectly. I wrote this in one sitting and there are probably a few typos in the text and maybe even some technical errors. Let me know if you find any. Now on with the show.

I found the state monad difficult to understand at first and from my searches it seems that I’m not the only one. On a superficial level it is conceptually simple: pass around the state to keep your functions pure. If the state itself is passed to a function then there are no side effects because each input will be mapped to a single output and we can trust our function to work as expected. Simple, right?

The difficulty lay in understanding what was really going on under the surface of the state monad. Basically, this section is for anyone who keeps looking at something like this:

rollDie :: State StdGen Int
rollDie = do generator <- get
             let ( value, newGenerator ) = randomR (1,6) generator
             put newGenerator
             return value

and wondering how the get and put just seem to magically pull state out of thin air, alter it, then push it back before returning a value that appears to be unrelated to this pushing and pulling. I’ll run through the key elements to understanding what’s actually happening then I’ll try to put them all together in a way that makes everything clear.

The Code

The code in this section is take from the Haskell wikibook’s state monad page. Note that I’ve changed the names in the definitions of evalState and execState to make it clearer that they accept a state as an argument. I’ve also formatted this code so that it will compile if you want to play around with it.

import System.Random


newtype State state result = State { runState :: state -> (result, state) }

instance Monad (State state_type) where
  --return :: result -> State state result
  return r = State ( \s -> (r, s) )

  --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b
  processor >>= processorGenerator = State $ \state ->
                                     let (result, state') = runState processor state
                                     in runState (processorGenerator result) state'

put newState = State $ \_ -> ((), newState)
get = State $ \state -> (state, state)

evalState stateMonad state = fst ( runState stateMonad state )
execState stateMonad state = snd ( runState stateMonad state )

type GeneratorState = State StdGen

rollDie :: GeneratorState Int
rollDie = do generator <- get
             let ( value, newGenerator ) = randomR (1,6) generator
             put newGenerator
             return value

--evalState rollDie (mkStdGen 0)

The “State Monad” Is A State Transformer

The state monad doesn’t actually hold a state. It holds something that can transform a state. Let’s look at the type:

newtype State state result = State { runState :: state -> (result, state) }

The record syntax adds some sugar that gets in the way, so let’s rewrite the above without it:

newtype State state result = State ( state -> (result, state) )

runState :: (State state result) -> (state -> (result,state))
runState (State f) = f

So, the state monad contains something that takes a state and returns a tuple consisting of a result and a state. runState is just an accessor function to get the function inside the state monad. We need this because the function is wrapped in the State constructor and we don’t want to have to pattern-match against it all the time.

Actually, it’s probably not even right to say that it “contains” a function. It really is a function, but it’s wrapped in a constructor to give it a type that we can work with.

So, what does the state transformer function inside of the state monad do? It takes a state and returns a tuple containing a result and a new state. That’s all.

do Notation Obfuscates The Functional Nature Of The Code

do notation is syntactic sugar and it makes it easier for people from an imperative programming background to get started with monads, but it also obfuscates the functional nature of the code. It tricks you into thinking about the code as though it were imperative and thus it appears that some “assignments” are completely unrelated to others. I believe that this can impede the learning of Haskell.

Let’s look at the section of code that confuses people:

rollDie :: State StdGen Int
rollDie = do generator <- get
             let ( value, newGenerator ) = randomR (1,6) generator
             put newGenerator
             return value

With imperative eyes this looks simple. The type of rollDie is a state monad and walking through it we see that it gets a state, passes it to randomR, gets a random value along with a new state, puts the state back wherever it got it, then returns the value.

OK, so where did the state come from? How did get magically pull the state out of thin air and how did put put it back? The function is pure. It isn’t changing something outside of it self, and it doesn’t accept any arguments. What’s it actually doing?

Let’s look at the get and put functions again:

put newState = State $ \_ -> ((), newState)
get = State $ \state -> (state, state)

get returns a state monad, but looking at the code we see that it takes no arguments and always returns the same thing. It’s not returning some specific state that’s tracking what we’re doing. It’s just a generic state monad with a function that accepts a state and returns a 2-tuple with the unaltered state in both positions. Again, it takes no arguments and always returns the same thing. (Well, not strictly… the type gets marshalled into whatever is needed, but that’s irrelevant here. Read up about type variables and polymorphic types if you want to know more.)

put take a state as an argument and returns a state monad. The function inside the state monad, as we saw above in the ‘The “State Monad” Is A State Transformer’ section, is a function that accepts a state and returns a tuple containing a result and a new state. Looking at the definition of put, we see that this state transforming function ignores its argument (a state) and returns a tuple with an empty result and the state that we passed to put.

It’s not “storing” a state anywhere. It’s actually getting a state monad that can transform all states (of the same type) to a tuple with an empty result and the state that we passed to put.

Be careful not to confuse put with the state transforming function inside the state monad that it returns. put is a function that accepts an argument and returns a state monad. The state monad that it returns contains a function that accepts a state, ignores it, and returns a tuple containing the state that we originally passed to put.

So how does that preserve state and pass it around outside of the function? What is going on here?

It’s time to get rid of some syntactic sugar by removing the do notation. This is what the compiler does behind the scenes when we run our code. Transforming the notation, the rollDie function becomes:

rollDie :: State StdGen Int
rollDie = get >>= \generator ->
          let ( value, newGenerator ) = randomR (1,6) generator
          in put newGenerator >> return value

If the transformation isn’t clear, read up on do notation.

Understanding (>>=)

Now that we have the function expressed in terms of (>>=) and (>>) (the “bind” and “then” operators, respectively), we need to look back at how (>>=) is defined for state monads (you should read up on these operators if you are not familiar with them, but if that’s the case then you might want to come back to state monads later). As we’ve just seen above, get is returning a state monad (not a state), so the first argument to the first (>>=) is indeed a state monad. Let’s change the names in the definition above to make it a bit clearer too:

(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b
stateTransformer >>= stateTransformerGenerator = State $ \state ->
                                   let (result, state') = runState stateTransformer state
                                   in runState (stateTransformerGenerator result) state'

Let’s run through this and make some sense of it. The first argument is a state monad and the second argument is a function that takes a result and returns a state monad that might have a different result type. (>>=) takes these two arguments and returns a new state monad.

The definition begins with State, which is the constructor of the state monad. This just ensures that the following lambda expression, which must be a state transforming function, will be wrapped up in a state monad.

In the second line, runState stateTransformer pulls out the state transforming function from the first argument to (>>=), namely stateTransformer. This is the function that accepts a state and returns a tuple with a result and a new state. The returned function is applied to state and returns the tuple (result, state'), where state' is a new state.

Before continuing, remind yourself where state is coming from. Remember that we are defining a state transforming function that accepts a state and returns a tuple, which we are then wrapping in the state monad. state is therefore a state that will be passed to this function later. It’s saying “When you’re passed a state, take that state and pass it to the state transformer in the state monad of the first argument, then call the result and new state in the tuple result and state', respectively.”. You can think of it as a set of instructions of what to do with a state if it’s passed. They’re not actually being done yet.

In the final line, we take the result from above (result) and we feed it into the second argument of (>>=), namely stateTransformerGenerator. This returns a new state monad and we then pull out the state transforming function in it with our accessor, runState. We now have a function that accepts a state and returns a tuple. We then pass this function the new state generated in the previous line (state') and get the tuple that we’re expecting to complete the definition of a state transformer started by the lambda expression on the first line (\state ->).

So, what do we have at the end? We have a state transformer that takes a state, runs it through one state transformer to get a new one and a result, then it uses the result to get another state transformer from our “state transformer generator”, and then runs the new state through that too to get a third state and a new result. We’re threading the state through a succession of state transformers (as we would thread a string through a string through beads) and using the result of each transformation to determine the next state transformer.

Replacing get and put

Now we need to fit this back into our definition of rollDie to see what it’s doing. Let’s look at the definition again (no need to scroll back):

rollDie :: State StdGen Int
rollDie = get >>= \generator ->
          let ( value, newGenerator ) = randomR (1,6) generator
          in put newGenerator >> return value

The rollDie function will return a state monad that can take a “random number generator state” and return an integer between 1 and 6, inclusive, along with a new “random number generator state”, which can then be threaded through rollDie again to get another number and another new state.

Let’s replace generator and newGenerator, which are the “random number generator states”, with gState and gState'. Let’s also replace get and put newGenerator with what they return.

rollDie :: State StdGen Int
rollDie = State $ \state -> (state, state) >>= \gState ->
          let ( value, gState' ) = randomR (1,6) gState
          in State $ \_ -> ((), gState') >> return value

Make sure that you understand the replacements by looking back at the definitions.

Understanding (>>)

There is still one last thing that we need to know before we can put all of this together. What does (>>) (then) do in this context? You’re probably familiar with it from the IO monad, which let’s you do something like this:

(putStrLn "foo") >> (putStrLn "bar") >> (putStrLn "baz")

or in the more familiar do notation:

do putStrLn "foo"
   putStrLn "bar"
   putStrLn "baz"

With IO, this just prints “foo”, then it prints “bar”, then it prints “baz”, each on a new line. (>>) seems to just sequence the functions and there doesn’t seem to be any connection between them, but this is just a consequence of the definition of (>>) for IO monads. Actually, it’s a consequence of the definition of (>>=).

(>>) is actually defined in terms of (>>=):

m >> n = m >>= \_ -> n

Let’s replace (>>) in our rollDie function with this definition and take another look at it.

rollDie :: State StdGen Int
rollDie = State $ \state -> (state, state) >>= \gState ->
          let ( value, gState' ) = randomR (1,6) gState
          in State $ \_ -> ((), gState') >>= \_ -> return value

So far, so good.

Putting It Together

OK, we’re almost there, It’s time to wrap our heads around this. Here are the definitions of return and (>>=) again so you don’t have to scroll back:

return r = State ( \s -> (r, s) )

(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b
stateTransformer >>= stateTransformerGenerator = State $ \state ->
                                   let (result, state') = runState stateTransformer state
                                   in runState (stateTransformerGenerator result) state'

Looking at the definition of rollDie above in terms of (>>=), we see that it’s defined as x >>= y >>= z, where

x = State $ \state -> (state, state)

y = \gState -> let ( value, gState' ) = randomR (1,6) gState
               in State $ \_ -> ((), gState')

z = \_ -> return value

The monad laws let us break this up as either: (x >>= y) >>= z or x >>= (\w -> y w >>= z). To avoid confusion about how value is getting passed around, we’ll use the latter. This must return a “state transforming function generator”, i.e. something with the type (result_a -> State state result_b) so that it can be passed to x >>=. We can rewrite it as \w -> (y w >>= z) to make it clear how we should apply the bind operator inside the lambda function.

\w -> (y w >>= z) = \w -> State $ \state ->
                          let (result, state') = runState (y w) state
                          in runState (z result) state'

The thing to note here is that the first argument to the bind operator is y w, not y. y takes a state and returns a state monad and the first argument of the bind operator must be a state monad, not a function that generates a state monad from a state (That’s what the second argument to the bind operator is, and that is why we could evaluate x >>= y first).

Let’s step through some evaluations:

y = \gState -> let ( value, gState' ) = randomR (1,6) gState
               in State $ \_ -> ((), gState')

y w = let ( value, gState' ) = randomR (1,6) w
      in State $ \_ -> ((), gState')

runState (y w) = let ( value, gState' ) = randomR (1,6) w
                 in \_ -> ((), gState')

(runState (y w)) state = let ( value, gState' ) = randomR (1,6) w
                         in ((), gState')

runState pulls the state transformer function out of the state monad, which is equivalent to removing the constructor from the definition. Passing state to this function returns a tuple. Because the function was defined as \_ -> ((), gState'), it completely ignores the state that we pass it, but it still returns the tuple ((),gState'). Let’s put this back into our definition above:

\w -> (y w >>= z) = \w -> State $ \state ->
                          let ( value, gState' ) = randomR (1,6) w
                              (result, state') = ((),gState')
                          in runState (z result) state'

Now we step through evaluations starting with z by first substituting the definition of return:

z = \_ -> return value

z = \_ -> State ( \s -> (value, s) )

z result = State ( \s -> (value, s) )

runState (z result) = \s -> (value, s)

(runState (z result)) state' = (value, state')

z is a function that ignores its argument and returns a state monad. That state monad’s contains a function that accepts a state and returns a tuple with a fixed value and an unchanged state. As before, runState unwraps the monad to get the function inside it (that accepts a state and returns a tuple). We then pass state' to that function to get the tuple, which is the fixed value and state', unchanged.

Now we substitute back into our previous expression.

\w -> (y w >>= z) = \w -> State $ \state ->
                          let ( value, gState' ) = randomR (1,6) w
                              (result, state') = ((),gState')
                          in (value, state')

which we can simplify to

\w -> (y w >>= z) = \w -> State $ \_ -> randomR (1,6) w

because state and result are ignored in the definition, and gState' is only used in the definition of state'. As before, work through it to make sure you understand the transformation. Now we substitute into x >>= (\w -> y w >>= z), again using the definition of (>>=):

x >>= (\w -> y w >>= z) = State $ \state ->
                          let (result, state') = runState x state
                          in runState ((\w -> y w >>= z) result) state'

and step through some more evaluations:

\w -> y w >>= z = (\w -> State $ \_ -> randomR (1,6) w)

(\w -> y w >>= z) result =  State $ \_ -> randomR (1,6) result

runState ((\w -> y w >>= z) result) =  \_ -> randomR (1,6) result

(runState ((\w -> y w >>= z) result)) state' =  randomR (1,6) result

and thus:

x >>= (\w -> y w >>= z) = State $ \state ->
                          let (result, state') = runState x state
                          in randomR (1,6) result

Now we just need to run through the evaluations starting with x (using s instead of state in the lambda expression from above to avoid confusion with state here).

x = State $ \s -> (s, s)

runState x = \s -> (s, s)

(runState x) state = (state, state)

and finally we substitute for x and then simplify

x >>= (\w -> y w >>= z) = State $ \state ->
                          let (result, state') = (state,state)
                          in randomR (1,6) result

x >>= (\w -> y w >>= z) = State $ \state -> randomR (1,6) state

Both result and state' are equal to state, so we replace them then remove the tuple assignment, which is redundant. Now remember that x >>= (\w -> y w >>= z) is equal to x >>= y >>= z, which is just our definition of rollDie.

Conclusion

In the end, we’re left with

rollDie = State $ \state -> randomR (1,6) state

Let’s check what the type of randomR (1,6) state will be:

ghci> :m System.Random
ghci> :t randomR
randomR :: (Random a, RandomGen g) => (a, a) -> g -> (a, g)

In our case, Random a is an Int and RandomGen g is our random generator state, so randomR (1,6) state will return a tuple containing an Int and a new random generator state. So our function accepts a state and returns a tuple containing a result and a new state, i.e. a state transformer.

State wraps this state transformer function into a state monad and so our rollDie function returns a state monad, as it’s supposed to.

So, what can we do with it?

evalState rollDie (mkStdGen 0)

mkStdGen 0 returns a new random generator state (0 is seed). evalState, which we saw way up above, takes a state monad and a state, pulls out the state transforming function from the state monad, applies it to the state to get the result and new state tuple, then pulls out the result from the tuple and returns it, which gives us an integer between 1 and 6, inclusive.

Yeah, that isn’t all that exciting. Don’t worry though, we didn’t go through all of that just to generate a single digit then call it a day and head home.

Going over the Haskell wikibook’s state monad page again, we find a definition for rollDice, which is a state monad whose result is a pair of integers instead of a single integer (i.e. a simulation of rolling 2 dice):

rollDice :: State StdGen (Int, Int)
rollDice = liftM2 (,) rollDie rollDie

If you’re wondering what liftM2 is, it’s just a convenience function defined in Control.Monad that leads to more obfuscation in this case, so let’s redefine it:

rollDice :: State StdGen (Int, Int)
rollDice :: State StdGen (Int, Int)
rollDice = rollDie >>= \result_1 ->
           rollDie >>= \result_2 ->
           return (result_1, result_2)

Look back at the definition of (>>=) and return to understand how this works. Here’s the same thing in do notation:

rollDice' :: State StdGen (Int, Int)
rollDice' = do result_1 <- rollDie
               result_2 <- rollDie
               return (result_1, result_2)

It appears deviously simple, doesn’t it? Let’s roll some dice:

ghci> evalState rollDice (mkStdGen 423752)
(5,1)

Note that rollDice' would give the same result with the the same seed (423752). It really is the same function.

Finally, here’s the working code again with both definitions of rollDice and two roll functions. The roll functions accept a seed (Int) for the random generator and return a pair of Ints generated using randomR.

import System.Random


newtype State state result = State { runState :: state -> (result, state) }

instance Monad (State state_type) where
  --return :: result -> State state result
  return r = State ( \s -> (r, s) )

  --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b
  processor >>= processorGenerator = State $ \state ->
                                     let (result, state') = runState processor state
                                     in runState (processorGenerator result) state'

put newState = State $ \_ -> ((), newState)
get = State $ \state -> (state, state)

evalState stateMonad state = fst ( runState stateMonad state )
execState stateMonad state = snd ( runState stateMonad state )

type GeneratorState = State StdGen

rollDie :: GeneratorState Int
rollDie = do generator <- get
             let ( value, newGenerator ) = randomR (1,6) generator
             put newGenerator
             return value

--evalState rollDie (mkStdGen 0)

rollDice :: State StdGen (Int, Int)
rollDice = rollDie >>= \result_1 ->
           rollDie >>= \result_2 ->
           return (result_1, result_2)

rollDice' :: State StdGen (Int, Int)
rollDice' = do result_1 <- rollDie
               result_2 <- rollDie
               return (result_1, result_2)

roll :: Int -> (Int,Int)
roll = \x -> evalState rollDice (mkStdGen x)

roll' :: Int -> (Int,Int)
roll' = \x -> evalState rollDice' (mkStdGen x)

I hope this helps you to understand the state monad. Feel free to contact me if you would like to give feedback such as suggestions for improvements or corrections (both technical and orthographical), or just to let me know if you found it useful.

Contact
echo xyne.archlinux.org | sed 's/\./@/'
Validation
XHTML 1.0 Strict CSS level 3 Atom 1.0