Last updated on

Effects, memoization

Welcome to week 11 of CS-214 β€” Software Construction!

As usual, ⭐️ indicates the most important exercises and questions and πŸ”₯ indicates the most challenging ones. Exercises or questions marked πŸ§ͺ are intended to build up to concepts used in this week’s lab.

You do not need to complete all exercises to succeed in this class, and you do not need to do all exercises in the order they are written.

We strongly encourage you to solve the exercises on paper first, in groups. After completing a first draft on paper, you may want to check your solutions on your computer. To do so, you can download the scaffold code (ZIP).


The beauty of functional programming is that it provides referential transparency, the property that the results of a function depend only on its inputs. But many things can break referential transparency:

Traditionally, these are called effects. In the coming weeks we’ll see safe ways to encapsulate all these and recover referential transparency, using monads, but this week we’ll focus instead on special cases where local uses of these features do not affect referential transparency.

In fact, we have already seen one example: tail recursion elimination! Converting a recursive function to a loop that uses a mutable variable does not change anything for the callers of that function, but saves space on the stack, and prevents stack overflows. In fact, this transformation is so safe that the compiler does it automatically, just as how we’d do it by hand.

This week we’ll focus on this and two more examples:

Avoiding recursion

In your past experience, you probably used loops to implement most of the algorithms you encountered. However, during this course, we have used exclusively recursion to implement looping behaviour. Is recursion useless? Is it only complicating things for the sake of it? Let’s find out! To do so, we will rewrite recursive functions as loop and vice versa.

Converting tail-recursive functions to loops ⭐️

Let us start by converting tail recursive functions to loops. As a reminder, a tail recursive function is recursive function in which the recursive call(s) is(are) the last thing(s) to be executed in the function. For example, the following function is tail recursive:

@tailrec
def fTR(x: Int, acc: Int): Int =
  if (x == 0) acc
  else fTR(x - 1, acc + 1)

Indeed, the last thing to be executed in the body of f is either 0 or f(x - 1), which is the recursive call.

The @tailrec annotation is used to tell the compiler that the function is supposed to be tail recursive. The compiler will then emit an error if that is not the case. This is useful to make sure that you did not make a mistake when writing the function. When the compiler detects that a function is tail recursive, either because of the annotation or because it is able to infer it, it will compile it to a loop to improve the performance. Indeed a loop does not need to use space on the stack for each recursive call.

On the other hand, the following function is not tail recursive:

def f(x: Int): Int =
  if (x == 0) 0
  else 1 + f(x - 1)

In this case, the last thing executed in the case x != 0 is the addition 1 + f(x-1). This means that the recursive call must first be completed and then the result of the addition can be computed. This is not tail recursive.

Now that we have a better understanding of tail recursion, Let us convert some tail recursive functions to loops. For this exercise, you will need to convert the functions reverseAppend, length and foldLeft from the execises of week 1. The goal is to understand the process that the compiler follows by replicating manually, following a mechanical procedure.

Hint

Here is the procedure you can follow to convert a tail recursive function to a loop. We illustrate it with the tail recursive function fTR defined above.

  • Add a while true do loop:

    def fTRLoop(x: Int, acc: Int = 0): Int =
      while true do
    
  • Create a mutable variable for each parameter of the function and assign the values of the parameters to them.

    def fTRLoop(x: Int, acc: Int = 0): Int =
      var xVar = x
      var accVar = acc
      while true do
    
  • Create a mutable variable that will contains the return value of the function. Here it is the acc parameter. This parameter has a default value, so we can remove the parameter and assign the default value to the variable.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
    
  • Put the body of the function as the body of the loop.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
    
  • Replace the base case by a return statement returning the accumulator variable.

    def fTRLoop(x: Int): Int =
      var xVar = x
      var acc = 0
      while true do
        if xVar == 0 then
          return acc
    
  • Replace the recursive call by an assignment to the parameter variables.

    def fTRLoop(x: Int): Int =
      var acc = 0
      var xVar = x
      while true do
        if xVar == 0 then
          return acc
        acc = acc + 1
        xVar = xVar - 1
    
  • However, the compiler will not be happy with this code. Indeed, when type checking is performed, it will evaluate the while loop expression to be of type Unit. To solve this issue, the last statement should be of type Int, or throwing an exception, depending on your preference. Please note, however, that this last statement is unreachable and will then never be executed!

    def fTRLoop(x: Int): Int =
      var acc = 0
      var xVar = x
      while true do
        if xVar == 0 then
          return acc
        acc = acc + 1
        xVar = xVar - 1
      throw new AssertionError("Unreachable")
    

