Introduction to Monads With Scala 3
So, you are trying to grasp the concept of monads. If you are learning functional programming, understanding monads is a kind of rite of passage. Looking for monads on Google can be a bit daunting. The concept comes from category theory, so there are a lot of explanations starting from there. However, in this tutorial we are staying away from category theory and addressing them from the developer’s point of view.
Motivation
In the beginning of functional programming there were no monads. In the 1970s, functional programming was experimental and people were trying to find out what can be done only with pure functions. These people were mainly researchers theorizing about programming languages. The main justification for functional programming was (and still is) that it has very solid mathematical foundations. If your programs are pure (pure means free of side-effects in this context), then you can reason about your program, and even prove things about it. Sometimes, the compiler can prove things on your behalf through the type system.
The problem with pure functions is that they, by definition, do not produce side-effects, which are a cornerstone of programming. Throwing an exception, modifying a variable, printing to the console, etc are all effects. A solution is to encode the effects in the functions arguments and return types. This, however, is extremely counter intuitive and not ergonomic. An imperative programming language lets you do these things very easily.
Let us look at an example, we want a function to check if a String
encoded in base 58 is equivalent to a String encoded in hexadecimal:
case object InvalidInputString extends Exception
val base58alphabet =
"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz".zipWithIndex
val idxToCharBase58 = Map(base58alphabet.map(_.swap): _*)
val charToIdxBase58 = Map(base58alphabet: _*)
def decodeFromBase58(b58: String): Array[Byte] =
try {
val zeroCount = b58.takeWhile(_ == '1').length
Array.fill(zeroCount)(0.toByte) ++
b58
.drop(zeroCount)
.map(charToIdxBase58)
.toList
.foldLeft(BigInt(0))((acc, x) => acc * 58 + x)
.toByteArray
.dropWhile(_ == 0.toByte)
} catch {
case _ => throw InvalidInputString
}
def decodeFromHex(hex: String): Array[Byte] =
try {
hex.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
} catch {
case _ => throw InvalidInputString
}
def checkEqualEncoding(hex: String, b58: String): Boolean =
decodeFromHex(hex).sameElements(decodeFromBase58(b58))
checkEqualEncoding("0000", "11") // true
To do that, we decode both to a byte array and then compare the result. Each decoding can throw an exception. However, the error handling is managed behind the scenes by the exception mechanism. If one of the decoding functions throws an exception the whole computation throws an exception. We can rewrite this program in a purely functional style as follows:
def decodeFromBase58Pure(b58: String): Array[Byte] | Exception =
try {
decodeFromBase58(b58)
} catch {
case _ => InvalidInputString
}
def decodeFromHexPure(hex: String): Array[Byte] | Exception =
try {
decodeFromHex(hex)
} catch {
case _ => InvalidInputString
}
def checkEqualEncodingPure(hex: String, b58: String): Boolean | Exception =
val elt1 = decodeFromHexPure(hex)
elt1 match
case e: Exception => InvalidInputString
case a1: Array[Byte] =>
val elt2 = decodeFromBase58Pure(b58)
elt2 match
case e: Exception => InvalidInputString
case a2: Array[Byte] =>
a1.sameElements(a2)
checkEqualEncodingPure("0000", "11")
For simplicity sake, we are reusing the functions that we have before but this time, we are encoding the exception in the result type. As we want to have pure functions, we capture the exception and just return it as a value. However, the interesting part comes in checkEqualEncodingPure
. This time, since we do not have the exception handling we need to handle the short-circuiting ourselves. We have encoded the failure in the return type of the function.
In particular, we see the pattern appear twice:
val eltX = decode(enc)
eltX match
case e: Exception => InvalidInputString
case a1: Array[Byte] => // do something with the value
Now, let us consider the well known bubble sort algorithm, written in imperative programming style:
def bubblesort(l: Vector[Int]): Vector[Int] =
val a = l.toArray
while // the Scala 3 way of doing do/while loops
var swapped = false
for i <- 1 to (a.size - 1)
yield
if (a(i - 1) > a(i))
val temp = a(i - 1)
a.update(i - 1, a(i))
a.update(i, temp)
swapped = true
swapped
do ()
a.toVector
bubblesort(Vector(5, 4, 3, 2, 1)) // Vector(1, 2, 3, 4, 5)
The functional version of this (without monads) is something like this
import scala.annotation.tailrec
def bubblesortPure(l: Vector[Int]): Vector[Int] =
iterateUntilNotSwapped(l, false)
def checkAndSwap(i: Int, a: Vector[Int]) =
if (a(i - 1) > a(i))
val temp = a(i - 1)
(a.updated(i - 1, a(i)).updated(i, temp), true)
else (a, false)
@tailrec
final def iterateWith(
a: Vector[Int],
i: Int,
initialSwapped: Boolean
): (Vector[Int], Boolean) =
if (i < a.size)
val (res, swapped) = checkAndSwap(i, a)
iterateWith(res, i + 1, initialSwapped || swapped)
else (a, initialSwapped)
@tailrec
final def iterateUntilNotSwapped(
a: Vector[Int],
initialSwapped: Boolean
): Vector[Int] =
val (res, swapped) = iterateWith(a, 1, initialSwapped)
if (swapped) iterateUntilNotSwapped(res, false)
else res
bubblesortPure(Vector(4, 5, 3, 2, 1)) // Vector(1, 2, 3, 4, 5)
My guess is that the first version is easier to understand. There is no recursion and probably it is how you would do it in your head: it is almost a copy paste of the algorithm you find in Wikipedia. The second version is not only longer, but it is difficult to map to the simple algorithm that we had before. Nevertheless, we have managed to write it in a tail recursive way, which at least guarantees that we are stack-safe.
Please note that:
- we have encoded the state in the parameters and return type of the different functions
- each time we call a function, we need to extract the current state (
res
andswapped
), and then, do something with the them, (for example, check ifswapped
is true), and then, pass the state to the next function
In particular, we see the pattern appear twice:
val (res, swapped) = functionCall(i, a)
In both cases, short-circuiting and state handling, each time that we obtain the context from the result of a function call we perform an action. For the exception, we check the result of the computation, and we pass it on, if and only if the computation was successful. Otherwise, we abort. For the state, we pass it to the next computation. Let us call this operation, which is done between the two function calls, the intermediate action.
The problem that monads solve is twofold:
- they unify the way the context is passed around
- they define how the computation will continue, depending on the context, between function calls.
As you saw from previous examples, it is not obvious what the context is. It can be a state, it can be the status of a computation, or it can be a dependency provided from the outside, or side-effects, or whatever. That context can present itself in several forms.
The power of monads is that it gives us a unified way to handle all the different contexts. There is a monad for exceptions, a monad for state, a monad for dependency injection, etc. Also, that is why monads become hard to understand. The same concept can be applied to situations that have nothing to do with each other.
So, in both of these examples there is the assumption of a world beyond the simple expression that we are evaluating, a context. In the case of the state, the context is the value of the state, we encode the context in the parameters and in the return types. In the case of the exception, the context is the state of the computation, we encode this state in the return types.
Definition
A monad F
has three parts:
-
a part that transforms a type into another type (type constructor). For example, if we have type
A
, we will call the new typeF[A]
. This new type represents a computation with return typeA
and in the contextF[_]
. -
a function that transforms a value
a
of typeA
, and wraps it in the new typeF[A]
. In Scala, this function is calledunit: A => F[A]
. -
a function that takes the computation
fa: F[A]
, and a functionf: A => F[B]
, and returns a new computationfb: F[B]
. This function is commonly known (in Scala) asflatMap: (F[A], A => F[B]) => F[B]
. Intuitively, theflatMap
function will transform a computationfa
to a new computationfb
, wherefb
will: first, executefa
and obtain the result value and the context, then execute the intermediate action with the context, compute the new action to execute from the resulting valuef(a)
, and finally, execute it.
In pseudo code, we can describe the action of flatMap
as follows:
// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = previousContext =>
val resultAndNewContext = execute(previousContext, fa) // execute fa
intermediateAction(resultAndContext, f) // perform the side effect
For example, for our state monad, flatMap
is defined as follows:
// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = previousState =>
val resultAndNewState = execute(previousState, fa) // execute fa with previus state
intermediateAction(resultAndNewState, f) // // perform the side effect
intermediateAction(resultAndNewState, f) =
val (result, state) = resultAndNewState
execute(state, f(result))
For our short-circuit monad, flatMap
is defined as follows:
// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = // there is no previous context
val resultOrFailure = identity(fa) // execute = identity
intermediateAction(resultOrFailure, f) // // perform the side effect
intermediateAction(resultOrFailure, f) =
if (resultOrFailure.isFailure)
return Failure
else
execute(f(resultOrFailure.getResult))
In the same way that we saw with functors, monads need to fulfill certain laws. The monad laws are the following:
def checkLeftIdentity[A, B, F[_]](
a: A,
h: A => F[B]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
flatMap(unit(a), h) == h(a)
def checkRightIdentity[A, F[_]](
m: F[A]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
flatMap(m, unit) == m
def checkAssociativity[A, B, C, F[_]](
m: F[A],
g: A => F[B],
h: B => F[C]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
flatMap(flatMap(m, g), h) == flatMap(m, x => flatMap(g(x), h))
The first two laws guarantee that the unit operation is an identity/neutral element with respect to flatMap. This means, that, no matter where you put it, it leaves unchanged the other parameter of the operation. The third law guarantees that the order in which you apply the flatMaps does not change the result of the computation
Some Concrete Monads
Let us explore some concrete monads to better understand the action of flatMap
. Let us consider the identity monad, where the intermediateAction is just to execute the next action, and there is no context.
type IdMonad[A] = A
def idUnit[A](a: A): IdMonad[A] = a
def idIntermediateAction[A, B](result: A, f: A => IdMonad[B]) =
f(result)
def idFlatMap[A, B](fa: IdMonad[A], f: A => IdMonad[B]) =
val result = fa
idIntermediateAction(result, f)
idFlatMap(5, _ + 2) // 7
Basically, it just uses the first value, passes it to the function directly. In procedural style, it is just the same as executing one statement after another, something like this
// idFlatMap(5, _ + 2)
val x = 5
x + 2 // 7
Now, let us explore the short circuit monad.
type EitherMonad[A] = A | Exception
def scUnit[A](a: A): EitherMonad[A] = a
def scIntermediateAction[A, B](
result: EitherMonad[A],
f: A => EitherMonad[B]
): EitherMonad[B] =
result match
case e: Exception => e
case _ => f(result.asInstanceOf[A])
def scFlatMap[A, B](fa: EitherMonad[A], f: A => EitherMonad[B]) =
val result = fa
scIntermediateAction[A, B](result, f)
def checkEqualEncodingPureMonad(hex: String, b58: String) =
scFlatMap(
decodeFromHexPure(hex),
x => scFlatMap(decodeFromBase58Pure(b58), y => x.sameElements(y))
)
checkEqualEncodingPureMonad("0000", "11") // true
As you can see, this monad makes the pure checkEqualEncoding
very similar to its impure counterpart. The monad, has indeed abstracted away the boilerplate and left us with a more elegant code.
Finally, let use explore the state monad. This time, we will create type constructor using a class, instead of the type keyword.
case class StateMonad[S, A](execute: S => (S, A)):
def unit(a: A) = StateMonad(s => (s, a))
private def intermediateAction[B](
stateAndResult: (S, A),
f: A => StateMonad[S, B]
): (S, B) =
val (newState, result) = stateAndResult
f(result).execute(newState)
def flatMap[B](f: A => StateMonad[S, B]) =
StateMonad[S, B] { oldState =>
val stateAndResult = execute(oldState)
intermediateAction(stateAndResult, f)
}
def map[B](f: A => B): StateMonad[S, B] = StateMonad { oldS =>
val (s, a) = execute(oldS)
(s, f(a))
}
end StateMonad
case class BState(vector: Vector[Int], swapped: Boolean, counter: Int)
Before entering into the details, having the State monad implemented like this gives us some nice syntax sugar. We can, for example, use for comprehension. This means that when we concatenate flatMaps as we did in checkEqualEncodingPureMonad
we can write it as follows:
// scFlatMap(
// decodeFromHexPure(hex),
// x => scFlatMap(decodeFromBase58Pure(b58), y => x.sameElements(y))
// )
// the code below does not compile,
// because we did not write our monad using the convention that Scala is expecting
for {
x <- decodeFromHexPure(hex)
y <- decodeFromBase58Pure(b58
} yield x.sameElements(y)
Back to the State monad, it clearly implements the same structure that we have seen before. The intermediate action is just passing the new state around. You might be wondering if we can do the bubble sort using the state monad. The answer is yes, and the code looks surprisingly similar to the code that we wrote in our impure implementation. However, the bad news is that we need to implement some generic boilerplate to make it as similar as possible to the impure version. The good news is that this kind of boilerplate is not necessary to write in the real world, as libraries such as cats already have implementation for most of the things we implemented here. Here is the full code:
// some useful methods
object StateMonad:
def get[S]: StateMonad[S, S] = StateMonad(s => (s, s))
def set[S](newS: S): StateMonad[S, Unit] = StateMonad(_ => (newS, ()))
def update[S](u: S => S): StateMonad[S, Unit] = StateMonad(s => (u(s), ()))
def unit[S] = StateMonad[S, Unit](s => (s, ()))
def ifThen[S](cond: => Boolean)(sm: StateMonad[S, Unit]) =
if (cond) then sm else unit[S]
def doWhile[S](
sm: StateMonad[S, Unit]
)(cond: S => Boolean): StateMonad[S, Unit] =
for {
_ <- sm
s <- get[S]
_ <- if (cond(s)) then doWhile(sm)(cond) else unit[S]
} yield ()
type BSortState[A] = StateMonad[BState, A]
import StateMonad._
def setSwapped(v: Boolean) = update[BState](_.copy(swapped = v))
def setCounter(i: Int) = update[BState](_.copy(counter = i))
def swapWithMonad(a: Vector[Int], i: Int): BSortState[Unit] =
val temp = a(i - 1)
for {
_ <- update[BState](s =>
s.copy(vector = a.updated(i - 1, a(i)).updated(i, temp))
)
_ <- setSwapped(true)
} yield ()
def incCounter() =
update[BState](s => s.copy(counter = s.counter + 1))
def getVector() = get[BState].map(_.vector)
def getCounter() = get[BState].map(_.counter)
// the actual bubble sort
def bubblesortPureMonad(l: Vector[Int]): Vector[Int] =
StateMonad
.doWhile[BState](
for {
_ <- setSwapped(false)
_ <- doWhile[BState] {
for {
a <- getVector()
i <- getCounter()
_ <- ifThen[BState](a(i - 1) > a(i))(swapWithMonad(a, i))
_ <- incCounter()
} yield ()
}(x => x.counter < x.vector.size)
_ <- setCounter(1)
} yield ()
)(x => x.swapped)
.execute(BState(l, false, 1))
._1
.vector
bubblesortPureMonad(Vector(3, 4, 5, 2, 1)) // Vector(1, 2, 3, 4, 5)
If you look at the code of the actual bubble sort, it is very similar to the impure counterpart. The magic part there is the for comprehension that makes the code look imperative, while in the background is no more than a lot of concatenated flatMaps
.
What is The Intuition Behind All This?
I hope that the intution behind monads is clear until now. Monads help functional programmers to access features of imperative/impure languages but in a pure and functional way. Intuitively, we can think of the laws as follows:
We can think of unit as just taking (or lifting) a value into a computation. If we were to write the unit function in an imperative program, it would be equivalente to the identity function. As such, if we were writing our code in an imperative way, we expect that:
val x = identity(value)
val y = performComputation(x)
Should be the same as:
y = performComputation(value) // this is the left identity law
In the same way, we expect that:
val x = performComputation(value)
val y = identity(x)
is the same as:
y = performComputation(value) // this is the right identity
This means that our intermediate actions can be executed in any order, and the full computation should compute the same result.
In the same way, if we were to translate the associativity law to imperative programming. We would have:
{
val x = performComputation1(value)
y = performComputation2(x)
}
val z = performComputation3(y)
And we would expect that to the the same as:
val x = performComputation1(value)
{
val y = performComputation2(x)
val z = performComputation3(y)
} // this is the associativy law
As you can see, the monads laws just formalize the laws that we would expect to hold when running an arbitrary imperative program.
Conclusion
We have explored monads, seen their formal definition and also explored some real world algorithm using imperative programming, functional programming without monads and functional programming with monads. I hope that the examples and explanations have helped you to improve your understanding of monads and that you can start using it in your code.
If you have read until here and liked this content please consider subscribing to my mailing list. I write regularly about Blockchain Technology, Scala and Language Engineering. Have questions? Leave a comment below and I’ll try to answer!