April 12, 2020

Writing a SQL database from scratch in Go: 2. binary expressions and WHERE filters

Previously in database basics:
1. SELECT, INSERT, CREATE and a REPL

Next in database basics:
3. indexes
4. a database/sql driver

In this post, we'll extend gosql to support binary expressions and very simple filtering on SELECT results via WHERE. We'll introduce a general mechanism for interpreting an expression on a row in a table. The expression may be an identifier (where the result is the value of the cell corresponding to that column in the row), a numeric literal, a combination via a binary expression, etc.

The following interactions will be possible:

# CREATE TABLE users (name TEXT, age INT);
ok
#  INSERT INTO users VALUES ('Stephen', 16);
ok
# SELECT name, age FROM users;
name   | age
----------+------
Stephen |  16
(1 result)
ok
# INSERT INTO users VALUES ('Adrienne', 23);
ok
# SELECT age + 2, name FROM users WHERE age = 23;
age |   name
------+-----------
25 | Adrienne
(1 result)
ok
# SELECT name FROM users;
name
------------
Stephen
Adrienne
(2 results)
ok

The changes we'll make in this post are roughly a walk through of this commit.

Boilerplate updates

There are a few updates to pick up that I won't go into in this post. Grab the following files from the main repo:

  • lexer.go
    • The big change here is to use the same keyword matching algorithm for symbols. This allows us to support symbols that are longer than one character.
    • This file also now includes the following keywords and symbols: and, or, true, false, =, <>, ||, and +.
  • cmd/main.go

Parsing boilerplate

We'll redefine three helper functions in parser.go before going further: parseToken, parseTokenKind, and helpMessage.

The parseToken helper will consume a token if it matches the one provided as an argument (ignoring location).

func parseToken(tokens []*token, initialCursor uint, t token) (*token, uint, bool) {
    cursor := initialCursor

    if cursor >= uint(len(tokens)) {
        return nil, initialCursor, false
    }

    if p := tokens[cursor]; t.equals(p) {
        return p, cursor + 1, true
    }

    return nil, initialCursor, false
}

The parseTokenKind helper will consume a token if it is the same kind as an argument provided.

func parseTokenKind(tokens []*token, initialCursor uint, kind tokenKind) (*token, uint, bool) {
    cursor := initialCursor

    if cursor >= uint(len(tokens)) {
        return nil, initialCursor, false
    }

    current := tokens[cursor]
    if current.kind == kind {
        return current, cursor + 1, true
    }

    return nil, initialCursor, false
}

And the helpMessage helper will give an indication of where in a program something happened.

func helpMessage(tokens []*token, cursor uint, msg string) {
    var c *token
    if cursor+1 < uint(len(tokens)) {
        c = tokens[cursor+1]
    } else {
        c = tokens[cursor]
    }

    fmt.Printf("[%d,%d]: %s, near: %s\n", c.loc.line, c.loc.col, msg, c.value)
}

Parsing binary expressions

Next we'll extend the AST structure in ast.go to support a "binary kind" of expression. The binary expression will have two sub-expressions and an operator.

const (
    literalKind expressionKind
    binaryKind
)

type binaryExpression struct {
    a  expression
    b  expression
    op token
}

type expression struct {
    literal *token
    binary  *binaryExpression
    kind    expressionKind
}

We'll use Pratt parsing to handle operator precedence. There is an excellent introduction to this technique here.

If at the beginning of parsing we see a left parenthesis, we'll consume it and parse an expression within it. Then we'll look for a right parenthesis. Otherwise we'll look for a non-binary expression first (e.g. symbol, number).

func parseExpression(tokens []*token, initialCursor uint, delimiters []token, minBp uint) (*expression, uint, bool) {
    cursor := initialCursor

    var exp *expression
    _, newCursor, ok := parseToken(tokens, cursor, tokenFromSymbol(leftParenSymbol))
    if ok {
        cursor = newCursor
        rightParenToken := tokenFromSymbol(rightParenSymbol)

        exp, cursor, ok = parseExpression(tokens, cursor, append(delimiters, rightParenToken), minBp)
        if !ok {
            helpMessage(tokens, cursor, "Expected expression after opening paren")
            return nil, initialCursor, false
        }

        _, cursor, ok = parseToken(tokens, cursor, rightParenToken)
        if !ok {
            helpMessage(tokens, cursor, "Expected closing paren")
            return nil, initialCursor, false
        }
    } else {
        exp, cursor, ok = parseLiteralExpression(tokens, cursor)
        if !ok {
            return nil, initialCursor, false
        }
    }

    ...

    return exp, cursor, true
}