And there you go, the loop version of the tail recursive function!

def fTRLoop(x: Int): Int =
  var acc = 0
  var xVar = x
  while true do
    if xVar == 0 then
      return acc
    acc = acc + 1
    xVar = xVar - 1
  throw new AssertionError("Unreachable")

Here is a good resource about tail recursion elimination (the process of converting tail recursive functions to loops): Debunking the “expensive procedure calls” myth.

Now that you know the procedure, convert the following functions:

@tailrec
def reverseAppend(l1: List[Int], l2: List[Int]): List[Int] =
  if l1.isEmpty then l2
  else reverseAppend(l1.tail, l1.head :: l2)

def reverseAppendLoop(l1: List[Int], l2: List[Int]): List[Int] =
    ???

src/main/scala/tailRecursion/lists.scala

@tailrec
def foldLeft(l: List[Int], acc: Int)(f: (Int, Int) => Int): Int =
  if l.isEmpty then acc
  else foldLeft(l.tail, f(acc, l.head))(f)

def foldLeftLoop(l: List[Int], startValue: Int)(f: (Int, Int) => Int): Int =
    ???

src/main/scala/tailRecursion/lists.scala

@tailrec
def sum(l: List[Int], acc: Int = 0): Int =
  if l.isEmpty then acc
  else sum(l.tail, acc + l.head)

def sumLoop(l: List[Int]): Int =
    ???

src/main/scala/tailRecursion/lists.scala

Now we know how to transform mechanically a tail recursive function to a loop. So for the rest of this exercise set, feel free to write either loops or tail recursive functions, depending on what you prefer (when appropriate of course).

Now the question is: can we write any recursive function as tail recursive? If so, why do we bother with recursive functions? Let us find out!

foldt

Let’s recall the foldt function of the SE exercises:

extension [T](l: List[T])
  def pairs(op: (T, T) => T): List[T] = l match
    case a :: b :: tl => op(a, b) :: tl.pairs(op)
    case _            => l
  def foldt(z: T)(op: (T, T) => T): T = l match
    case Nil       => z
    case List(t)   => t
    case _ :: tail => l.pairs(op).foldt(z)(op)

How would you write the function foldt with loops? You can start from the following template:

extension [T](l: List[T])
  def foldt(z: T)(op: (T, T) => T): T =
    ???

src/main/scala/tailRecursion/lists.scala

Extra exercise: how would you write pairs with a while loop?

groupBy

Now, you will implement a function groupBy by yourself, without using the standard groupBy method.

Implement two versions of groupBy:

foldLeft versus foldRight

We commonly use foldLeft and foldRight to shorten simple recursive functions.

Can foldLeft be rewritten as a loop? How about foldRight? In both cases, write the code, or explain why it cannot be rewritten that way.

Tail recursion modulo context πŸ”₯

We saw in the first exercise that a tail recursive function can be mechanically transfored to a loop. This transformation is mostly useful because it is performed automatically by the compiler.

In this exercise, we will explore a way to rewrite some non-tail recursive functions into tail-recursive ones.

To illustrate this technique, let us consider the map function on List[Int]:

def map(l: List[Int], f: Int => Int): List[Int] =
  if l.isEmpty then Nil
  else f(l.head) :: map(l.tail, f)

src/main/scala/tailRecursion/lists.scala

The only thing that happens after the call is the creation of a Cons instance. To create it, we need to know the head (which we know before the recursive call) but also the tail (which is computed recursively). So the tail is the difficult part: when the recursive call completes, the callee returns the tail. Does that suggest a solution?

Hint

The trick could be to shift responsibilities around so that the caller begins the construction of the Cons, and the callee finishes that construction by storing the computed tail.

To do so, we will create a list type with a mutable tail. This way we can construct the list before making the recursive call, and transfer the responsability to swap the tail to the recursive call.

