Sunday, April 5, 2009

Tail calls, @tailrec and trampolines

Recursion is an essential part of functional programming. But if each call allocates a stack frame, then too much recursion will overflow the stack. Most functional programming languages solve this problem by eliminating stack frames through a process called tail-call optimisation. Unfortunately for Scala programmers, the JVM doesn't perform this optimisation.

Here's a picture of a Scala program as it executes. This program tries to work out whether 9999 is even or odd by calling odd1 and even1 recursively. The stack overflows before we can make 9999 calls.

def odd1(n: Int): Boolean = {
  if (n == 0) false
  else even1(n - 1)
}
def even1(n: Int): Boolean = {
  if (n == 0) true
  else odd1(n - 1)
}
even1(9999)

All the calls in our example program are in tail position, so if the JVM did support tail-call optimisation, then the program would be able to complete successfully.

Luckily, even without JVM support, the Scala compiler can perform tail-call optimisation in some cases. The compiler looks for certain types of tail calls and translates them automatically into loops. At the moment it can optimise self calls in final methods and in local functions. It cannot optimise non-final methods (because they might be overridden by a subclass), and it cannot optimise tail calls that are made to different methods.

What this means

Because of these limitations, you need to be careful when using recursion in Scala. When writing programs, you will need to keep in mind how both the compiler and the JVM work. One safe approach is to use code from the standard library, where possible. For example, you'll find that many recursive algorithms can easily be rewritten in terms of standard operations like map and fold.

In Scala 2.8, you will also be able to use the new @tailrec annotation to get information about which methods are optimised. This annotation lets you mark specific methods that you hope the compiler will optimise. You will then get a warning if they are not optimised by the compiler. In Scala 2.7 or earlier, you will need to rely on manual testing, or inspection of the bytecode, to work out whether a method has been optimised.

If you do find a call that you think should be optimised by the compiler, but isn't, then you should check that the call:

  1. is a tail call,
  2. is in a final method or local function, and
  3. is to itself.

For example, the code for factorial below would not be optimised. The call is not in tail position (the tail operation is the multiplication), and the method is public and non-final, so it could be overridden by a subclass.

class Factorial1 {
  def factorial(n: Int): Int = {
    if (n <= 1) 1
    else n * factorial(n - 1)
  }
}

You can make simple changes to factorial to eliminate both of these problems. First, you could move the recursive code into a local function within the method, so that it cannot be overridden. Second, you could introduce an accumulator so that multiplication happens before the recursive call. Finally, you could add a @tailrec annotation so that you can be sure that your changes have worked.

import scala.annotation.tailrec

class Factorial2 {
  def factorial(n: Int): Int = {
    @tailrec def factorialAcc(acc: Int, n: Int): Int = {
      if (n <= 1) acc
      else factorialAcc(n * acc, n - 1)
    }
    factorialAcc(1, n)
  }
}

But there are some types of recursive code that the compiler will not be able to optimise. For example, if your code is mutually recursive, as it is with odd1 and even1, then you will need to try something else. One thing you might consider, is using a trampoline.

Trampolines

A trampoline is a loop that repeatedly runs functions. Each function, called a thunk, returns the next function for the loop to run. The trampoline never runs more than one thunk at a time, so if you break up your program into small enough thunks and bounce each one off the trampoline, then you can be sure the stack won't grow too big.

Here is our program again, rewritten in trampolined style. Call objects contain the thunks and a Done object contains the final result. Instead of making a tail call directly, each method now returns its call as a thunk for the trampoline to run. This frees up the stack after each iteration. The effect is very similar to tail-call optimisation.

def even2(n: Int): Bounce[Boolean] = {
  if (n == 0) Done(true)
  else Call(() => odd2(n - 1))
}
def odd2(n: Int): Bounce[Boolean] = {
  if (n == 0) Done(false)
  else Call(() => even2(n - 1))
}
trampoline(even2(9999))

It only takes a few lines of code to implement a trampoline.

sealed trait Bounce[A]
case class Done[A](result: A) extends Bounce[A]
case class Call[A](thunk: () => Bounce[A]) extends Bounce[A]

def trampoline[A](bounce: Bounce[A]): A = bounce match {
  case Call(thunk) => trampoline(thunk())
  case Done(x) => x
}