Then we'll look for a binary operator (e.g. =, and) or delimiter. If we find an operator and it of lesser "binding power" than the current minimum (minBp passed as an argument to the parse function with a default value of 0), we'll return the current expression.

    ...

    lastCursor := cursor
outer:
    for cursor < uint(len(tokens)) {
        for _, d := range delimiters {
            _, _, ok = parseToken(tokens, cursor, d)
            if ok {
                break outer
            }
        }

        binOps := []token{
            tokenFromKeyword(andKeyword),
            tokenFromKeyword(orKeyword),
            tokenFromSymbol(eqSymbol),
            tokenFromSymbol(neqSymbol),
            tokenFromSymbol(concatSymbol),
            tokenFromSymbol(plusSymbol),
        }

        var op *token = nil
        for _, bo := range binOps {
            var t *token
            t, cursor, ok = parseToken(tokens, cursor, bo)
            if ok {
                op = t
                break
            }
        }

        if op == nil {
            helpMessage(tokens, cursor, "Expected binary operator")
            return nil, initialCursor, false
        }

        bp := op.bindingPower()
        if bp < minBp {
            cursor = lastCursor
            break
        }

        ...
    }

    return exp, cursor, true

The bindingPower function on tokens can be defined for now such that sum and concatenation have the highest binding power, followed by equality operations, then boolean operators, and then everything else at zero.

func (t token) bindingPower() uint {
    switch t.kind {
    case keywordKind:
        switch keyword(t.value) {
        case andKeyword:
            fallthrough
        case orKeyword:
            return 1
        }
    case symbolKind:
        switch symbol(t.value) {
        case eqSymbol:
            fallthrough
        case neqSymbol:
            fallthrough
        case concatSymbol:
            fallthrough
        case plusSymbol:
            return 3
        }
    }

    return 0
}

Back in parseExpression, if the new operator has greater binding power we'll parse the next operand expression (a recursive call, passing the binding power of the new operator as the new minBp).

Upon completion, the current expression (the return value of the recursive call) is set to a new binary expression containing the previously current expression on the left and the just-parsed expression on the right.

        ...

        b, newCursor, ok := parseExpression(tokens, cursor, delimiters, bp)
        if !ok {
            helpMessage(tokens, cursor, "Expected right operand")
            return nil, initialCursor, false
        }
        exp = &expression{
            binary: &binaryExpression{
                *exp,
                *b,
                *op,
            },
            kind: binaryKind,
        }
        cursor = newCursor
        lastCursor = cursor
    }

    return exp, cursor, true
}

All together:

func parseExpression(tokens []*token, initialCursor uint, delimiters []token, minBp uint) (*expression, uint, bool) {
    cursor := initialCursor

    var exp *expression
    _, newCursor, ok := parseToken(tokens, cursor, tokenFromSymbol(leftParenSymbol))
    if ok {
        cursor = newCursor
        rightParenToken := tokenFromSymbol(rightParenSymbol)

        exp, cursor, ok = parseExpression(tokens, cursor, append(delimiters, rightParenToken), minBp)
        if !ok {
            helpMessage(tokens, cursor, "Expected expression after opening paren")
            return nil, initialCursor, false
        }

        _, cursor, ok = parseToken(tokens, cursor, rightParenToken)
        if !ok {
            helpMessage(tokens, cursor, "Expected closing paren")
            return nil, initialCursor, false
        }
    } else {
        exp, cursor, ok = parseLiteralExpression(tokens, cursor)
        if !ok {
            return nil, initialCursor, false
        }
    }

    lastCursor := cursor
outer:
    for cursor < uint(len(tokens)) {
        for _, d := range delimiters {
            _, _, ok = parseToken(tokens, cursor, d)
            if ok {
                break outer
            }
        }

        binOps := []token{
            tokenFromKeyword(andKeyword),
            tokenFromKeyword(orKeyword),
            tokenFromSymbol(eqSymbol),
            tokenFromSymbol(neqSymbol),
            tokenFromSymbol(concatSymbol),
            tokenFromSymbol(plusSymbol),
        }

        var op *token = nil
        for _, bo := range binOps {
            var t *token
            t, cursor, ok = parseToken(tokens, cursor, bo)
            if ok {
                op = t
                break
            }
        }

        if op == nil {
            helpMessage(tokens, cursor, "Expected binary operator")
            return nil, initialCursor, false
        }

        bp := op.bindingPower()
        if bp < minBp {
            cursor = lastCursor
            break
        }

        b, newCursor, ok := parseExpression(tokens, cursor, delimiters, bp)
        if !ok {
            helpMessage(tokens, cursor, "Expected right operand")
            return nil, initialCursor, false
        }
        exp = &expression{
            binary: &binaryExpression{
                *exp,
                *b,
                *op,
            },
            kind: binaryKind,
        }
        cursor = newCursor
        lastCursor = cursor
    }

    return exp, cursor, true
}