Now that you have the idea, try to implement the mapTRWorker function:

enum MutableList:
  case Nil
  case Cons(val hd: Int, var tail: MutableList)

import MutableList.*

def mapTR(l: MutableList, f: Int => Int): MutableList =
  l match
    case Nil => Nil
    case Cons(hd, tl) =>
      val acc: Cons = Cons(f(hd), Nil)
      mapTRWorker(tl, f, acc)
      acc

// @tailrec uncomment when working on the exercise
def mapTRWorker(
    l: MutableList,
    f: Int => Int,
    acc: MutableList.Cons
): Unit =
  ???

src/main/scala/tailRecursion/lists.scala

Looping on Trees

In our quest to find a case in which recursion really is easier to use than loops, we will now look at trees. We will use the following definition of binary trees:

enum Tree[T]:
  case Leaf(value: T)
  case Node(left: Tree[T], right: Tree[T])

src/main/scala/tailRecursion/trees.scala

Sum of leaves - rotation

Let us start with a simple function that computes the sum of a tree’s leaves. The recursive version is the following:

def sumRec(t: Tree[Int]): Int =
  t match
    case Leaf(value)       => value
    case Node(left, right) => sumRec(left) + sumRec(right)

src/main/scala/tailRecursion/trees.scala

On right line trees ⭐️

In the 2023 midterm, we saw the concept of right line trees. As a reminder, a right line tree is a tree in which each node is either a leaf, or has a leaf child on the left. The following function checks whether a tree is a right line tree:

def isRightLineTree(t: Tree[Int]): Boolean =
  t match
    case Leaf(_)              => true
    case Node(Leaf(_), right) => isRightLineTree(right)
    case _                    => false

src/main/scala/tailRecursion/trees.scala

Can you see the similarity between a right line tree and a list in the context of tail recursive functions?

Before writing any code, think about this: what can (a + b) + c = a + (b + c) mean on trees? Can we exploit this to write a loop (or tail recursive) function?

Hint

This represents the right rotation on trees. This property of + is the associativity.

In our context, it means that the tree can be rearranged to compute the sum of leaves in a different way without affecting the result.

Let us write an imperative version (or tail recursive, as you prefer) of the sum function that works only for right line trees. Do not forget to add the correct scala call ensure that the tree is indeed a right line tree before computing the sum πŸ™‚:

def sumRightLineTree(tr: Tree[Int]): Int =
  ???

src/main/scala/tailRecursion/trees.scala

What would happen if the operation is not associative, like, for example, the substraction?

Hint You can take inspiration from the sum function on the list we saw in the first exercise of this session. Think about the similarity between the structure of a list and the one from a right line tree.
Using rotations ⭐️

A right rotation is an operation on a tree that gives a new tree with less leaves on the left hand side. Can we use this operation to compute the sum of leaves on an arbitrary tree while reusing the idea of the sum we implemented on the right line tree? Let’s find out!

Implement the sumRotate function that computes the sum of leaves’ values using right rotations:

def sumRotate(tr: Tree[Int], acc: Int): Int =
  ???

src/main/scala/tailRecursion/trees.scala

Can you name which property the operation done on the leaves must satisfy for this to work?

Sum of leaves - DFS

Now, let us write an imperative version of the sum function. Before writing any code think well about it. On what elements would you iterate? How make sure you visit all the nodes? How would you keep track of the nodes you still need to visit?

def sumLoop(t: Tree[Int]): Int =
  ???

src/main/scala/tailRecursion/trees.scala

Hint As you might have realised, this is not straightforward. The main issue is that you need to keep track of the nodes to visit. What datastructure would you use to store the nodes you encounter and will visit later?
Spoiler You should indeed use a Stack to keep track of the nodes you have to visit. You can use again the Stack class from the scala library.

Reduce on tree πŸ”₯

We will now take a look at another function on trees: reduce. As a reminder, reduce is defined recursively as follows on trees:

def reduce[T](tr: Tree[T], f: (T, T) => T): T =
  tr match
    case Leaf(value)       => value
    case Node(left, right) => f(reduce(left, f), reduce(right, f))

src/main/scala/tailRecursion/trees.scala

