Tag Archives: scala-cookbook

Scala Cookbook: Simplifying nested pattern matching

Pattern matching in Scala is a natural and intuitive way of expressing branched logic and can greatly simplify complex branched logic in terms of readability but what if you need to branch within a branch? This can result in unsightly code and duplication of logic. Let’s see this first hand with an example.

Let’s try to code a function removeAt(n, List) which removes the zero indexed nth element from the supplied list and returns the resulting list. Our first attempt, as we transcribe our thoughts step by step into code, may look as follows.

  def removeAt[T](n: Int, xs: List[T]): List[T] = n match {
    case 0 => xs match {
      case Nil => Nil
      case y :: ys => removeAt(n - 1, ys)
    }
    case _ => xs match {
      case Nil => Nil
      case y :: ys => y :: removeAt(n - 1, ys)
    }
  }

Here you can see that, although the logic is correct and reasonably easy to interpret, there is duplication of code and the function could be made even easier to read and a little shorter. On our second attempt we may try to extract the duplicate Nil => Nil cases outside of the match as below.

  def removeAt[T](n: Int, xs: List[T]): List[T] =
    if (xs.isEmpty) Nil else n match {
      case 0 => xs match {
        case y :: ys => removeAt(n - 1, ys)
      }
      case _ => xs match {
        case y :: ys => y :: removeAt(n - 1, ys)
      }
    }

This also works and has in fact achieved our goals – the function is shorter and does not duplicate code. However, at this point, it is blending if/else logic with pattern matching and most importantly producing warnings in Scala IDE in Eclipse. The warnings state that in our secondary (internal) pattern match we do not cater for the Nil case and that’s correct.

Description	Resource	Path	Location	Type
match may not be exhaustive. It would fail on the following input: Nil	ListUtils.scala	/scala-test/src/week5	line 13	Scala Problem
match may not be exhaustive. It would fail on the following input: Nil	ListUtils.scala	/scala-test/src/week5	line 16	Scala Problem

Even though Scala IDE may not see that in the bigger picture the logic is actually correct it is correct in saying what it is saying specifically and I also consider it bad practice to leave warnings in code so let’s go through one more simplification attempt. How else can we restructure this function to achieve the same goals?

  def removeAt[T](n: Int, xs: List[T]): List[T] =
    xs match {
      case Nil => Nil
      case y :: ys => n match {
        case 0 => removeAt(n - 1, ys)
        case _ => y :: removeAt(n - 1, ys)
      }
    }

The answer is to invert the order of the high level pattern match subjects by making the subject that’s being pattern matched redundantly rise to become the primary pattern match with the other subject then becoming secondary. And there we have the end result of the simplification.

Addendum

A couple of after thoughts for completion follow here. Firstly – the above progression in implementation strategies was chosen to allow illustration of how to incrementally simplify a function and does not suggest that one cannot in fact end up with the final solution the first time. Secondly – the implementation algorithm above which is needlessly verbose has also been chosen to be illustrative of nested pattern matching and isn’t the simplest and most concise possible solution. For that see below.

  def removeAt[T](n: Int, xs: List[T]): List[T] =
    (xs take n) ::: (xs drop n+1)

Scala Cookbook: Tail recursive factorial

The direct port of a human readable description of the factorial algorithm may look as follows.

  def factorial(x: BigInt): BigInt = {
    if (x == 0) 1 else x * factorial(x - 1)
  }

However – this isn’t tail recursive and will fail for large inputs.

Exception in thread "main" java.lang.StackOverflowError
	at java.math.BigInteger.getInt(BigInteger.java:3014)

The classic nested accumulator helper function idiom can be used to make this tail recursive.

  def factorial(x: BigInt): BigInt = {
    @tailrec
    def f(x: BigInt, acc: BigInt): BigInt = {
      if (x == 0) acc else f(x - 1, x * acc)
    }
    f(x, 1)
  }

So what’s going on here? Well it’s quite simple really.

The first function uses a number of stack frames proportional to the size of the input due to the fact that it is using right expansion. In other words when it calls itself it multiplies the result by x tainting the stack so that it can’t be thrown away or reused. The stack is only flattened when fully expanded and then collapsed.

The second function is tail recursive – in other words – it calls itself as the final call. This is essentially compiled to a loop by the Scala compiler using only one stack frame regardless of the size of the input. It is therefore not possible to get a StackOverflowError with a tail recursive function.

Note the @tailrec annotation which when applied if it compiles then your function is certified tail recursive by the Scala compiler. It is a quick and easy way to check whether your function is tail recursive at compile time as well as serving as a useful hint for other developers maintaining your code.

A pattern matched equivalent of the tail recursive function for fun is below.

  def factorial4(x: BigInt): BigInt = {
    @tailrec
    def f(x: BigInt, acc: BigInt): BigInt = x match {
      case y: BigInt if y == 0 => acc
      case _ => f(x - 1, x * acc)
    }
    f(x, 1)
  }

Scala Cookbook: Pattern matching on BigInt

Let’s take the factorial algorithm as an example.

  def factorial(x: BigInt): BigInt = {
    if (x == 0) 1 else x * factorial(x - 1)
  }

One might expect to convert this into a pattern matched equivalent as follows.

  def factorial(x: BigInt): BigInt = x match {
    case 0 => 1
    case _ => x * factorial(x - 1)
  }

In fact this will give you the error below.

type mismatch;  found   : Int(0)  required: BigInt	Factorial.scala

The reason is that BigInt is not a case class and does not have an unapply() method in its companion object.

However a small modification allows the pattern matching implementation to succeed.

  def factorial(x: BigInt): BigInt = x match {
    case y: BigInt if y == 0 => 1
    case _ => x * factorial(x - 1)
  }

Arguably the if/else equivalent is simpler but if you always use pattern matching you’ll want this version.