October 12, 2019

Interpreting Go

After spending some time at work on tooling for keeping documentation in sync with Go struct definitions I had enough exposure to Go's built-in parsing package that next steps were clear: write an interpreter. It's a great way to get more comfortable with a language's AST.

In this post we'll use the Go parser package to interpret the AST directly (as opposed to compiling to a bytecode VM) with enough to support a recursive implementation of the fibonacci algorithm:

package main

func fib(a int) int {
  if a == 1 {
    return 0
  }

  if a == 2 {
    return 1
  }

  return fib(a-1) + fib(a-2)
}

func main() {
  println(fib(15))
}

You'll note this isn't actually valid Go because we are using an undefined function println. We'll provide that for the runtime to make things easier on ourselves.

The fibonacci algorithm is my goto minimal program that forces us to deal with basic aspects of:

  • Function definitions
  • Function calls
  • Function arguments
  • Function return values
  • If/else
  • Assignment
  • Arithmetic and boolean operators

We'll do this in around 200 LoC. Project code is available on Github.

A follow-up post will cover support for an iterative fibonacci implementation with support for basic aspects of:

  • Local variables
  • Loops

First steps

I always start exploring an AST by practicing error-driven development. It's helpful to have the Go AST, parser, and token package docs handy as well.

We'll focus on single-file programs and start with parser.ParseFile. This function will return an *ast.File. This in turn contains a list of Decls. Unfortunately Go stops being helpful at this point because we have no clue what is going to implement this Decl interface. So we'll switch on the concrete type and error until we know what we need to know.

package main

import (
  "go/ast"
  "go/parser"
  "go/token"
  "io/ioutil"
  "log"
  "os"
  "reflect"
)

func interpret(f *ast.File) {
  for _, decl := range f.Decls {
    switch d := decl.(type) {
    default:
      log.Fatalf("Unknown decl type (%s): %+v", reflect.TypeOf(d), d)
    }
  }
}

func main() {
  fset := token.NewFileSet() // positions are relative to fset

  src, err := ioutil.ReadFile(os.Args[1])
  if err != nil {
    log.Fatalf("Unable to read file: %s", err.Error())
  }

  f, err := parser.ParseFile(fset, os.Args[1], src, 0)
  if err != nil {
    log.Fatalf("Unable to parse file: %s", err.Error())
  }

  interpret(f)
}

Build and run:

$ go build goi.go
$ ./goi fib.go
2019/10/12 09:43:48 Unknown decl type (*ast.FuncDecl): &{Doc:<nil> Recv:<nil> Name:fib Type:0xc000096320 Body:0xc00009a3c0}

Cool! This is the declaration of the fib function and its type is *ast.FuncDecl.

Interpreting ast.FuncDecl

A function declaration is going to need to add its name to a context map, mapped to a function reference for use in function calls. Since Go throws everything into the same context namespace this we can simply pass around a map of strings to values where a value can be any Go value. To facilitate this, we'll define a value struct to hold an integer to represent "kind" and an empty interface "value". When a value is referenced it will have to switch on the "kind" and then cast the "value".

Additionally, and unlike a value-oriented language like Scheme, we'll need to track a set of return values at all stages through interpretation so, when set, we can short circuit execution.

type kind uint

const (
  i64 kind = iota
  fn
  bl
)

type value struct {
  kind  kind
  value interface{}
}

type context map[string]value

func (c context) copy() context {
  cpy := context{}
  for key, value := range c {
    cpy[key] = value
  }

  return cpy
}

type ret struct {
  set bool
  vs  []value
}

func (r *ret) setValues(vs []value) {
  r.vs = vs
  r.set = true
}

func interpretFuncDecl(ctx context, r *ret, fd *ast.FuncDecl) {
  ctx[fd.Name.String()] = value{
    fn,
    func(ctx context, r *ret, args []value) {},
  }
}

func interpret(ctx context, f *ast.File) {
  for _, decl := range f.Decls {
    switch d := decl.(type) {
    case *ast.FuncDecl:
      interpretFuncDecl(ctx, nil, d)
    default:
      log.Fatalf("Unknown decl type (%s): %+v", reflect.TypeOf(d), d)
    }
  }
}

Now that we have the idea of return management and contexts set out, let's fill out the actual function declaration callback. Inside we'll need to copy the the context so variables declared inside the function are not visible outside. Then we'll iterate over the parameters and map them in context to the associated argument. Finally we'll interpret the body.

func interpretBlockStmt(ctx context, r *ret, fd *ast.BlockStmt) {}

func interpretFuncDecl(ctx context, r *ret, fd *ast.FuncDecl) {
  ctx[fd.Name.String()] = value{
    fn,
    func(ctx context, r *ret, args []value) {
      childCtx := ctx.copy()
      for i, param := range fd.Type.Params.List {
        childCtx[param.Names[0].String()] = args[i]
      }

      interpretBlockStmt(childCtx, r, fd.Body)
    },
  }
}