We will write an imperative version of this function.

To kickstart, let us implement a mutable Stack structure, just as you used in the previous exercises. Our MStack is based on a List and will extend the following trait:

trait MStackTrait[A]:
  def push(a: A): Unit
  def pop(): A
  def isEmpty: Boolean
  def size: Int
  def contains(a: A): Boolean

case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
  def push(a: A): Unit =
    ???
  def pop(): A =
    ???
  def isEmpty: Boolean =
    ???
  def size: Int =
    ???
  def contains(a: A): Boolean =
    ???

src/main/scala/tailRecursion/trees.scala

Now let us implement a post order traversal on trees. This function will return the subtrees in post order, which means first the left child, then the right child, then the node itself. For example, the post order traversal of the following tree:

val tree =
  Node(
    Node(
      Leaf(1),
      Leaf(2)
    ),
    Leaf(3)
  )

is the following list:

List(
  Leaf(1),
  Leaf(2),
  Node(Leaf(1), Leaf(2)),
  Leaf(3),
  Node(Node(Leaf(1), Leaf(2)), Leaf(3))
)

Now, implement the postOrderTraversal function using a while loop and the MStack type that you just implemented. Think hard before writing the function. How do you keep track of the nodes you will visit? How you ensure that you add the nodes in the correct order?

def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
  ???

src/main/scala/tailRecursion/trees.scala

This postorder traversal should be enough to implement reduce!

Hint

You’ll need an intermediate data structure to keep track of partially reduced results while you go over the post order. You can use a Map that associates tree notes to the result of reduce on them, or you can use a Stack with a bit more thinking about the order in which nodes appear in the post-order.

def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
  ???

src/main/scala/tailRecursion/trees.scala

Map on tree

Now that you implemented reduce, you can implement map using the same principles.

Proof of correctness of reduce on trees πŸ”₯

We will now revisit the reduce function that uses the post order traversal from the exercise Reduce on tree. If you did not do it, here is the implementation:

Solution
trait MStackTrait[A]:
  def push(a: A): Unit
  def pop(): A
  def isEmpty: Boolean
  def size: Int
  def contains(a: A): Boolean

case class MStack[A](var l: List[A] = Nil) extends MStackTrait[A]:
  def push(a: A): Unit =
    l = a :: l
  def pop(): A =
    val a = l.head
    l = l.tail
    a
  def isEmpty: Boolean =
    l.isEmpty
  def size: Int =
    l.size
  def contains(a: A): Boolean =
    l.contains(a)

src/main/scala/tailRecursion/trees.scala

def postOrderTraversal[T](tr: Tree[T]): List[Tree[T]] =
  var toVisit = MStack[Tree[T]]()
  toVisit.push(tr)
  var postOrderNodes: List[Tree[T]] = Nil
  while !toVisit.isEmpty do
    val n = toVisit.pop()
    postOrderNodes = n :: postOrderNodes
    n match
      case Node(left, right) =>
        toVisit.push(left)
        toVisit.push(right)
      case Leaf(_) =>
  postOrderNodes

src/main/scala/tailRecursion/trees.scala

def reduceLoop[T](tr: Tree[T], f: (T, T) => T): T =
  var cache: Map[Tree[T], T] = Map()

  for (t, idx) <- postOrderTraversal(tr).zipWithIndex do
    t match
      case Leaf(v) => cache = cache + (t -> v)
      case Node(left, right) =>
        val leftValue = cache(left)
        val rightValue = cache(right)
        cache = cache + (t -> f(leftValue, rightValue))
  cache(tr)

src/main/scala/tailRecursion/trees.scala

If you are interested in program verification and proofs, two courses are given at EPFL in this area:

Post order traversal

Let us start by proving the correctness of the post order traversal algorithm. In words, the algorithm is correct if the produced list contains all the nodes of the tree, and if the order of the nodes is indeed a post order traversal (i.e., the children appears in the list at smaller index than their parent ). In particular, the list should end with the root.

Your task is to write the above postcondition in scala code and a loop invariant for the postOrderTraversal function that proves it is indeed satisfied at the end. Be careful, the invariant must take the state of the stack into account.

reduce

