A quick disclaimer: I’m still learning Haskell so I may have misunderstood some things or described them incorrectly. I wrote this in one sitting and there are probably a few typos in the text and maybe even some technical errors. Let me know if you find any. Now on with the show.
I found the state monad difficult to understand at first and from my searches it seems that I’m not the only one. On a superficial level it is conceptually simple: pass around the state to keep your functions pure. If the state itself is passed to a function then there are no side effects because each input will be mapped to a single output and we can trust our function to work as expected. Simple, right?
The difficulty lay in understanding what was really going on under the surface of the state monad. Basically, this section is for anyone who keeps looking at something like this:
rollDie :: State StdGen Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value
and wondering how the get
and put
just seem
to magically pull state out of thin air, alter it, then push it back
before returning a value that appears to be unrelated to this pushing
and pulling. I’ll run through the key elements to understanding what’s
actually happening then I’ll try to put them all together in a way that
makes everything clear.
The code in this section is take from the Haskell wikibook’s state
monad page. Note that I’ve changed the names in the definitions of
evalState
and execState
to make it clearer
that they accept a state as an argument. I’ve also formatted this code
so that it will compile if you want to play around with it.
import System.Random newtype State state result = State { runState :: state -> (result, state) } instance Monad (State state_type) where --return :: result -> State state result return r = State ( \s -> (r, s) ) --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b processor >>= processorGenerator = State $ \state -> let (result, state') = runState processor state in runState (processorGenerator result) state' put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state) evalState stateMonad state = fst ( runState stateMonad state ) execState stateMonad state = snd ( runState stateMonad state ) type GeneratorState = State StdGen rollDie :: GeneratorState Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value --evalState rollDie (mkStdGen 0)
The state monad doesn’t actually hold a state. It holds something that can transform a state. Let’s look at the type:
newtype State state result = State { runState :: state -> (result, state) }
The record syntax adds some sugar that gets in the way, so let’s rewrite the above without it:
newtype State state result = State ( state -> (result, state) ) runState :: (State state result) -> (state -> (result,state)) runState (State f) = f
So, the state monad contains something that takes a state and returns
a tuple consisting of a result and a state. runState
is
just an accessor function to get the function inside the state monad. We
need this because the function is wrapped in the State
constructor and we don’t want to have to pattern-match against it all
the time.
Actually, it’s probably not even right to say that it “contains” a function. It really is a function, but it’s wrapped in a constructor to give it a type that we can work with.
So, what does the state transformer function inside of the state monad do? It takes a state and returns a tuple containing a result and a new state. That’s all.
do
notation is syntactic sugar and it makes it easier
for people from an imperative programming background to get started with
monads, but it also obfuscates the functional nature of the code. It
tricks you into thinking about the code as though it were imperative and
thus it appears that some “assignments” are completely unrelated to
others. I believe that this can impede the learning of Haskell.
Let’s look at the section of code that confuses people:
rollDie :: State StdGen Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value
With imperative eyes this looks simple. The type of
rollDie
is a state monad and walking through it we see that
it gets a state, passes it to randomR, gets a random value along with a
new state, puts the state back wherever it got it, then returns the
value.
OK, so where did the state come from? How did get
magically pull the state out of thin air and how did put
put it back? The function is pure. It isn’t changing something outside
of it self, and it doesn’t accept any arguments. What’s it actually
doing?
Let’s look at the get and put functions again:
put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state)
get
returns a state monad, but looking at the code we
see that it takes no arguments and always returns the same thing. It’s
not returning some specific state that’s tracking what we’re doing. It’s
just a generic state monad with a function that accepts a state and
returns a 2-tuple with the unaltered state in both positions. Again, it
takes no arguments and always returns the same thing. (Well, not
strictly… the type gets marshalled into whatever is needed, but that’s
irrelevant here. Read up about type variables and polymorphic types if
you want to know more.)
put
take a state as an argument and returns a state
monad. The function inside the state monad, as we saw above in the ‘The
“State Monad” Is A State Transformer’ section, is a function that
accepts a state and returns a tuple containing a result and a new state.
Looking at the definition of put
, we see that this state
transforming function ignores its argument (a state) and returns a tuple
with an empty result and the state that we passed to
put
.
It’s not “storing” a state anywhere. It’s actually getting a state monad that can transform all states (of the same type) to a tuple with an empty result and the state that we passed to put.
Be careful not to confuse put
with the state
transforming function inside the state monad that it returns.
put
is a function that accepts an argument and returns a
state monad. The state monad that it returns contains a function that
accepts a state, ignores it, and returns a tuple containing the state
that we originally passed to put
.
So how does that preserve state and pass it around outside of the function? What is going on here?
It’s time to get rid of some syntactic sugar by removing the do
notation. This is what the compiler does behind the scenes when we run
our code. Transforming the notation, the rollDie
function
becomes:
rollDie :: State StdGen Int rollDie = get >>= \generator -> let ( value, newGenerator ) = randomR (1,6) generator in put newGenerator >> return value
If the transformation isn’t clear, read up on do
notation.
(>>=)
Now that we have the function expressed in terms of
(>>=)
and (>>)
(the “bind” and
“then” operators, respectively), we need to look back at how
(>>=)
is defined for state monads (you should read up
on these operators if you are not familiar with them, but if that’s the
case then you might want to come back to state monads later). As we’ve
just seen above, get
is returning a state monad (not a
state), so the first argument to the first (>>=)
is
indeed a state monad. Let’s change the names in the definition above to
make it a bit clearer too:
(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b stateTransformer >>= stateTransformerGenerator = State $ \state -> let (result, state') = runState stateTransformer state in runState (stateTransformerGenerator result) state'
Let’s run through this and make some sense of it. The first argument
is a state monad and the second argument is a function that takes a
result and returns a state monad that might have a different result
type. (>>=)
takes these two arguments and returns a
new state monad.
The definition begins with State
, which is the
constructor of the state monad. This just ensures that the following
lambda expression, which must be a state transforming function, will be
wrapped up in a state monad.
In the second line, runState stateTransformer
pulls out
the state transforming function from the first argument to (>>=),
namely stateTransformer
. This is the function that accepts
a state and returns a tuple with a result and a new state. The returned
function is applied to state
and returns the tuple
(result, state')
, where state'
is a new
state.
Before continuing, remind yourself where state
is coming
from. Remember that we are defining a state transforming function that
accepts a state and returns a tuple, which we are then wrapping in the
state monad. state
is therefore a state that will be passed
to this function later. It’s saying “When you’re passed a state, take
that state and pass it to the state transformer in the state monad of
the first argument, then call the result and new state in the tuple
result
and state'
, respectively.”. You can
think of it as a set of instructions of what to do with a state if it’s
passed. They’re not actually being done yet.
In the final line, we take the result from above
(result
) and we feed it into the second argument of
(>>=)
, namely stateTransformerGenerator
.
This returns a new state monad and we then pull out the state
transforming function in it with our accessor, runState
. We
now have a function that accepts a state and returns a tuple. We then
pass this function the new state generated in the previous line
(state'
) and get the tuple that we’re expecting to complete
the definition of a state transformer started by the lambda expression
on the first line (\state ->
).
So, what do we have at the end? We have a state transformer that takes a state, runs it through one state transformer to get a new one and a result, then it uses the result to get another state transformer from our “state transformer generator”, and then runs the new state through that too to get a third state and a new result. We’re threading the state through a succession of state transformers (as we would thread a string through a string through beads) and using the result of each transformation to determine the next state transformer.
get
and put
Now we need to fit this back into our definition of rollDie to see what it’s doing. Let’s look at the definition again (no need to scroll back):
rollDie :: State StdGen Int rollDie = get >>= \generator -> let ( value, newGenerator ) = randomR (1,6) generator in put newGenerator >> return value
The rollDie function will return a state monad that can take a “random number generator state” and return an integer between 1 and 6, inclusive, along with a new “random number generator state”, which can then be threaded through rollDie again to get another number and another new state.
Let’s replace generator
and newGenerator
,
which are the “random number generator states”, with gState
and gState'
. Let’s also replace get
and
put newGenerator
with what they return.
rollDie :: State StdGen Int rollDie = State $ \state -> (state, state) >>= \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') >> return value
Make sure that you understand the replacements by looking back at the definitions.
(>>)
There is still one last thing that we need to know before we can put
all of this together. What does (>>)
(then) do in
this context? You’re probably familiar with it from the IO monad, which
let’s you do something like this:
(putStrLn "foo") >> (putStrLn "bar") >> (putStrLn "baz")
or in the more familiar do
notation:
do putStrLn "foo" putStrLn "bar" putStrLn "baz"
With IO, this just prints “foo”, then it prints “bar”, then it prints
“baz”, each on a new line. (>>)
seems to just
sequence the functions and there doesn’t seem to be any connection
between them, but this is just a consequence of the definition of
(>>)
for IO monads. Actually, it’s a consequence of
the definition of (>>=)
.
(>>)
is actually defined in terms of
(>>=)
:
m >> n = m >>= \_ -> n
Let’s replace (>>)
in our rollDie function with
this definition and take another look at it.
rollDie :: State StdGen Int rollDie = State $ \state -> (state, state) >>= \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') >>= \_ -> return value
So far, so good.
OK, we’re almost there, It’s time to wrap our heads around this. Here
are the definitions of return
and (>>=)
again so you don’t have to scroll back:
return r = State ( \s -> (r, s) ) (>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b stateTransformer >>= stateTransformerGenerator = State $ \state -> let (result, state') = runState stateTransformer state in runState (stateTransformerGenerator result) state'
Looking at the definition of rollDie
above in terms of
(>>=)
, we see that it’s defined as
x >>= y >>= z
, where
x = State $ \state -> (state, state) y = \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') z = \_ -> return value
The monad
laws let us break this up as either:
(x >>= y) >>= z
or
x >>= (\w -> y w >>= z)
. To avoid confusion
about how value
is getting passed around, we’ll use the
latter. This must return a “state transforming function generator”,
i.e. something with the type
(result_a -> State state result_b)
so that it can be
passed to x >>=
. We can rewrite it as
\w -> (y w >>= z)
to make it clear how we should
apply the bind operator inside the lambda function.
\w -> (y w >>= z) = \w -> State $ \state -> let (result, state') = runState (y w) state in runState (z result) state'
The thing to note here is that the first argument to the bind
operator is y w
, not y
. y
takes a
state and returns a state monad and the first argument of the bind
operator must be a state monad, not a function that generates a state
monad from a state (That’s what the second argument to the bind operator
is, and that is why we could evaluate x >>= y
first).
Let’s step through some evaluations:
y = \gState -> let ( value, gState' ) = randomR (1,6) gState in State $ \_ -> ((), gState') y w = let ( value, gState' ) = randomR (1,6) w in State $ \_ -> ((), gState') runState (y w) = let ( value, gState' ) = randomR (1,6) w in \_ -> ((), gState') (runState (y w)) state = let ( value, gState' ) = randomR (1,6) w in ((), gState')
runState
pulls the state transformer function out of the
state monad, which is equivalent to removing the constructor from the
definition. Passing state
to this function returns a tuple.
Because the function was defined as \_ -> ((), gState')
,
it completely ignores the state that we pass it, but it still returns
the tuple ((),gState')
. Let’s put this back into our
definition above:
\w -> (y w >>= z) = \w -> State $ \state -> let ( value, gState' ) = randomR (1,6) w (result, state') = ((),gState') in runState (z result) state'
Now we step through evaluations starting with z
by first
substituting the definition of return
:
z = \_ -> return value z = \_ -> State ( \s -> (value, s) ) z result = State ( \s -> (value, s) ) runState (z result) = \s -> (value, s) (runState (z result)) state' = (value, state')
z
is a function that ignores its argument and returns a
state monad. That state monad’s contains a function that accepts a state
and returns a tuple with a fixed value and an unchanged state. As
before, runState
unwraps the monad to get the function
inside it (that accepts a state and returns a tuple). We then pass
state'
to that function to get the tuple, which is the
fixed value
and state'
, unchanged.
Now we substitute back into our previous expression.
\w -> (y w >>= z) = \w -> State $ \state -> let ( value, gState' ) = randomR (1,6) w (result, state') = ((),gState') in (value, state')
which we can simplify to
\w -> (y w >>= z) = \w -> State $ \_ -> randomR (1,6) w
because state
and result
are ignored in the
definition, and gState'
is only used in the definition of
state'
. As before, work through it to make sure you
understand the transformation. Now we substitute into
x >>= (\w -> y w >>= z)
, again using the
definition of (>>=)
:
x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = runState x state in runState ((\w -> y w >>= z) result) state'
and step through some more evaluations:
\w -> y w >>= z = (\w -> State $ \_ -> randomR (1,6) w) (\w -> y w >>= z) result = State $ \_ -> randomR (1,6) result runState ((\w -> y w >>= z) result) = \_ -> randomR (1,6) result (runState ((\w -> y w >>= z) result)) state' = randomR (1,6) result
and thus:
x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = runState x state in randomR (1,6) result
Now we just need to run through the evaluations starting with
x
(using s
instead of state
in
the lambda expression from above to avoid confusion with
state
here).
x = State $ \s -> (s, s) runState x = \s -> (s, s) (runState x) state = (state, state)
and finally we substitute for x
and then simplify
x >>= (\w -> y w >>= z) = State $ \state -> let (result, state') = (state,state) in randomR (1,6) result x >>= (\w -> y w >>= z) = State $ \state -> randomR (1,6) state
Both result
and state'
are equal to
state
, so we replace them then remove the tuple assignment,
which is redundant. Now remember that
x >>= (\w -> y w >>= z)
is equal to
x >>= y >>= z
, which is just our definition of
rollDie
.
In the end, we’re left with
rollDie = State $ \state -> randomR (1,6) state
Let’s check what the type of randomR (1,6) state
will
be:
ghci> :m System.Random ghci> :t randomR randomR :: (Random a, RandomGen g) => (a, a) -> g -> (a, g)
In our case, Random a
is an Int
and
RandomGen g
is our random generator state, so
randomR (1,6) state
will return a tuple containing an Int
and a new random generator state. So our function accepts a state and
returns a tuple containing a result and a new state, i.e. a state
transformer.
State
wraps this state transformer function into a state
monad and so our rollDie
function returns a state monad, as
it’s supposed to.
So, what can we do with it?
evalState rollDie (mkStdGen 0)
mkStdGen 0
returns a new random generator state (0 is
seed). evalState
, which we saw way up above, takes a state
monad and a state, pulls out the state transforming function from the
state monad, applies it to the state to get the result and new state
tuple, then pulls out the result from the tuple and returns it, which
gives us an integer between 1 and 6, inclusive.
Yeah, that isn’t all that exciting. Don’t worry though, we didn’t go through all of that just to generate a single digit then call it a day and head home.
Going over the Haskell wikibook’s state monad page again, we find a
definition for rollDice
, which is a state monad whose
result is a pair of integers instead of a single integer (i.e. a
simulation of rolling 2 dice):
rollDice :: State StdGen (Int, Int) rollDice = liftM2 (,) rollDie rollDie
If you’re wondering what liftM2
is, it’s just a
convenience function defined in Control.Monad that leads to more
obfuscation in this case, so let’s redefine it:
rollDice :: State StdGen (Int, Int) rollDice :: State StdGen (Int, Int) rollDice = rollDie >>= \result_1 -> rollDie >>= \result_2 -> return (result_1, result_2)
Look back at the definition of (>>=)
and
return
to understand how this works. Here’s the same thing
in do
notation:
rollDice' :: State StdGen (Int, Int) rollDice' = do result_1 <- rollDie result_2 <- rollDie return (result_1, result_2)
It appears deviously simple, doesn’t it? Let’s roll some dice:
ghci> evalState rollDice (mkStdGen 423752) (5,1)
Note that rollDice'
would give the same result with the
the same seed (423752). It really is the same function.
Finally, here’s the working code again with both definitions of
rollDice
and two roll
functions. The
roll
functions accept a seed (Int) for the random generator
and return a pair of Ints generated using randomR.
import System.Random newtype State state result = State { runState :: state -> (result, state) } instance Monad (State state_type) where --return :: result -> State state result return r = State ( \s -> (r, s) ) --(>>=) :: State state result_a -> (result_a -> State state result_b) -> State state result_b processor >>= processorGenerator = State $ \state -> let (result, state') = runState processor state in runState (processorGenerator result) state' put newState = State $ \_ -> ((), newState) get = State $ \state -> (state, state) evalState stateMonad state = fst ( runState stateMonad state ) execState stateMonad state = snd ( runState stateMonad state ) type GeneratorState = State StdGen rollDie :: GeneratorState Int rollDie = do generator <- get let ( value, newGenerator ) = randomR (1,6) generator put newGenerator return value --evalState rollDie (mkStdGen 0) rollDice :: State StdGen (Int, Int) rollDice = rollDie >>= \result_1 -> rollDie >>= \result_2 -> return (result_1, result_2) rollDice' :: State StdGen (Int, Int) rollDice' = do result_1 <- rollDie result_2 <- rollDie return (result_1, result_2) roll :: Int -> (Int,Int) roll = \x -> evalState rollDice (mkStdGen x) roll' :: Int -> (Int,Int) roll' = \x -> evalState rollDice' (mkStdGen x)
I hope this helps you to understand the state monad. Feel free to contact me if you would like to give feedback such as suggestions for improvements or corrections (both technical and orthographical), or just to let me know if you found it useful.