Introduction to Monads With Scala 3

post-thumb

So, you are trying to grasp the concept of monads. If you are learning functional programming, understanding monads is a kind of rite of passage. Looking for monads on Google can be a bit daunting. The concept comes from category theory, so there are a lot of explanations starting from there. However, in this tutorial we are staying away from category theory and addressing them from the developer’s point of view.

Motivation

In the beginning of functional programming there were no monads. In the 1970s, functional programming was experimental and people were trying to find out what can be done only with pure functions. These people were mainly researchers theorizing about programming languages. The main justification for functional programming was (and still is) that it has very solid mathematical foundations. If your programs are pure (pure means free of side-effects in this context), then you can reason about your program, and even prove things about it. Sometimes, the compiler can prove things on your behalf through the type system.

The problem with pure functions is that they, by definition, do not produce side-effects, which are a cornerstone of programming. Throwing an exception, modifying a variable, printing to the console, etc are all effects. A solution is to encode the effects in the functions arguments and return types. This, however, is extremely counter intuitive and not ergonomic. An imperative programming language lets you do these things very easily.

Let us look at an example, we want a function to check if a String encoded in base 58 is equivalent to a String encoded in hexadecimal:

case object InvalidInputString extends Exception

val base58alphabet =
  "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz".zipWithIndex

val idxToCharBase58 = Map(base58alphabet.map(_.swap): _*)

val charToIdxBase58 = Map(base58alphabet: _*)

def decodeFromBase58(b58: String): Array[Byte] =
  try {
    val zeroCount = b58.takeWhile(_ == '1').length
    Array.fill(zeroCount)(0.toByte) ++
      b58
        .drop(zeroCount)
        .map(charToIdxBase58)
        .toList
        .foldLeft(BigInt(0))((acc, x) => acc * 58 + x)
        .toByteArray
        .dropWhile(_ == 0.toByte)
  } catch {
    case _ => throw InvalidInputString
  }

def decodeFromHex(hex: String): Array[Byte] =
  try {
    hex.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray
  } catch {
    case _ => throw InvalidInputString
  }

def checkEqualEncoding(hex: String, b58: String): Boolean =
  decodeFromHex(hex).sameElements(decodeFromBase58(b58))

checkEqualEncoding("0000", "11") // true

To do that, we decode both to a byte array and then compare the result. Each decoding can throw an exception. However, the error handling is managed behind the scenes by the exception mechanism. If one of the decoding functions throws an exception the whole computation throws an exception. We can rewrite this program in a purely functional style as follows:

def decodeFromBase58Pure(b58: String): Array[Byte] | Exception =
  try {
    decodeFromBase58(b58)
  } catch {
    case _ => InvalidInputString
  }

def decodeFromHexPure(hex: String): Array[Byte] | Exception =
  try {
    decodeFromHex(hex)
  } catch {
    case _ => InvalidInputString
  }

def checkEqualEncodingPure(hex: String, b58: String): Boolean | Exception =
  val elt1 = decodeFromHexPure(hex)
  elt1 match
    case e: Exception => InvalidInputString
    case a1: Array[Byte] =>
      val elt2 = decodeFromBase58Pure(b58)
      elt2 match
        case e: Exception => InvalidInputString
        case a2: Array[Byte] =>
          a1.sameElements(a2)

checkEqualEncodingPure("0000", "11")

For simplicity sake, we are reusing the functions that we have before but this time, we are encoding the exception in the result type. As we want to have pure functions, we capture the exception and just return it as a value. However, the interesting part comes in checkEqualEncodingPure. This time, since we do not have the exception handling we need to handle the short-circuiting ourselves. We have encoded the failure in the return type of the function.

In particular, we see the pattern appear twice:

val eltX = decode(enc)
eltX match
   case e: Exception => InvalidInputString
   case a1: Array[Byte] => // do something with the value

Now, let us consider the well known bubble sort algorithm, written in imperative programming style:

def bubblesort(l: Vector[Int]): Vector[Int] =
  val a = l.toArray
  while // the Scala 3 way of doing do/while loops
    var swapped = false
    for i <- 1 to (a.size - 1)
    yield
      if (a(i - 1) > a(i))
        val temp = a(i - 1)
        a.update(i - 1, a(i))
        a.update(i, temp)
        swapped = true
    swapped
  do ()
  a.toVector

bubblesort(Vector(5, 4, 3, 2, 1)) // Vector(1, 2, 3, 4, 5)