And we'll add a call to the interpreted main to the end of the interpreter's main:

func main() {
  fset := token.NewFileSet() // positions are relative to fset

  src, err := ioutil.ReadFile(os.Args[1])
  if err != nil {
    log.Fatalf("Unable to read file: %s", err.Error())
  }

  f, err := parser.ParseFile(fset, os.Args[1], src, 0)
  if err != nil {
    log.Fatalf("Unable to parse file: %s", err.Error())
  }

  ctx := context{}
  interpret(ctx, f)
  var r ret
  ctx["main"].value.(func(context, *ret, []value))(ctx, &r, []value{})
}

Next step!

Interpreting ast.BlockStmt

For this AST node, we'll iterate over each statement and interpret it. If the return value has been set we'll execute the loop to short circuit execution.

func interpretStmt(ctx context, r *ret, stmt ast.Stmt) {}

func interpretBlockStmt(ctx context, r *ret, bs *ast.BlockStmt) {
  for _, stmt := range bs.List {
    interpretStmt(ctx, r, stmt)
    if r.set {
      return
    }
  }
}

Next step!

Interpreting ast.Stmt

Since ast.Stmt is another interface, we're back to error-driven development.

func interpretStmt(ctx context, r *ret, stmt ast.Stmt) {
  switch s := stmt.(type) {
  default:
    log.Fatalf("Unknown stmt type (%s): %+v", reflect.TypeOf(s), s)
  }
}

And the trigger:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:15:14 Unknown stmt type (*ast.ExprStmt): &{X:0xc0000a02c0}

Great! Checking the docs on ast.ExprStmt we'll just skip directly to a call to a new function interpretExpr:

func interpretExpr(ctx context, r *ret, expr ast.Expr) {}

func interpretStmt(ctx context, r *ret, stmt ast.Stmt) {
  switch s := stmt.(type) {
  case *ast.ExprStmt:
    interpretExpr(ctx, r, s.X)
  default:
    log.Fatalf("Unknown stmt type (%s): %+v", reflect.TypeOf(s), s)
  }
}

Moving on!

Interpreting ast.Expr

We've got another interface. Let's error!

func interpretExpr(ctx context, r *ret, expr ast.Expr) {
  switch e := expr.(type) {
  default:
    log.Fatalf("Unknown expr type (%s): %+v", reflect.TypeOf(e), e)
  }
}

And the trigger:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:19:16 Unknown expr type (*ast.CallExpr): &{Fun:println Lparen:146 Args:[0xc0000a2280] Ellipsis:0 Rparen:154}

Cool! For a call we'll evaluate the arguments, evaluate the function expression itself, cast the resulting value to a function, and call it.

func interpretCallExpr(ctx context, r *ret, ce *ast.CallExpr) {
  var fnr ret
  interpretExpr(ctx, &fnr, ce.Fun)
  fn := fnr.values[0]

  values := []value{}
  for _, arg := range ce.Args {
    var vr ret
    interpretExpr(ctx, &vr, arg)
    values = append(values, vr.values[0])
  }

  fn.value.(func(context, *ret, []value))(ctx, r, values)
}

All of this casting is unsafe because we aren't doing a type-checking stage. But we can ignore this because if a type-checking stage were introduced (which it need to be at some point), it would prevent bad casts from happening.

Handling more ast.Expr implementations

Let's give the interpreter a shot again:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:28:17 Unknown expr type (*ast.Ident): println

We'll need to add ast.Ident support to interpretCallExpr. This will be a simple lookup in context. We'll also add a setValue helper since the ret value is serving double-duty as a value passing mechanism and also a function's return value (solely where multiple value are a thing).

...


func (r *ret) setValue(v value) {
  r.values = []value{v}
  r.set = true
}

...

func interpretExpr(ctx context, r *ret, expr ast.Expr) {
  switch e := expr.(type) {
  case *ast.CallExpr:
    interpretCallExpr(ctx, r, e)
  case *ast.Ident:
    r.setValue(ctx[e.Name])
  default:
    log.Fatalf("Unknown expr type (%s): %+v", reflect.TypeOf(e), e)
  }
}

This is also a good time to add the println builtin to our top-level context.

func main() {
  ...

  ctx := context{}
  interpret(ctx, f)
  ctx["println"] = value{
    fn,
    func(ctx context, r *ret, args []value) {
      var values []interface{}
      for _, arg := range args {
        values = append(values, arg.value)
      }

      fmt.Println(values...)
    },
  }

  var r ret
  ctx["main"].value.(func(context, *ret, []value))(ctx, &r, []value{})
}

More ast.Expr implementations