Now that we have this general parse expression helper in place, we can add support for parsing WHERE in SELECT statements.

Parsing WHERE

This part's pretty easy. We modify the existing parseSelectStatement to search for an optional WHERE token followed by an expression.

func parseSelectStatement(tokens []*token, initialCursor uint, delimiter token) (*SelectStatement, uint, bool) {
    var ok bool
    cursor := initialCursor
    _, cursor, ok = parseToken(tokens, cursor, tokenFromKeyword(selectKeyword))
    if !ok {
        return nil, initialCursor, false
    }

    slct := SelectStatement{}

    fromToken := tokenFromKeyword(fromKeyword)
    item, newCursor, ok := parseSelectItem(tokens, cursor, []token{fromToken, delimiter})
    if !ok {
        return nil, initialCursor, false
    }

    slct.item = item
    cursor = newCursor

    whereToken := tokenFromKeyword(whereKeyword)
    delimiters := []token{delimiter, whereToken}

    _, cursor, ok = parseToken(tokens, cursor, fromToken)
    if ok {
        from, newCursor, ok := parseFromItem(tokens, cursor, delimiters)
        if !ok {
            helpMessage(tokens, cursor, "Expected FROM item")
            return nil, initialCursor, false
        }

        slct.from = from
        cursor = newCursor
    }

    _, cursor, ok = parseToken(tokens, cursor, whereToken)
    if ok {
        where, newCursor, ok := parseExpression(tokens, cursor, []token{delimiter}, 0)
        if !ok {
            helpMessage(tokens, cursor, "Expected WHERE conditionals")
            return nil, initialCursor, false
        }

        slct.where = where
        cursor = newCursor
    }

    return &slct, cursor, true
}

Now we're all done with parsing binary expressions and WHERE filters! If in doubt, refer to parser.go in the project.

Re-thinking query execution

In the first post in this series, we didn't establish any standard way for interpreting an expression in any kind of statement. In SQL though, every expression is always run in the context of a row in a table. We'll handle cases like SELECT 1 and INSERT INTO users VALUES (1) by creating a table with a single empty row to act as the context.

This requires a bit of re-architecting. So we'll rewrite the memory.go implementation in this post from scratch.

We'll also stop panic-ing when things go wrong. Instead we'll print a message. This allows the REPL loop to keep going.

Memory cells

Again the fundamental blocks of memory in the table will be an untyped array of bytes. We'll provide conversion methods from this memory cell into integers, strings, and boolean Go values.

type MemoryCell []byte

func (mc MemoryCell) AsInt() int32 {
    var i int32
    err := binary.Read(bytes.NewBuffer(mc), binary.BigEndian, &i)
    if err != nil {
        fmt.Printf("Corrupted data [%s]: %s\n", mc, err)
        return 0
    }

    return i
}

func (mc MemoryCell) AsText() string {
    return string(mc)
}

func (mc MemoryCell) AsBool() bool {
    return len(mc) != 0
}

func (mc MemoryCell) equals(b MemoryCell) bool {
    // Seems verbose but need to make sure if one is nil, the
    // comparison still fails quickly
    if mc == nil || b == nil {
        return mc == nil && b == nil
    }

    return bytes.Compare(mc, b) == 0
}

We'll also extend the Cell interface in backend.go to support the new boolean type.

package gosql

type ColumnType uint

const (
    TextType ColumnType = iota
    IntType
    BoolType
)

type Cell interface {
    AsText() string
    AsInt() int32
    AsBool() bool
}

...

Finally, we need a way for mapping a Go value into a memory cell.

func literalToMemoryCell(t *token) MemoryCell {
    if t.kind == numericKind {
        buf := new(bytes.Buffer)
        i, err := strconv.Atoi(t.value)
        if err != nil {
            fmt.Printf("Corrupted data [%s]: %s\n", t.value, err)
            return MemoryCell(nil)
        }

        // TODO: handle bigint
        err = binary.Write(buf, binary.BigEndian, int32(i))
        if err != nil {
            fmt.Printf("Corrupted data [%s]: %s\n", string(buf.Bytes()), err)
            return MemoryCell(nil)
        }
        return MemoryCell(buf.Bytes())
    }

    if t.kind == stringKind {
        return MemoryCell(t.value)
    }

    if t.kind == boolKind {
        if t.value == "true" {
            return MemoryCell([]byte{1})
        } else {
            return MemoryCell(nil)
        }
    }

    return nil
}