Now that you proved the correctness of the post order traversal, you can prove the correctness of the reduce function. The postcondition of reduce in our case is that the cache contains the root, and that this value is equal to reduce(root, f).

Your task is to write the above postcondition in scala code and a loop invariant for the reduce function that proves it is indeed satisfied at the end. You can write one invariant encoding the validity of the cache, i.e., that all values it contains are indeed correct with respect to the key and the function f, and one invariant that encodes the correctness of how the cache is updated in the loop.

Exceptional control flow

An exceptional contains method ⭐️

  1. Consider the following two implementations of contains:

    extension [T](l: List[T])
      final def containsRec(t0: T): Boolean =
        l match
          case Nil      => false
          case hd :: tl => hd == t0 || tl.containsRec(t0)
    

    src/main/scala/exceptions/Exceptions.scala

    extension [T](l: List[T])
      final def containsFold(t0: T): Boolean =
        l.foldRight(false)((hd, found) => found || hd == t0)
    

    src/main/scala/exceptions/Exceptions.scala

    Is one of them preferable? Why?

  2. Which mechanism do you know to interrupt a computation before it completes? Use it to rewrite contains using forEach.

    Hint

    Use an exception! They work just the same in Scala as in Java.

    What advantages does this approach have?

Avoiding accidental escape: boundary/break ⭐️

Exceptions are great, but they risk escaping: if you forget to catch an exception raised for control flow, it will propagate to the caller of your function, and cause havoc there.

  1. Read the boundary/break documentation.

  2. Use a boundary to reimplement contains a fourth time.

  3. πŸ”₯ Which of these four implementations of contains is fastest? Make a guess, then confirm it by writing a JMH benchmark.

Value-carrying exceptions

  1. Define a custom error type to hold values. Use it to write an exception-based implementation of find.

  2. Use boundary/break instead of a custom error type.

Memoization

Briefly, memoization is the process of augmenting a function with a mutable cache that records the output of the function every time it is called. If the function is subsequently called again with a previously-seen input, the result can be returned from cache instead of being recomputed.

A step-by-step example ⭐️

To see why memoization may be useful, consider a simple example: the Fibonacci function, which we studied previously:

def fib(n: Int): Int =
  if n <= 1 then 1 else fib(n - 1) + fib(n - 2)

src/main/scala/memo/Fib.scala

To compute fib(4) we made two recursive calls: one to fib(3), and one to fib(2). To compute fib(3), we again make two recursive calls: one to fib(2), and one to fib(1). Without special precautions, we end up computing fib(2) twice. Other parts of the computation are similarly repeated.

       fib(4)
=== (  fib(3)                     + fib(2)           )
=== ( (fib(2)           + fib(1)) + (fib(1) + fib(0)))
=== (((fib(1) + fib(0)) + 1     ) + (1        1     ))
=== (((1       + 1   )  + 1     ) + (1        1     ))

Interestingly, the cost of computing fib(n) grows exactly as fib(n): if it takes $T(k)$ steps to compute fib(k), then the cost of computing fib(n) is $T(n) = T(n - 1) + T(n - 2)$.

Memoization

All this redundant computation is unnecessary. Instead, as our first attempt to address this problem, we can create a cache to store fib’s results:

import scala.collection.mutable.Map

def fibMemo(n: Int): Int =
  val cache: Map[Int, Int] = Map()
  def loop(idx: Int): Int =
    cache.getOrElseUpdate(
      idx,
      if idx <= 1 then 1
      else loop(idx - 1) + loop(idx - 2)
    )
  loop(n)

src/main/scala/memo/Fib.scala

  1. Can you convince yourself that this function behaves identically to the version without a cache?

  2. How large does the cache grow (i.e., how many entries get created in the cache) as we evaluate fib(k)? What entries does it contain when the computation completes?

Subproblem graph

To save space, we need to understand the structure of the subproblem graph of the Fibonacci function. The subproblem graph is a graph where:

For example, the nodes of the computation graph of fib(4) are 4, 3, 2, 1, 0 and its edges are 4 β†’ 3, 4 β†’ 2, 3 β†’ 2, 3 β†’ 1, 2 β†’ 1, 2 β†’ 0.