The functional version of this (without monads) is something like this

import scala.annotation.tailrec

def bubblesortPure(l: Vector[Int]): Vector[Int] =
  iterateUntilNotSwapped(l, false)

def checkAndSwap(i: Int, a: Vector[Int]) =
  if (a(i - 1) > a(i))
    val temp = a(i - 1)
    (a.updated(i - 1, a(i)).updated(i, temp), true)
  else (a, false)

@tailrec
final def iterateWith(
    a: Vector[Int],
    i: Int,
    initialSwapped: Boolean
): (Vector[Int], Boolean) =
  if (i < a.size)
    val (res, swapped) = checkAndSwap(i, a)
    iterateWith(res, i + 1, initialSwapped || swapped)
  else (a, initialSwapped)

@tailrec
final def iterateUntilNotSwapped(
    a: Vector[Int],
    initialSwapped: Boolean
): Vector[Int] =
  val (res, swapped) = iterateWith(a, 1, initialSwapped)
  if (swapped) iterateUntilNotSwapped(res, false)
  else res

bubblesortPure(Vector(4, 5, 3, 2, 1)) // Vector(1, 2, 3, 4, 5)

My guess is that the first version is easier to understand. There is no recursion and probably it is how you would do it in your head: it is almost a copy paste of the algorithm you find in Wikipedia. The second version is not only longer, but it is difficult to map to the simple algorithm that we had before. Nevertheless, we have managed to write it in a tail recursive way, which at least guarantees that we are stack-safe.

Please note that:

  • we have encoded the state in the parameters and return type of the different functions
  • each time we call a function, we need to extract the current state (res and swapped), and then, do something with the them, (for example, check if swapped is true), and then, pass the state to the next function

In particular, we see the pattern appear twice:

val (res, swapped) = functionCall(i, a)

In both cases, short-circuiting and state handling, each time that we obtain the context from the result of a function call we perform an action. For the exception, we check the result of the computation, and we pass it on, if and only if the computation was successful. Otherwise, we abort. For the state, we pass it to the next computation. Let us call this operation, which is done between the two function calls, the intermediate action.

The problem that monads solve is twofold:

  • they unify the way the context is passed around
  • they define how the computation will continue, depending on the context, between function calls.

As you saw from previous examples, it is not obvious what the context is. It can be a state, it can be the status of a computation, or it can be a dependency provided from the outside, or side-effects, or whatever. That context can present itself in several forms.

The power of monads is that it gives us a unified way to handle all the different contexts. There is a monad for exceptions, a monad for state, a monad for dependency injection, etc. Also, that is why monads become hard to understand. The same concept can be applied to situations that have nothing to do with each other.

So, in both of these examples there is the assumption of a world beyond the simple expression that we are evaluating, a context. In the case of the state, the context is the value of the state, we encode the context in the parameters and in the return types. In the case of the exception, the context is the state of the computation, we encode this state in the return types.

Definition

A monad F has three parts:

  • a part that transforms a type into another type (type constructor). For example, if we have type A, we will call the new type F[A]. This new type represents a computation with return type A and in the context F[_].

  • a function that transforms a value a of type A, and wraps it in the new type F[A]. In Scala, this function is called unit: A => F[A].

  • a function that takes the computation fa: F[A], and a function f: A => F[B], and returns a new computation fb: F[B]. This function is commonly known (in Scala) as flatMap: (F[A], A => F[B]) => F[B]. Intuitively, the flatMap function will transform a computation fa to a new computation fb, where fb will: first, execute fa and obtain the result value and the context, then execute the intermediate action with the context, compute the new action to execute from the resulting value f(a), and finally, execute it.

In pseudo code, we can describe the action of flatMap as follows:

// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = previousContext =>
  val resultAndNewContext = execute(previousContext, fa) // execute fa
  intermediateAction(resultAndContext, f)  // perform the side effect

For example, for our state monad, flatMap is defined as follows:

// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = previousState =>
  val resultAndNewState = execute(previousState, fa) // execute fa with previus state
  intermediateAction(resultAndNewState, f)  // // perform the side effect 

intermediateAction(resultAndNewState, f) =
	val (result, state) = resultAndNewState
	execute(state, f(result))

For our short-circuit monad, flatMap is defined as follows:

// this code is not valid Scala code, just pseudo-code
flatMap(fa, f) = // there is no previous context
  val resultOrFailure = identity(fa) // execute = identity
  intermediateAction(resultOrFailure, f)  // // perform the side effect