And we'll provide global true and false values:

var (
    trueToken  = token{kind: boolKind, value: "true"}
    falseToken = token{kind: boolKind, value: "false"}

    trueMemoryCell  = literalToMemoryCell(&trueToken)
    falseMemoryCell = literalToMemoryCell(&falseToken)
)

Tables

A table has a list of rows (an array of memory cells) and a list of column names and types.

type table struct {
    columns     []string
    columnTypes []ColumnType
    rows        [][]MemoryCell
}

Finally we'll add a series of methods on table that, given a row index, interprets an expression AST against that row in the table.

Interpreting literals

First we'll implement evaluateLiteralCell that will look up an identifier or return the value of integers, strings, and booleans.

func (t *table) evaluateLiteralCell(rowIndex uint, exp expression) (MemoryCell, string, ColumnType, error) {
    if exp.kind != literalKind {
        return nil, "", 0, ErrInvalidCell
    }

    lit := exp.literal
    if lit.kind == identifierKind {
        for i, tableCol := range t.columns {
            if tableCol == lit.value {
                return t.rows[rowIndex][i], tableCol, t.columnTypes[i], nil
            }
        }

        return nil, "", 0, ErrColumnDoesNotExist
    }

    columnType := IntType
    if lit.kind == stringKind {
        columnType = TextType
    } else if lit.kind == boolKind {
        columnType = BoolType
    }

    return literalToMemoryCell(lit), "?column?", columnType, nil
}

Interpreting binary expressions

Now we can implement evaluateBinaryCell that will evaluate it's two sub-expressions and combine them together according to the operator. The SQL operators we have defined so far do no coercion. So we'll fail immediately if the two sides of the operation are not of the same type. Additionally, the concatenation and addition operators require that their arguments are strings and numbers, respectively.

func (t *table) evaluateBinaryCell(rowIndex uint, exp expression) (MemoryCell, string, ColumnType, error) {
    if exp.kind != binaryKind {
        return nil, "", 0, ErrInvalidCell
    }

    bexp := exp.binary

    l, _, lt, err := t.evaluateCell(rowIndex, bexp.a)
    if err != nil {
        return nil, "", 0, err
    }

    r, _, rt, err := t.evaluateCell(rowIndex, bexp.b)
    if err != nil {
        return nil, "", 0, err
    }

    switch bexp.op.kind {
    case symbolKind:
        switch symbol(bexp.op.value) {
        case eqSymbol:
            eq := l.equals(r)
            if lt == TextType && rt == TextType && eq {
                return trueMemoryCell, "?column?", BoolType, nil
            }

            if lt == IntType && rt == IntType && eq {
                return trueMemoryCell, "?column?", BoolType, nil
            }

            if lt == BoolType && rt == BoolType && eq {
                return trueMemoryCell, "?column?", BoolType, nil
            }

            return falseMemoryCell, "?column?", BoolType, nil
        case neqSymbol:
            if lt != rt || !l.equals(r) {
                return trueMemoryCell, "?column?", BoolType, nil
            }

            return falseMemoryCell, "?column?", BoolType, nil
        case concatSymbol:
            if lt != TextType || rt != TextType {
                return nil, "", 0, ErrInvalidOperands
            }

            return literalToMemoryCell(&token{kind: stringKind, value: l.AsText() + r.AsText()}), "?column?", TextType, nil
        case plusSymbol:
            if lt != IntType || rt != IntType {
                return nil, "", 0, ErrInvalidOperands
            }

            iValue := int(l.AsInt() + r.AsInt())
            return literalToMemoryCell(&token{kind: numericKind, value: strconv.Itoa(iValue)}), "?column?", IntType, nil
        default:
            // TODO
            break
        }
    case keywordKind:
        switch keyword(bexp.op.value) {
        case andKeyword:
            if lt != BoolType || rt != BoolType {
                return nil, "", 0, ErrInvalidOperands
            }

            res := falseMemoryCell
            if l.AsBool() && r.AsBool() {
                res = trueMemoryCell
            }

            return res, "?column?", BoolType, nil
        case orKeyword:
            if lt != BoolType || rt != BoolType {
                return nil, "", 0, ErrInvalidOperands
            }

            res := falseMemoryCell
            if l.AsBool() || r.AsBool() {
                res = trueMemoryCell
            }

            return res, "?column?", BoolType, nil
        default:
            // TODO
            break
        }
    }

    return nil, "", 0, ErrInvalidCell
}

