May 14, 2019

Tail call elimination

In this post we'll explore what tail calls are, why they are useful, and how they can be eliminated in an interpreter, a compiler targeting C++, and a compiler targeting LLVM IR.

Tail calls

A tail call is a function call made at the end of a block that returns the value of the call (some languages do not force this return requirement). Here are a few examples.

function tailCallEx1() {
  // Loops forever but is a tail call.
  return tailCallEx1();
}

function tailCallEx2(x) {
  if (x) {
    return tailCallEx2(x - 1);
  }

  return 1;
}

function tailCallEx3(x) {
  return x && tailCallEx(x - 1) || 1;
}

function tailCallEx4(x) {
  switch (x) {
    case 0:
      return 1;
    default:
      return tailCallEx4(x - 1);
  }
}

And here are some examples of non-tail calls.

function nonTailCallEx1(x) {
  if (x) {
    // Not a tail call because the call is not the value returned.
    return 1 + nonTailCallEx1(x - 1);
  }

  return 0;
}

function nonTailCallEx2(x) {
  if (x) {
    const r = nonTailCallEx2(x - 1);
    // Not a tail call because the value is not *immediately* returned.
    console.log(r);
    return r;
  }

  return 1;
}

Why is this important?

Some languages can rewrite a recursive tail call as a jump/branch/goto instead of a function call. This allows:

  1. Potential performance gain if function calls have large overhead
  2. No stack overflows due to no nested function call stacks

Implementation 1: Interpreter

Given a tail call recursive fibonacci:

function fibonacci(a, b, n) {
  if (n === 0) {
    return a;
  }

  if (n === 1) {
    return b;
  }

  return fibonacci(b, a + b, n - 1);
}

Here is how we could transform (by hand) this without a tail call.

function fibonacci(a, b, n) {
  while (true) {
    if (n === 0) {
      return a;
    }

    if (n === 1) {
      return b;
    }

    const a1 = b;
    const b1 = a + b;
    const n1 = n - 1;
    a = a1;
    b = b1;
    n = n1;
  }
}

If this was written in a language with labels and goto we could simplify the code slightly by doing that. But it is the same effect as a loop.

Since we're in an interpreter (that isn't JIT compiling), we cannot pick between these two and must merge them. So we put all function bodies in a loop and break if it isn't a tail call. Otherwise we line up the paremeters and let the loop take us back.

Here is an example of this strategy used in a Scheme interpreter written in D.

// Define a new function with name `name` and add it to the context.
Value namedLambda(Value arguments, Context ctx, string name) {
  auto funArguments = car(arguments);
  auto funBody = cdr(arguments);

  Value defined(Value parameters, void** rest) {
    Context newCtx = ctx.dup;

    // Copy the runtime calling context to the new context.
    Context runtimeCtx = cast(Context)(*rest);
    auto runtimeCallingContext = runtimeCtx.callingContext;
    newCtx.callingContext = runtimeCallingContext.dup;

    Value result;
    // Loop forever, will break immediately if not a tail call
    bool tailCalling = false;
    while (true) {
      if (valueIsList(funArguments)) {
        auto keyTmp = valueToList(funArguments);
        auto valueTmp = valueToList(parameters);
        while (true) {
          auto key = valueToSymbol(keyTmp[0]);
          auto value = valueTmp[0];

          newCtx.set(key, value);

          // TODO: handle arg count mismatch
          if (valueIsList(keyTmp[1])) {
            keyTmp = valueToList(keyTmp[1]);
            valueTmp = valueToList(valueTmp[1]);
          } else {
            break;
          }
        }
      } else if (valueIsSymbol(funArguments)) {
        auto key = valueToSymbol(funArguments);
        newCtx.set(key, car(parameters));
      } else if (!valueIsNil(funArguments)) {
        error("Expected symbol or list in lambda formals", funArguments);
      }

      if (!tailCalling) {
        newCtx.callingContext.push(Tuple!(string, Delegate)(name, &defined));
      }

      result = eval(withBegin(funBody), cast(void**)[newCtx]);

      // In a tail call, let the loop carry us back.
      if (newCtx.doTailCall == &defined) {
        tailCalling = true;
        parameters = result;
        newCtx.doTailCall = null;
      } else {
        break; // Not in a tail call, we're done a regular function call.
      }
    }

    return result;
  }

  return makeFunctionValue(name, &defined, false);
}

We can not eliminate mutually recursive tail calls with this method. We could use continuation-passing style but that would not have addressed the concern: not making a function call.

Implementation 2: Compiling to C++

The strategy here is the same as in the interpreter except for that since tail call recursive functions are known at compile time, we can generate non-generalized code in function bodies.

Here is how a JavaScript compiler transforms the above fibonacci implementation into C++:

void tco_fib(const FunctionCallbackInfo<Value> &args) {
  Isolate *isolate = args.GetIsolate();
  double tco_n = toNumber(args[0]);
  double tco_a = toNumber(args[1]);
  double tco_b = toNumber(args[2]);

tail_recurse_1:

    ;

  bool sym_if_test_58 = (tco_n == 0);
  if (sym_if_test_58) {
    args.GetReturnValue().Set(Number::New(isolate, tco_a));
    return;
  }

  bool sym_if_test_70 = (tco_n == 1);
  if (sym_if_test_70) {
    args.GetReturnValue().Set(Number::New(isolate, tco_b));
    return;
  }

  Local<Value> sym_arg_83 = Number::New(isolate, (tco_n - 1));
  Local<Value> sym_arg_92 = Number::New(isolate, (tco_a + tco_b));
  tco_n = toNumber(sym_arg_83);
  tco_a = tco_b;
  tco_b = toNumber(sym_arg_92);
  goto tail_recurse_1;
}

This is implemented by checking every function call. If the function call is in tail call position, we generate code for jumping to the beginning of the function. Otherwise, we generate a call as usual.

Here is how the tail call check and code-generation is done in the compiler:

function compileCall(
  context: Context,
  destination: Local,
  ce: ts.CallExpression,
) {
  let tcoLabel;
  let tcoParameters;
  if (ce.expression.kind === ts.SyntaxKind.Identifier) {
    const id = identifier(ce.expression as ts.Identifier);
    const safe = context.locals.get(mangle(context.moduleName, id));

    if (safe && context.tco[safe.getCode()]) {
      const safeName = safe.getCode();
      tcoLabel = context.tco[safeName].label;
      tcoParameters = context.tco[safeName].parameters;
    }
  }

  ...

  if (ce.expression.kind === ts.SyntaxKind.Identifier) {
    const id = identifier(ce.expression as ts.Identifier);
    const mangled = mangle(context.moduleName, id);
    const safe = context.locals.get(mangled);

    if (safe) {
      if (tcoLabel) {
        args.forEach((arg, i) => {
          compileParameter(
            context,
            tcoParameters[i],
            i,
            i === args.length - 1,
            arg,
          );
        });

        context.emitStatement(`goto ${tcoLabel}`);
        context.emit('', 0);
        destination.tce = true;
        return;
      }
    }
  }

  ...

  // Otherwise generate regular function call

This requires you to have been building up the state throughout the AST to know whether or not any particular call is in tail position.

Implementation 3: Compiling to LLVM IR

LLVM IR is the most boring because all you do is mark any tail call as being a tail call. Then so long as the call meets some requirements, the key one being that the result of the call must be returned immediately, LLVM will generate a jump instead of a call for you.

Given the following lisp-y implementation of the same tail call recursive fibonacci function (compiler here):

(def fib (a b n)
     (if (= n 0)
         a
       (fib b (+ a b) (- n 1))))

We generate the following LLVM IR:

define i64 @fib(i64 %a, i64 %b, i64 %n) {
  %ifresult13 = alloca i64, align 4
  %sym14 = add i64 %n, 0
  %sym15 = add i64 0, 0
  %sym12 = icmp eq i64 %sym14, %sym15
  br i1 %sym12, label %iftrue16, label %iffalse17
iftrue16:
  %sym18 = add i64 %a, 0
  store i64 %sym18, i64* %ifresult13, align 4
  br label %ifend19
iffalse17:
  %sym21 = add i64 %b, 0
  %sym23 = add i64 %a, 0
  %sym24 = add i64 %b, 0
  %sym22 = add i64 %sym23, %sym24
  %sym26 = add i64 %n, 0
  %sym27 = add i64 1, 0
  %sym25 = sub i64 %sym26, %sym27
  ; NOTE the `tail` before `call` here
  %sym20 = tail call i64 @fib(i64 %sym21, i64 %sym22, i64 %sym25)
  ret i64 %sym20
  store i64 %sym20, i64* %ifresult13, align 4
  br label %ifend19
ifend19:
  %sym11 = load i64, i64* %ifresult13, align 4
  ret i64 %sym11
}

The only difference between supporting tail call elimination in is whether the call instruction is preceeded by a tail directive. That makes the implementation very simple:

const isTailCall = module.exports.TAIL_CALL_ENABLED &&
                   context.tailCallTree.includes(validFunction.value);
const maybeTail = isTailCall ? 'tail ' : '';
this.emit(1, `%${destination.value} = ${maybeTail}call ${validFunction.type} @${validFunction.value}(${safeArgs})`);
if (isTailCall) {
  this.emit(1, `ret ${destination.type} %${destination.value}`);
}

Generated assembly

The resulting generated code (run through llc) for that call will be:

...
    add rax, rsi
    dec rdx
    mov rdi, rsi
    mov rsi, rax
    jmp _fib               ## TAILCALL
...

And if tail call elimination is disabled:

...
    add rax, rsi
    dec rdx
    mov rdi, rsi
    mov rsi, rax
    call    _fib
...

Summary

The last bit I haven't covered is how you track whether or not a call is in tail position. That is difficult to cover in a blog post because it's a matter of you propagating/not propagating at each syntax node type. But generally speaking, if the syntax node is not in tail position (e.g. not the last expression in a block), you drop the tail state you've built up. When you make a function call, you add the function name to the tail state.

But I will be covering this in detail in the LLVM case in the next post in my compiler basics series.