Trampolined code is harder to read and write, and it executes more slowly. However, trampolines can be invaluable when your program would otherwise run out of stack space, and the only other alternative is to convert it into an imperative style. There has recently been talk of including a trampoline implementation in Scala 2.8. (The code in this article is based on the code from that discussion.)

Postscript: Continuations

I've been writing about continuations quite a bit recently, so I think it's fitting to mention their relationship to trampolines. It turns out that a thunk can be easily manufactured from a continuation. You can create thunks automatically using shift and reset. In fact my recent implementation of goto used a form of trampoline. I'll close here by showing how goto can be implemented using the trampoline that we defined above.

import scala.continuations.cps
import scala.continuations.ControlContext.{shift,reset}

def trampolineK[A,B1<:C,C1<:C,C](body: => B1 @cps[B1,Bounce[C1]]): C =
  trampoline(reset(Done(body)))

case class Label[A](k: Label[A] => Bounce[A])

def label[A]: Label[A] @cps[Bounce[A],Bounce[A]] =
  shift((k: Label[A] => Bounce[A]) => k(Label(k)))

def goto[A](l: Label[A]): Unit @cps[Bounce[A],Bounce[A]] =
  shift((k: Nothing => Bounce[A]) => Call(() => l.k(l)))

trampolineK {
  var sum = 0
  var i = 0
  val beforeLoop = label
  if (i < 10000) {
    println(i)
    sum += i
    i += 1
    goto(beforeLoop)
  }
  println(sum)
}

Update: Fixed image scaling.

Sunday, March 15, 2009

Goto in Scala

Ever wished you could use goto in your Scala programs? :-) You'll be glad to know that a new compiler plugin, currently in development, makes it possible.

import Goto._

reset {
  var sum = 0
  var i = 0
  val beforeLoop = label // Create label for goto
  if (i < 10000) {
    println(i)
    sum += i
    i += 1
    goto(beforeLoop) // Jump to label
  }
  println(sum)
}

The new plugin brings delimited continuations to Scala. With continuation support, it becomes possible for end users to extend the language in new and interesting ways. For example, it is possible to write new control flow operations. Here's some code that implements the famous goto operation, which was for some reason omitted from the Scala language.

object Goto {

  case class Label(k: Label => Unit)

  private case class GotoThunk(label: Label) extends Throwable

  def label: Label @suspendable =
    shift((k: Label => Unit) => executeFrom(Label(k)))

  def goto(l: Label): Nothing =
    throw new GotoThunk(l)

  private def executeFrom(label: Label): Unit = {
    val nextLabel = try {
      label.k(label)
      None
    } catch {
      case g: GotoThunk => Some(g.label)
    }
    if (nextLabel.isDefined) executeFrom(nextLabel.get)
  }

}

To use this code, you would first call label to get a Label object. You could then jump to the label by calling goto and passing the Label object as an argument.

Here's how it works.

Calling label captures the continuation with shift. We store the continuation in a Label object. When we run the continuation, the Label is passed as an argument. This allows it to be used by code within the continuation—by calls to goto, for example.

Calling goto throws a GotoThunk that contains the Label. We then catch the thunk. We can catch it because we made sure to run our continuation in a try/catch block. Once we catch the thunk, we extract its Label and run the Label's continuation, starting the process all over again. In this way we can continue calling goto indefinitely…

Goto is only one of many control flow operations that we can construct with delimited continuations. For example, we can make conditionals, loops, generators, coroutines, etc.

If you're interested in playing with the plugin yourself, you can find some brief usage instructions in my previous blog entry. Remember that the plugin is still under development! You may also wish to read a description of the plugin's status that was posted to scala-user by Tiark Rompf, the plugin's author.

Tuesday, February 24, 2009

Delimited continuations in Scala

In a recent thread on scala-user it was announced that Tiark Rompf is writing a Scala compiler plugin for working with delimited continuations. The plugin is still in development, but it is scheduled to be included in Scala 2.8.

The plugin transforms the parts of the program contained within a call to reset into continuation-passing form. Within the transformed part of the program, continuations can be accessed by calling shift.