Then we'll provide a generic evaluateCell method to wrap these two correctly:

func (t *table) evaluateCell(rowIndex uint, exp expression) (MemoryCell, string, ColumnType, error) {
    switch exp.kind {
        case literalKind:
            return t.evaluateLiteralCell(rowIndex, exp)
        case binaryKind:
            return t.evaluateBinaryCell(rowIndex, exp)
        default:
            return nil, "", 0, ErrInvalidCell
    }
}

Implementing SELECT

As before, each statement will operate on a backend of tables.

type MemoryBackend struct {
    tables map[string]*table
}

func NewMemoryBackend() *MemoryBackend {
    return &MemoryBackend{
        tables: map[string]*table{},
    }
}

When we implement SELECT, we'll iterate over each row in the table (we only support looking up one table for now). If the SELECT statement contains a WHERE block, we'll evaluate the WHERE expression against the current row and move on if the result is false.

Otherwise for each expression in the SELECT list of items we'll evaluate it against the current row in the table.

If there is no table selected, we provide a fake table with a single empty row.

func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) {
    t := &table{}

    if slct.from != nil && slct.from.table != nil {
        var ok bool
        t, ok = mb.tables[slct.from.table.value]
        if !ok {
            return nil, ErrTableDoesNotExist
        }
    }

    if slct.item == nil || len(*slct.item) == 0 {
        return &Results{}, nil
    }

    results := [][]Cell{}
    columns := []struct {
        Type ColumnType
        Name string
    }{}

    if slct.from == nil {
        t = &table{}
        t.rows = [][]MemoryCell{{}}
    }

    for i := range t.rows {
        result := []Cell{}
        isFirstRow := len(results) == 0

        if slct.where != nil {
            val, _, _, err := t.evaluateCell(uint(i), *slct.where)
            if err != nil {
                return nil, err
            }

            if !val.AsBool() {
                continue
            }
        }

        for _, col := range *slct.item {
            if col.asterisk {
                // TODO: handle asterisk
                fmt.Println("Skipping asterisk.")
                continue
            }

            value, columnName, columnType, err := t.evaluateCell(uint(i), *col.exp)
            if err != nil {
                return nil, err
            }

            if isFirstRow {
                columns = append(columns, struct {
                    Type ColumnType
                    Name string
                }{
                    Type: columnType,
                    Name: columnName,
                })
            }

            result = append(result, value)
        }

        results = append(results, result)
    }

    return &Results{
        Columns: columns,
        Rows:    results,
    }, nil
}

Implementing INSERT, CREATE

The INSERT and CREATE statements stay mostly the same except for that we'll use the evaluateCell help for every expression. Refer back to the first post if the implementation is otherwise unclear.

func (mb *MemoryBackend) Insert(inst *InsertStatement) error {
    t, ok := mb.tables[inst.table.value]
    if !ok {
        return ErrTableDoesNotExist
    }

    if inst.values == nil {
        return nil
    }

    row := []MemoryCell{}

    if len(*inst.values) != len(t.columns) {
        return ErrMissingValues
    }

    for _, value := range *inst.values {
        if value.kind != literalKind {
            fmt.Println("Skipping non-literal.")
            continue
        }

        emptyTable := &table{}
        value, _, _, err := emptyTable.evaluateCell(0, *value)
        if err != nil {
            return err
        }

        row = append(row, value)
    }

    t.rows = append(t.rows, row)
    return nil
}

func (mb *MemoryBackend) CreateTable(crt *CreateTableStatement) error {
    t := table{}
    mb.tables[crt.name.value] = &t
    if crt.cols == nil {

        return nil
    }

    for _, col := range *crt.cols {
        t.columns = append(t.columns, col.name.value)

        var dt ColumnType
        switch col.datatype.value {
        case "int":
            dt = IntType
        case "text":
            dt = TextType
        default:
            return ErrInvalidDatatype
        }

        t.columnTypes = append(t.columnTypes, dt)
    }

    return nil
}

Back to the REPL

Putting it all together, we run the following session:

# CREATE TABLE users (name TEXT, age INT);
ok
#  INSERT INTO users VALUES ('Stephen', 16);
ok
# SELECT name, age FROM users;
name   | age
----------+------
Stephen |  16
(1 result)
ok
# INSERT INTO users VALUES ('Adrienne', 23);
ok
# SELECT age + 2, name FROM users WHERE age = 23;
age |   name
------+-----------
25 | Adrienne
(1 result)
ok
# SELECT name FROM users;
name
------------
Stephen
Adrienne
(2 results)
ok

And that's it for now! In future posts we'll get into indices, joining tables, etc.