Here is one representation of the graph (notice that there are no edges from 1 to 0: they are both leaves):

      β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
      β”‚     v     v
4  β†’  3  β†’  2  β†’  1      0
β”‚     ^     ^β”‚    ^      ^
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

Dynamic programming

The subproblem graph captures precisely the notion of dependency: we cannot compute the output of a function on a given input node unless we know the outputs of the function on all the nodes it points to. Given this:

The key questions to be able to do dynamic programming are: How long do we need to remember cached values? Which computation order minimizes this time?

To answer, we proceed in three steps:

  1. Find a traversal of the subproblem graph, starting from the leaves, such that dependencies are always computed before their parents. This is called a reverse topological sort of the subproblem graph.

  2. Rewrite our algorithm to construct the memoization cache iteratively, in the order given by stepΒ 1.

    For Fibonnaci, the order is very simple: 0, 1, 2, 3, 4, …. In other words, to compute fib(n), it is sufficient to know the values of all fib(k) where k < n. The result of step 2 is, hence, as follows:

    import scala.collection.mutable.Map
    
    def fibIter(n: Int): Int =
      val cache: Map[Int, Int] = Map()
      for idx <- 0 to n do
        cache(idx) =
          if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2)
      cache(n)
    

    src/main/scala/memo/Fib.scala

  3. Discard entries from the memoization cache as soon as they are not used any more. In the case of the Fibonacci function, we only need to keep the last two entries:

    import scala.collection.mutable.Map
    
    def fibIterOpt(n: Int): Int =
      val cache: Map[Int, Int] = Map()
      for idx <- 0 to n do
        cache(idx) =
          if idx <= 1 then 1 else cache(idx - 1) + cache(idx - 2)
        cache.remove(idx - 2)
      cache(n)
    

    src/main/scala/memo/Fib.scala

    And we can, as a last cleanup step, entirely eliminate the cache, keeping only two variables:

    import scala.collection.mutable.Map
    
    def fibIterFinal(n: Int): Int =
      var f0 = 1
      var f1 = 1
      for idx <- 2 to n do
        val f = f0 + f1
        f0 = f1
        f1 = f
      f1
    

    src/main/scala/memo/Fib.scala

First application: $\binom{n}{k}$ “$n$ choose $k$”

The function choose(n, k) computes how many ways there are to pick k elements among n, without considering order and without allowing repetitions. This choice can be done in two ways:

  1. Implement choose as a recursive function using this equation. Mind the base cases!

    def choose(n: Int, k: Int): Int =
      ???
    

    src/main/scala/memo/Choose.scala

  2. Draw the subproblem graph of choose(5, 3) as a tree. Each node should have a pair of numbers, since the function takes two arguments. Do you notice repeated work?

  3. Write a memoized implementation of choose. The cache should map pairs of numbers (inputs) to single numbers (outputs):

    def chooseMemo(n: Int, k: Int): Int =
      ???
    

    src/main/scala/memo/Choose.scala

  4. Redraw the subproblem graph, but this time lay it out as an array with 6 columns and 4 rows: place node (i, j) at position x = i, y = j. What do you notice about the structure of the graph? Propose a reverse topological ordering of it.

  5. Replace the Map-based cache with a two-dimensional array, and rewrite the memoized algorithm to build the cache iteratively, without recursion.

  6. Is the whole cache needed at all times? Rewrite the algorithm to use less memory.

More applications: memoizing every previous CS214 problem ⭐️

Train yourself to add memoization to functions by revisiting previous CS-214 problems. Particularly relevant are coinChange and des chiffres et des lettres.

Benchmarking πŸ”₯

The original solution to the Anagrams lab was not exactly fast. Memoize the recursive part of the anagrams function, and measure the resulting speed improvements. How much faster does it get?

Tower of Hanoi

A popular item in dentist offices, children museums, and on β€œ10 original gift ideas for the holidays” lists is the game called β€œTower of Hanoi”.

The game has three pegs, and 7 disks of increasing size, each with a hole in their center. In the initial configuration, the disks are stacked from largest to smallest on the leftmost peg:

      -|-              |               |
     --|--             |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------        |               |
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

The aim is to move all the disks to the rightmost peg, with one rule: a larger disk may never rest on top of a smaller disk. Hence, this is a valid move:

       |               |               |
     --|--             |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------        |              -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