intermediateAction(resultOrFailure, f) =
	if (resultOrFailure.isFailure)
		return Failure
	else
		execute(f(resultOrFailure.getResult))

In the same way that we saw with functors, monads need to fulfill certain laws. The monad laws are the following:

def checkLeftIdentity[A, B, F[_]](
    a: A,
    h: A => F[B]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
  flatMap(unit(a), h) == h(a)

def checkRightIdentity[A, F[_]](
    m: F[A]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
  flatMap(m, unit) == m

def checkAssociativity[A, B, C, F[_]](
    m: F[A],
    g: A => F[B],
    h: B => F[C]
)(unit: A => F[A], flatMap: [X, Y] => (F[X], X => F[Y]) => F[Y]): Boolean =
  flatMap(flatMap(m, g), h) == flatMap(m, x => flatMap(g(x), h))

The first two laws guarantee that the unit operation is an identity/neutral element with respect to flatMap. This means, that, no matter where you put it, it leaves unchanged the other parameter of the operation. The third law guarantees that the order in which you apply the flatMaps does not change the result of the computation

Some Concrete Monads

Let us explore some concrete monads to better understand the action of flatMap. Let us consider the identity monad, where the intermediateAction is just to execute the next action, and there is no context.

type IdMonad[A] = A

def idUnit[A](a: A): IdMonad[A] = a

def idIntermediateAction[A, B](result: A, f: A => IdMonad[B]) =
  f(result)

def idFlatMap[A, B](fa: IdMonad[A], f: A => IdMonad[B]) =
  val result = fa
  idIntermediateAction(result, f)

idFlatMap(5, _ + 2) // 7

Basically, it just uses the first value, passes it to the function directly. In procedural style, it is just the same as executing one statement after another, something like this

// idFlatMap(5, _ + 2)
val x = 5
x + 2 // 7

Now, let us explore the short circuit monad.

type EitherMonad[A] = A | Exception

def scUnit[A](a: A): EitherMonad[A] = a

def scIntermediateAction[A, B](
    result: EitherMonad[A],
    f: A => EitherMonad[B]
): EitherMonad[B] =
  result match
    case e: Exception => e
    case _            => f(result.asInstanceOf[A])

def scFlatMap[A, B](fa: EitherMonad[A], f: A => EitherMonad[B]) =
  val result = fa
  scIntermediateAction[A, B](result, f)

def checkEqualEncodingPureMonad(hex: String, b58: String) =
  scFlatMap(
    decodeFromHexPure(hex),
    x => scFlatMap(decodeFromBase58Pure(b58), y => x.sameElements(y))
  )

checkEqualEncodingPureMonad("0000", "11") // true

As you can see, this monad makes the pure checkEqualEncoding very similar to its impure counterpart. The monad, has indeed abstracted away the boilerplate and left us with a more elegant code.

Finally, let use explore the state monad. This time, we will create type constructor using a class, instead of the type keyword.

case class StateMonad[S, A](execute: S => (S, A)):
  def unit(a: A) = StateMonad(s => (s, a))
  private def intermediateAction[B](
      stateAndResult: (S, A),
      f: A => StateMonad[S, B]
  ): (S, B) =
    val (newState, result) = stateAndResult
    f(result).execute(newState)

  def flatMap[B](f: A => StateMonad[S, B]) =
    StateMonad[S, B] { oldState =>
      val stateAndResult = execute(oldState)
      intermediateAction(stateAndResult, f)
    }
  def map[B](f: A => B): StateMonad[S, B] = StateMonad { oldS =>
    val (s, a) = execute(oldS)
    (s, f(a))
  }

end StateMonad

case class BState(vector: Vector[Int], swapped: Boolean, counter: Int)

Before entering into the details, having the State monad implemented like this gives us some nice syntax sugar. We can, for example, use for comprehension. This means that when we concatenate flatMaps as we did in checkEqualEncodingPureMonad we can write it as follows:

//  scFlatMap(
//    decodeFromHexPure(hex),
//    x => scFlatMap(decodeFromBase58Pure(b58), y => x.sameElements(y))
//  )
// the code below does not compile, 
// because we did not write our monad using the convention that Scala is expecting
for {
	x <- decodeFromHexPure(hex)
	y <- decodeFromBase58Pure(b58
} yield x.sameElements(y)

Back to the State monad, it clearly implements the same structure that we have seen before. The intermediate action is just passing the new state around. You might be wondering if we can do the bubble sort using the state monad. The answer is yes, and the code looks surprisingly similar to the code that we wrote in our impure implementation. However, the bad news is that we need to implement some generic boilerplate to make it as similar as possible to the impure version. The good news is that this kind of boilerplate is not necessary to write in the real world, as libraries such as cats already have implementation for most of the things we implemented here. Here is the full code:

// some useful methods
object StateMonad:
  def get[S]: StateMonad[S, S] = StateMonad(s => (s, s))
  def set[S](newS: S): StateMonad[S, Unit] = StateMonad(_ => (newS, ()))
  def update[S](u: S => S): StateMonad[S, Unit] = StateMonad(s => (u(s), ()))
  def unit[S] = StateMonad[S, Unit](s => (s, ()))
  def ifThen[S](cond: => Boolean)(sm: StateMonad[S, Unit]) =
    if (cond) then sm else unit[S]
  def doWhile[S](
      sm: StateMonad[S, Unit]
  )(cond: S => Boolean): StateMonad[S, Unit] =
    for {
      _ <- sm
      s <- get[S]
      _ <- if (cond(s)) then doWhile(sm)(cond) else unit[S]
    } yield ()

type BSortState[A] = StateMonad[BState, A]

import StateMonad._

def setSwapped(v: Boolean) = update[BState](_.copy(swapped = v))
def setCounter(i: Int) = update[BState](_.copy(counter = i))

def swapWithMonad(a: Vector[Int], i: Int): BSortState[Unit] =
  val temp = a(i - 1)
  for {
    _ <- update[BState](s =>
      s.copy(vector = a.updated(i - 1, a(i)).updated(i, temp))
    )
    _ <- setSwapped(true)
  } yield ()

def incCounter() =
  update[BState](s => s.copy(counter = s.counter + 1))

def getVector() = get[BState].map(_.vector)

def getCounter() = get[BState].map(_.counter)

// the actual bubble sort
def bubblesortPureMonad(l: Vector[Int]): Vector[Int] =
  StateMonad
    .doWhile[BState](
      for {
        _ <- setSwapped(false)
        _ <- doWhile[BState] {
          for {
            a <- getVector()
            i <- getCounter()
            _ <- ifThen[BState](a(i - 1) > a(i))(swapWithMonad(a, i))
            _ <- incCounter()
          } yield ()
        }(x => x.counter < x.vector.size)
        _ <- setCounter(1)
      } yield ()
    )(x => x.swapped)
    .execute(BState(l, false, 1))
    ._1
    .vector

bubblesortPureMonad(Vector(3, 4, 5, 2, 1)) // Vector(1, 2, 3, 4, 5)

If you look at the code of the actual bubble sort, it is very similar to the impure counterpart. The magic part there is the for comprehension that makes the code look imperative, while in the background is no more than a lot of concatenated flatMaps.

What is The Intuition Behind All This?

I hope that the intution behind monads is clear until now. Monads help functional programmers to access features of imperative/impure languages but in a pure and functional way. Intuitively, we can think of the laws as follows:

We can think of unit as just taking (or lifting) a value into a computation. If we were to write the unit function in an imperative program, it would be equivalente to the identity function. As such, if we were writing our code in an imperative way, we expect that:

val x = identity(value)
val y = performComputation(x)

Should be the same as:

y = performComputation(value) // this is the left identity law

In the same way, we expect that:

val x = performComputation(value)
val y = identity(x)

is the same as:

y = performComputation(value) // this is the right identity

This means that our intermediate actions can be executed in any order, and the full computation should compute the same result.

In the same way, if we were to translate the associativity law to imperative programming. We would have:

{
	val x = performComputation1(value)
	y = performComputation2(x)
}
val z = performComputation3(y)

And we would expect that to the the same as:

val x = performComputation1(value)
{
	val y = performComputation2(x)
	val z = performComputation3(y)
} // this is the associativy law

As you can see, the monads laws just formalize the laws that we would expect to hold when running an arbitrary imperative program.

Conclusion

We have explored monads, seen their formal definition and also explored some real world algorithm using imperative programming, functional programming without monads and functional programming with monads. I hope that the examples and explanations have helped you to improve your understanding of monads and that you can start using it in your code.

If you have read until here and liked this content please consider subscribing to my mailing list. I write regularly about Blockchain Technology, Scala and Language Engineering. Have questions? Leave a comment below and I’ll try to answer!

comments powered by Disqus