All of this is done (rather elegantly, I think) by introducing a new @cps type annotation. At compile time the plugin converts anything that produces a value of type A @cps(B,C) into a ControlContext[A,B,C] object that contains the computation. The A here represents the output of the computation, which is also the input to its continuation. The B represents the return type of that continuation, and the C represents its "final" return type—because shift can do further processing to the returned value and change its type. (Hopefully someone will read this and let me know if I'm wrong about any of this!)

I'm excited about this plugin because I think it is going to be very helpful for the asynchronous IO work I'm doing as part of Scala OTP. For some time now I've been manually converting parts of Scala OTP to use continuation-passing style (so that those parts can be suspended while waiting for IO). This is painstaking work, and also makes the code very hard to read. Using the plugin will simultaneously make my life a lot easier and make the code more accessible for others.

The main shortcoming I can see with the plugin at the moment is that it doesn't really handle exceptions. However, I've been looking at the plugin's source and I think it should be possible to add exception support without too much trouble.

Building the plugin

This is pretty easy to do.

First you will need a recent build of Scala. You can either install a nightly binary or build it from source yourself.

$ wget http://www.scala-lang.org/archives/downloads/distrib/files/nightly/scala-2.8.0.r17146-b20090219020925.tgz
$ tar xzf scala-2.8.0.r17146-b20090219020925.tgz

Once you have a build, set the SCALA_HOME environment variable to point to your new installation.

$ export SCALA_HOME=.../scala-2.8.0.r17146-b20090219020925

Next get the source for the continuations plugin.

$ svn co http://lampsvn.epfl.ch/svn-repos/scala/compiler-plugins/continuations/trunk continuations
$ cd continuations

Then build it.

$ ANT_OPTS=-Xmx512m ant test

You should see the plugin's JAR being built (selectivecps-plugin.jar) and its tests should run and pass.

Trying it out

You might want to have a go at writing code that uses delimited continuations. At the moment there isn't any documentation for the plugin, so the best way to work things out is to look at the code in doc/examples and in test.

It's also sometimes helpful to pass -Xprint:selectivecps to scalac when you're compiling your programs. This option tells the compiler to print out the program immediately after the plugin has done its work. By this point in the compilation all the continuation-passing has been made explicit so you're looking at normal Scala code again. This can make it a lot easier to work out what's going on.

Here's an example that demonstrates how an asynchronous alternative to Thread.sleep can be written by using shift and reset. The call to shift inside the sleep method captures the continuation up to the end of reset. Shift starts a new thread that first sleeps, then runs the continuation.

import scala.continuations.ControlContext._

object Example {
  def sleep(delay: Long) = shift { k: (Unit => Unit) =>
    val runnable = new Runnable {
      def run = {
        Thread.sleep(delay)
        k()
      }
    }
    val thread = new Thread(runnable)
    thread.start
  }
  def main(args: Array[String]) {
    println("Before reset")
    reset {
      println("Before sleep")
      sleep(2000)
      println("After sleep")
    }
    println("After reset")
  }
}

First we compile the program. We need to tell the compiler to run the new plugin.

$ $SCALA_HOME/bin/scalac -Xplugin:build/pack/selectivecps-plugin.jar -classpath build/build.library example.scala

Then we run it.

$ $SCALA_HOME/bin/scala -classpath build/build.library:. Example
Before reset
Before sleep
After reset
After sleep

Look at the output. You can see that the order of "After sleep" and "After reset" has been reversed. This is because "After sleep" has been captured by shift and runs later, in a different thread.

Finally, you can see exactly what the plugin has done to the program by compiling with -Xprint:selectivecps. I won't explain the transformation here, but I do recommend you try to understand it if you're going to do serious work with delimited continuations.

object Example {
  def sleep(delay: Long) = ControlContext.shift[Unit, Unit, Unit](((k: (Unit) => Unit) => {
    val runnable: new Runnable {
      def run: Unit = {
        Thread.sleep(delay);
        k(())
      }
    }
    val thread = new Thread(runnable);
    thread.start()
  }))
  def main(args: Array[String]): Unit = {
    println("Before reset");
    ControlContext.reset[Unit, Unit]({
      println("Before sleep");
      sleep(2000L).map[Unit](((tmp1: Unit) => {
        println("After sleep")
      }))
    })
    println("After reset")
  }
}

So that's a simple example of using the plugin. Shift and reset are extremely powerful primitives (this example doesn't really do them justice). By supporting delimited continuations natively, Scala 2.8 will make it possible for users to extend the language in new and exciting ways.

Update: Fixed repository URL.