… and so is this:

       |               |               |
       |               |               |
    ---|---            |               |
   ----|----           |               |
  -----|-----          |               |
 ------|------         |               |
-------|-------      --|--            -|-
==== PEG 0 ==== ==== PEG 1 ==== ==== PEG 2 ====

… but after this the only valid moves are $2 \to 1$ (moving the disk from peg 2 to peg 1) as well as $2 \to 0$ and $1 \to 0$.

A solution of the game is a list of moves $i \to j$ that moves all disks from peg 0 to peg 2. Here is a solution for the case of three disks:

  -|-      |       |
 --|--     |       |
---|---    |       |
=Left= =Middle= =Right=

Left β†’ Right
   |       |       |
 --|--     |       |
---|---    |      -|-
=Left= =Middle= =Right=

Left β†’ Middle
   |       |       |
   |       |       |
---|---  --|--    -|-
=Left= =Middle= =Right=

Right β†’ Middle
   |       |       |
   |      -|-      |
---|---  --|--     |
=Left= =Middle= =Right=

Left β†’ Right
   |       |       |
   |      -|-      |
   |     --|--  ---|---
=Left= =Middle= =Right=

Middle β†’ Left
   |       |       |
   |       |       |
  -|-    --|--  ---|---
=Left= =Middle= =Right=

Middle β†’ Right
   |       |       |
   |       |     --|--
  -|-      |    ---|---
=Left= =Middle= =Right=

Left β†’ Right
   |       |      -|-
   |       |     --|--
   |       |    ---|---
=Left= =Middle= =Right=

These diagrams are generated by calling viewMoves(hanoi(3), 3). This function is provided in the code supplement to this exercise set.

  1. Any good text editor or IDE should have an implementation of tower of Hanoi (if you’re using Emacs, simply use M-x hanoi to start it). Use it to familiarize yourself with the game.

  2. Write a function that computes a solution to the problem with $n$ disks. Check below for an important hint, or skip the hint if you prefer a πŸ”₯ exercise. In any case, remember the fundamental question of recursive problems: how do I express a solution to my problem in terms of smaller subproblems?

    enum Peg:
      case Left, Middle, Right
    
    case class Move(from: Peg, to: Peg)
    

    src/main/scala/memo/Hanoi.scala

    def hanoi(n: Int): Seq[Move] =
      ???
    

    src/main/scala/memo/Hanoi.scala

    Hint

    You need to solve a more general problem: how to move $n$ disks from peg $a$ to peg $b$. Can you solve the problem with 7 disks if you know how to move the first 6 disks to the middle peg?

    def hanoiHelper(src: Peg, dst: Peg, third: Peg, n: Int): Seq[Move] =
      ???
    

    src/main/scala/memo/Hanoi.scala

  3. Can this program benefit from memoization?

A memoizing fixpoint combinator πŸ”₯

Take a look back at the combinator exercise from the polymorphism week.

Where we stopped, the first step of memoization always looks the same: starting from a function def f(input: …) = …, we write the following:

def fMemo(input: …) =
  val cache: Map[…] = Map()
  def loop(input: …) =
    cache.getOrElseUpdate(input,
      …(body)…)
  loop(input)

It would be nice to be able to abstract over this pattern. Define a higher-order function memo to do so:

def memo[A, B](f: (A, A => B) => B)(a: A): B =
  ???

src/main/scala/memo/Combinator.scala

This function should be such that we can define fib and choose as follows:

val fib = memo: (n: Int, f: Int => Int) =>
  if n <= 1 then 1 else f(n - 1) + f(n - 2)

src/main/scala/memo/Combinator.scala

val choose = memo[(Int, Int), Int] {
  case ((n, k), f) =>
    if k <= 0 || k >= n then 1
    else f((n - 1, k - 1)) + f((n - 1, k))
}

src/main/scala/memo/Combinator.scala

Hint

Start from the fixpoint combinator from the previous exercise:

def fixpoint[A, B](f: (A, A => B) => B)(a: A): B =
  def loop(a: A): B = f(a, loop)
  f(a, loop)

src/main/scala/memo/Combinator.scala