Running the interpreter again we get:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:41:59 Unknown expr type (*ast.BasicLit): &{ValuePos:151 Kind:INT Value:15}

Easy enough: we'll switch on the "kind" and parse a string int to an int and wrap it in our value type.

func interpretExpr(ctx context, r *ret, expr ast.Expr) {
  switch e := expr.(type) {
  case *ast.CallExpr:
    interpretCallExpr(ctx, r, e)
  case *ast.Ident:
    r.setValue(ctx[e.Name])
  case *ast.BasicLit:
    switch e.Kind {
    case token.INT:
      i, _ := strconv.ParseInt(e.Value, 10, 64)
      r.setValue(value{i64, i})
    default:
      log.Fatalf("Unknown basiclit type: %+v", e)
    }
  default:
    log.Fatalf("Unknown expr type (%s): %+v", reflect.TypeOf(e), e)
  }
}

Now we run again:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:48:46 Unknown stmt type (*ast.IfStmt): &{If:38 Init:<nil> Cond:0xc0000ac150 Body:0xc0000ac1b0 Else:<nil>}

Cool, more control flow!

Interpreting ast.IfStmt

For ast.IfStmt we interpret the condition and, depending on the condition, interpret the body or the else node. In order to make empty else interpreting easier, we'll also add a nil short-circuit to interpretStmt.

func interpretIfStmt(ctx context, r *ret, is *ast.IfStmt) {
  interpretStmt(ctx, nil, is.Init)

  var cr ret
  interpretExpr(ctx, &cr, is.Cond)
  c := cr.valus[0]

  if c.value.(bool) {
    interpretBlockStmt(ctx, r, is.Body)
    return
  }

  interpretStmt(ctx, r, is.Else)
}


func interpretStmt(ctx context, r *ret, stmt ast.Stmt) {
  if stmt == nil {
    return
  }

  switch s := stmt.(type) {
  case *ast.IfStmt:
    interpretIfStmt(ctx, r, s)

  ...

Let's try it out:

$ go build goi.go
$ ./goi fib.go
2019/10/12 10:56:28 Unknown expr type (*ast.BinaryExpr): &{X:a OpPos:43 Op:== Y:0xc00008a120}

Great!

Interpreting ast.BinaryExpr

An ast.BinaryExpr has an Op field that we'll switch on to decide what operations to do. We'll interpret the left side and then the right side and finally perform the operation and return the result. The three binary operations we use in this program are ==, + and -. We'll look these up in go/token docs to discover the associated constants.

func interpretBinaryExpr(ctx context, r *ret, bexpr *ast.BinaryExpr) {
  var xr, yr ret
  interpretExpr(ctx, &xr, bexpr.X)
  x := xr.values[0]
  interpretExpr(ctx, &yr, bexpr.Y)
  y := yr.values[0]

  switch bexpr.Op {
  case token.ADD:
    r.setValue(value{i64, x.value.(int64) + y.value.(int64)})
  case token.SUB:
    r.setValue(value{i64, x.value.(int64) - y.value.(int64)})
  case token.EQL:
    r.setValue(value{bl, x.value.(int64) == y.value.(int64)})
  default:
    log.Fatalf("Unknown binary expression type: %+v", bexpr)
  }
}

func interpretExpr(ctx context, r *ret, expr ast.Expr) {
  switch e := expr.(type) {
  case *ast.BinaryExpr:
    interpretBinaryExpr(ctx, r, e)

  ...

Let's try one more time!

$ go build goi.go
$ ./goi fib.go
2019/10/12 11:06:19 Unknown stmt type (*ast.ReturnStmt): &{Return:94 Results:[0xc000070540]}

Awesome, last step.

Interpreting ast.ReturnStmt

Based on the ast.ReturnStmt definition we'll have to interpret each expression and set all of them to the ret value.

func interpretReturnStmt(ctx context, r *ret, s *ast.ReturnStmt) {
  var values []value
  for _, expr := range s.Results {
    var r ret
    interpretExpr(ctx, &r, expr)
    values = append(values, r.values[0])
  }

  r.setValues(values)

  return
}

func interpretStmt(ctx context, r *ret, stmt ast.Stmt) {
  if stmt == nil {
    return
  }

  switch s := stmt.(type) {
  case *ast.ReturnStmt:
    interpretReturnStmt(ctx, r, s)

  ...

And let's try one last time:

$ go build goi.go
$ ./goi fib.go
377

Looking good. :) Let's try with another input:

$ cat fib.go
package main

func fib(a int) int {
  if a == 1 {
    return 0
  }

  if a == 2 {
    return 1
  }

  return fib(a-1) + fib(a-2)
}

func main() {
  println(fib(14))
}
$ ./goi fib.go
233

We've got the basics of an interpreter for Golang.