Welcome to the next article in the Create Your Own Programming Language series! In this one we're going to transform the abstract syntax tree (AST) to A Normal Form and also implement one of the coolest features of functional languages: tail call optimization.
As always, if you haven't read the previous article on iteration, do that first before continuing.
Ok, good? Let's do it.
A Normal Form
Unless you're a programming language geek, you've probably never heard of A Normal Form (ANF).
In a compiler, it's normal to have multiple passes that transform the AST into intermediary forms. ANF is one such intermediary form.
ANF is an intermediary form where we un-nest nested expressions (like call expressions inside of other call expressions) and flatten things out.
That means we'll take an expression like this one:
(slice (+ 1 3) (* 2 5) (concat (list 1 2 3) (list 4 5 6)))
And turn it into something like this:
(def _1 (+ 1 3))
(def _2 (* 2 5))
(def _3 (list 1 2 3))
(def _4 (list 4 5 6))
(def _5 (concat _3 _4))
(slice _1 _2 _5)
This simplification of nested expressions makes it possible to do all sorts of further transformations and optimizations on the program before emitting code.
Tail Call Optimization
Tail call optimization is a feature most often found in functional languages like Scheme that allows essentially infinite recursion as long as the recursive call is in tail position.
Tail position means the function call is the last expression in a function's control flow path, and the expression is just the call expression with no other operations being performed.
For example, here's a naive implementation of the factorial function:
(def fact (n)
(if (= n 1)
n
(* n (fact (- n 1)))))
The recursive call to fact
is not in tail position here, because it's part of the expression beginning with *
.
Here's a tail recursive version of fact
:
(def fact (n a)
(if (= n 1)
a
(fact (- n 1) (* a n))))
As you can see, in the tail recursive version we keep the value of the multiplication operation as a parameter to the function. This kind of parameter is called an accumulator.
Normally, a recursive call adds another frame to the call stack, and when the stack grows large enough you get a stack overflow error.
In a language like Scheme, the interpreter is optimized so that tail recursive calls don't add new frames to the stack.
Instead, the frame for the new function call replaces the previous frame on the stack, allowing essentially infinite tail recursion.
The technique we use to optimize tail calls will allow us to have proper tail call optimization even though most JavaScript engines don't implement TCO natively (which is a shame, because it IS actually in the ECMAScript specification).
Fixing a Bug in The Parser
Before we begin implementing the A Normal Form transformation we need to quickly fix a bug in the parser.
I've gone back and fixed this in the original post on functions, but in case you read it before I edited it to include the fix I'll go ahead and share it here too.
In parseFunction
in src/parser/parse.js
there's a bug that prevents the 2nd expression in the body of a function from being properly parsed. In the code that checks to see if there is a return type annotation, we need to make a change. Starting from the top of the function change it to this:
let retType, body;
if (maybeArrow.type === TokenTypes.Symbol && maybeArrow.value === "->") {
// has return type annotation
retType = parseTypeAnnotation(maybeRetType);
body = maybeBody;
} else {
retType = null;
body = maybeRetType
? [maybeArrow, maybeRetType, ...maybeBody]
: [maybeArrow, ...maybeBody];
}
Then the function can continue from const variadic =...
to the end.
Preparing for A Normal Form
We also need to make a couple of changes to the core library in preparation for the transformation to A Normal Form. We'll be constructing new AST nodes in the transformation, including some function calls, and we need to make sure those functions are correct in the standard library.
First, we need to make a minor change to the get
function in lib/js/core.js
. We'll make it so it also works to get an index from a vector:
get: rt.makeFunction(
(n, obj) => {
let value =
obj.get && typeof obj.get === "function" ? obj.get(n) : obj[n];
if (value === undefined) {
throw new Exception(`Value for index ${n} not found on object`);
}
return value;
},
{ contract: "(((list any) || (vector any)) -> any)", name: "get" }
),
We're also going to add a new function slice
at the bottom of the module:
slice: rt.makeFunction(
(start, end, obj = undefined) => {
if (end < 0) {
end = [...obj].length + end;
} else if (obj === undefined) {
// didn't pass in an end value, which is valid
obj = end;
end = [...obj].length;
}
if (isList(obj)) {
let values = [];
let i = start;
while (i < end && obj.get(i) !== undefined) {
values.push(obj.get(i));
}
return Cons.from(values);
} else if (Array.isArray(obj)) {
return obj.slice(start, end);
} else {
throw new Exception(
"slice can only take a list or vector as its argument"
);
}
},
{
name: "slice",
contract:
"(number, number, ((vector any) || (list any)) -> ((vector any) || (list any)))",
}
),
As you'll see when we get to the ANF transformation there's a little white lie here: we're also going to use it on tuples. But since tuples are arrays under the hood and at the point of transformation we're already done with the type checker, it's not a big deal to do that. Just remember that slice won't type check for tuples, so you can't actually use it to slice tuples in your Wanda code.
You'll also notice that we've made the 3rd parameter default to undefined
because if we didn't then makeFunction
would curry it to 3 parameters. We need to be able to call it with only 2 arguments, as you'll see below when we get to transforming variable declarations.
Transforming to A Normal Form
An expression is said to be in A Normal Form if its complex and nested subexpressions have been un-nested and bound in let expressions (we'll use variable declarations).
Primitive values and symbols are already considered to be in A Normal Form.
All arguments to a function must be trivial, which is to say they should already be terminally evaluated and be passed into the function as either primitives or symbols. Another way of putting it is that arguments to a function need to be in a form where evaluating them halts immediately instead of calling out to another function.
You can have some nested expressions, i.e. the test expression for an if expression can still be a call expression as long as its arguments have been terminally evaluated.
You'll start to develop an intuition for it as we go through the conversion process.
Since we're unnesting expressions that are embedded within other expressions, a number of our conversion functions will have to return arrays of expression nodes. That means we're going to be doing a lot of checks for arrays in our conversion functions.
When a transformation produces an array of expressions, the last expression in the array will always be the target expression. So for example if we convert an if expression and it unnests additional expressions, the converted if expression will always be the last member of the array.
This makes array handling more straightforward, because we can always pop the last expression from the array and then concatenate the remainder of the array to our unnested expressions.
The Program
node, do expressions, and when expressions all contain a body
property that is an array of expressions, so when we get an array of unnested expressions in a function that processes one of those we'll need to use flatMap
to join those arrays to the body
properties of those nodes.
We're also going to handle destructuring at this stage and unnest all the separate variables that are parts of the vector and record patterns, which means we'll be able to remove the code in the emitter that handles destructured assignment.
It sounds more convoluted than it actually is, so let's start with some code and you can see how it works.
The Dispatcher Function
First, here are the imports for src/transform/anf.js
:
import { AST, ASTTypes } from "../parser/ast.js";
import { isPrimitive } from "../parser/utils.js";
import { Exception } from "../shared/exceptions.js";
import { makeGenSym } from "../runtime/makeSymbol.js";
import { Token } from "../lexer/Token.js";
import { TokenTypes } from "../lexer/TokenTypes.js";
We also need these 2 helper functions:
const createFreshSymbol = (srcloc) => {
return AST.Symbol(Token.new(TokenTypes.Symbol, makeGenSym(), srcloc));
};
const isPrimitiveOrSymbol = (node) => {
return isPrimitive(node) || node.kind === ASTTypes.Symbol;
};
As usual, we need a dispatcher function that handles each of the different node kinds:
export const anf = (node) => {
switch (node.kind) {
case ASTTypes.Program:
return transformProgram(node);
case ASTTypes.NumberLiteral:
case ASTTypes.StringLiteral:
case ASTTypes.BooleanLiteral:
case ASTTypes.KeywordLiteral:
case ASTTypes.NilLiteral:
case ASTTypes.Symbol:
return node;
case ASTTypes.CallExpression:
return transformCallExpression(node);
case ASTTypes.LambdaExpression:
return transformLambdaExpression(node);
case ASTTypes.VariableDeclaration:
return transformVariableDeclaration(node);
case ASTTypes.SetExpression:
return transformSetExpression(node);
case ASTTypes.TypeAlias:
// ignore
return node;
case ASTTypes.VectorLiteral:
return transformVectorLiteral(node);
case ASTTypes.RecordLiteral:
return transformRecordLiteral(node);
case ASTTypes.MemberExpression:
return transformMemberExpression(node);
case ASTTypes.DoExpression:
return transformDoExpression(node);
case ASTTypes.AsExpression:
return transformAsExpression(node);
case ASTTypes.IfExpression:
return transformIfExpression(node);
case ASTTypes.WhenExpression:
return transformWhenExpression(node);
case ASTTypes.LogicalExpression:
return transformLogicalExpression(node);
default:
throw new Exception(`Unhandled node kind: ${node.kind}`);
}
};
You may wonder why we're writing separate functions instead of using our generic visitor. I decided there was no benefit to using the visitor since we'd need to write a method for every single node type that's left after desugaring, so I wrote the transformer as a series of functions.
As you can see, for primitives and symbols it simply returns the node. There's nothing to do with type aliases either, so it returns that node as well. The rest of the nodes all dispatch to separate transformer functions.
Transforming The Program
transformProgram
is very simple: it just produces a new body by flatMap
ping over the original body with anf
as the map function:
const transformProgram = (node) => {
let body = node.body.flatMap(anf);
return { ...node, body };
};
Transforming Call Expressions
transformCallExpression
is the first function to return an array of expressions rather than a single expression.
We start by creating the array. Then we transform the function itself. If the result is an array, we pop the function off the array then concatenate the rest of the array to the unnested expressions array.
Then we need to handle arguments.
Arguments can be either simple or complex/nested expressions, so we need to do some work to handle arguments.
First, we create an array to hold the transformed arguments. Then we loop over the original node's arguments.
An argument to a function can be another call expression, in which case we need to process the subcall and unnest any of its arguments.
So if the argument is a call expression we loop over its args and unnest any nested expressions then push the bare arguments themselves onto a subargs array. We transform complex arguments by unnesting them and creating new variable assignments for the nested expressions, pushing the symbols for the new variables onto the subargs array and adding the unnested expressions and new variable declarations to the unnested expressions array.
We take the transformed subcall arguments and construct a new call expression, using it as the expression for a new variable declaration, then add that variable to the arguments array for the original call expression. The variable declaration gets added to the unnested expressions array.
Processing a subcall serves as a base case for recursively processing call expressions.
If an argument to the main call expression is not another call expression, it simply gets processed by the anf
function and the transformed expression is handled by popping the main expression off any array that was returned and concatenating any unnested expressions onto the unnested expressions array.
It sounds convoluted, but the code is fairly straightforward. Here's the code for transformCallExpression
, where you'll see I've annotated everything with comments so you can relate the code to the explanation I've given above:
const transformCallExpression = (node) => {
// create an array for unnested expressions from the call expression
let unnestedExprs = [];
// transform the function
let func = anf(node.func);
// if func has been transformed into an array, get the actual function
// which will be the last expression in the array from the transformer
if (Array.isArray(func)) {
func = func.pop();
// add the unnested expressions to our unnested expressions array
unnestedExprs.concat(func);
}
let args = [];
for (let arg of node.args) {
// if it's a call expression, we need to unnest any subexpressions from
// the arguments to the sub-call expression and create a new call expr
if (arg.kind === ASTTypes.CallExpression) {
// we'll need unnested arguments for the subcall
let subArgs = [];
for (let a of arg.args) {
// primitives and symbols are already in ANF
if (isPrimitiveOrSymbol(a)) {
subArgs.push(a);
// otherwise, we need to unnest the expression and bind the result to a new
// variable, then replace the expression in the call body with that variable
} else {
const freshLet = AST.VariableDeclaration(
createFreshSymbol(a.srcloc),
a,
a.srcloc,
null
);
const transformedLet = transformVariableDeclaration(freshLet);
// the actual declaration will always be the last node in the array
// we're going to need to add this to the unnested expressions
// for the parent call expression, so we don't pop it
const actualLet = transformedLet[transformedLet.length - 1];
// add the unnested expressions from the VariableDeclaration
// to the unnested expressions from the call expression
unnestedExprs = unnestedExprs.concat(transformedLet);
// add the variable that's been assigned to
// its relative place in the subcall args
subArgs.push(actualLet.lhv);
}
}
// create a new CallExpression with unnested sub-arguments
let subCall = AST.CallExpression(arg.func, subArgs, arg.srcloc);
// create a fresh variable symbol
const callSymbol = createFreshSymbol(subCall.srcloc);
// assign the result of the unnested call expression to the fresh variable
const callLet = AST.VariableDeclaration(
callSymbol,
subCall,
subCall.srcloc
);
// add the assignment to the unnested expressions
unnestedExprs.push(callLet);
// the argument to the parent call expression should now be the fresh variable
arg = callSymbol;
} else {
arg = anf(arg);
}
if (Array.isArray(arg)) {
// the actual arg will always be the last element in this array, the rest
// are all unnested expressions and should be concatenated to that array
args.push(arg.pop());
unnestedExprs = unnestedExprs.concat(arg);
} else {
// the anfed arg is a single node
args.push(arg);
}
}
const newCallExpr = AST.CallExpression(func, args, node.srcloc);
return [...unnestedExprs, newCallExpr];
};
Transforming Lambda Expressions
Transforming lambda expressions is extremely simple. We don't need to worry about parameters because there's no nesting there, so all we need to do is flatMap
over the body with anf
as the mapping function:
const transformLambdaExpression = (node) => {
const body = node.body.flatMap(anf);
return { ...node, body };
};
Transforming Variable Declarations
Transforming variable declarations is complicated by the need to unnest variables used in destructuring.
First, like usual, we create an array for unnested expressions. Then we ANF transform the initializer expression for the declaration node.
If the ANFed expression is an array, we pop off the actual expression then concatenate the rest to the unnested expressions array.
Whether or not it's an array, we construct a new declaration using the transformed expression.
If it's a simple variable declaration using a single symbol as the variable name, we're done here and can just return an array containing the unnested expressions and ending with the newly constructed declaration.
const transformVariableDeclaration = (node) => {
let unnestedExprs = [];
const anfedExpr = anf(node.expression);
let anfedDecl;
let expression;
if (Array.isArray(anfedExpr)) {
expression = anfedExpr.pop();
anfedDecl = { ...node, expression };
unnestedExprs = unnestedExprs.concat(anfedExpr);
} else {
expression = anfedExpr;
anfedDecl = { ...node, expression };
}
// rest of function which handles destructured declarations
// If we get here, it's a simple variable declaration with a symbol as LHV
return [...unnestedExprs, anfedDecl];
};
If it's a destructured variable declaration, we've got some work to do.
Transforming Vector Destructuring
If it's vector destructuring, we need to loop over the members of the vector pattern.
For each member of the vector pattern, we construct a new variable declaration with a call to the get
function for the current index of the list/vector/tuple to get the correct value for the destructured variable.
Note that this function call will throw an error if the object being destructured doesn't have enough members to satisfy the number of destructured variables.
If the last member is a rest variable, then instead of using the get
function we use slice
, passing it the current index and the object being destructured. This is why we needed to be able to call slice
with only 2 arguments, passing it only a start
numeric value instead of requiring both start
and stop
.
if (node.lhv.kind === ASTTypes.VectorPattern) {
// is vector pattern destructuring
/** @type {import("../parser/ast.js").VectorPattern} */
const pattern = node.lhv;
let i = 0;
for (let mem of pattern.members) {
if (i === pattern.members.length - 1 && pattern.rest) {
// need to slice off the rest of the list/vector/tuple and assign it to the last member
const destructuredDecl = AST.VariableDeclaration(
mem,
AST.CallExpression(
AST.Symbol(Token.new(TokenTypes.Symbol, "slice", mem.srcloc)),
[
AST.NumberLiteral(
Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
),
expression,
],
mem.srcloc
)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(destructuredDecl);
} else {
// need to get the value from the current index of the list/vector/tuple and assign it to the current pattern member
const destructuredDecl = AST.VariableDeclaration(
mem,
AST.CallExpression(
AST.Symbol(Token.new(TokenTypes.Symbol, "get", mem.srcloc)),
[
AST.NumberLiteral(
Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
),
expression,
]
)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(destructuredDecl);
}
i++;
}
return unnestedExprs;
}
// handle record destructuring...
Transforming Record Destructuring
If it's record destructuring, we start by assigning the actual object being destructured to a fresh variable name and pushing the assignment node onto the unnested expressions array. We save the symbol node for this fresh variable so we can use it in what follows.
Now we loop over the properties of the record pattern. We need to keep track of what properties have been used with an array, so we can use the unused properties to create the object assigned to the rest variable. Since this step comes after the type checker, we'll use the type information we got from the type checker to do that.
If it's not a rest variable, we simply create a new variable declaration node and construct a member expression node using the original object and the current property being looped over, then push that declaration onto the unnested expressions array.
Then we push the property name we've just used onto the used properties array.
If it's a rest variable, we create a new array of all the unused properties by filtering over the RHV's type properties then mapping those properties to new symbol nodes.
Then we get the properties for the remainder object by reducing the unused properties array and creating a new array of Property
nodes, then use that array to construct a new RecordLiteral
node.
Then we create a new variable declaration assigning the remainder object to the rest variable name.
Finally, we push the final declaration onto the unnested expressions array.
// closing brace is the last one from the previous code block, do not duplicate
} else if (node.lhv.kind === ASTTypes.RecordPattern) {
// is record pattern destructuring
// remember, we have the RHV's type at this point
/** @type {import("../parser/ast.js").RecordPattern} */
const pattern = node.lhv;
// first we need to assign the actual object to a fresh variable name
const objSymbol = createFreshSymbol(expression.srcloc);
const objDecl = AST.VariableDeclaration(
objSymbol,
expression,
expression.srcloc
);
unnestedExprs.push(objDecl);
let i = 0;
let used = [];
for (let prop of pattern.properties) {
if (i === pattern.properties.length - 1 && pattern.rest) {
// need to get the rest of the object's properties and assign them to the rest variable
// this maps the array of unused properties from the type to an array of Symbol
// nodes with each property name as the node name
const unusedProps = expression.type.properties
.filter((p) => {
return !used.includes(p.name);
})
.map((p) =>
AST.Symbol(Token.new(TokenTypes.Symbol, p.name, prop.srcloc))
);
// now create an object using the properties
const properties = unusedProps.reduce((props, p) => {
return [
...props,
AST.Property(
p,
AST.MemberExpression(objSymbol, p, p.srcloc),
p.srcloc
),
];
}, []);
const remainingObject = AST.RecordLiteral(properties, prop.srcloc);
// and a variable declaration using the remainder object assigning it to the rest variable
const restDecl = AST.VariableDeclaration(
prop,
remainingObject,
prop.srcloc
);
unnestedExprs.push(restDecl);
} else {
// need to assign the current variable's object property
const currentDecl = AST.VariableDeclaration(
prop,
AST.MemberExpression(objSymbol, prop, prop.srcloc)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(currentDecl);
used.push(prop.name);
}
i++;
}
return unnestedExprs;
}
Here's the entire transformVariableDeclaration
function:
const transformVariableDeclaration = (node) => {
let unnestedExprs = [];
const anfedExpr = anf(node.expression);
let anfedDecl;
let expression;
if (Array.isArray(anfedExpr)) {
expression = anfedExpr.pop();
anfedDecl = { ...node, expression };
unnestedExprs = unnestedExprs.concat(anfedExpr);
} else {
expression = anfedExpr;
anfedDecl = { ...node, expression };
}
if (node.lhv.kind === ASTTypes.VectorPattern) {
// is vector pattern destructuring
/** @type {import("../parser/ast.js").VectorPattern} */
const pattern = node.lhv;
let i = 0;
for (let mem of pattern.members) {
if (i === pattern.members.length - 1 && pattern.rest) {
// need to slice off the rest of the list/vector/tuple and assign it to the last member
const destructuredDecl = AST.VariableDeclaration(
mem,
AST.CallExpression(
AST.Symbol(Token.new(TokenTypes.Symbol, "slice", mem.srcloc)),
[
AST.NumberLiteral(
Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
),
expression,
],
mem.srcloc
)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(destructuredDecl);
} else {
// need to get the value from the current index of the list/vector/tuple and assign it to the current pattern member
const destructuredDecl = AST.VariableDeclaration(
mem,
AST.CallExpression(
AST.Symbol(Token.new(TokenTypes.Symbol, "get", mem.srcloc)),
[
AST.NumberLiteral(
Token.new(TokenTypes.Number, i.toString(), mem.srcloc)
),
expression,
]
)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(destructuredDecl);
}
i++;
}
return unnestedExprs;
} else if (node.lhv.kind === ASTTypes.RecordPattern) {
// is record pattern destructuring
// remember, we have the RHV's type at this point
/** @type {import("../parser/ast.js").RecordPattern} */
const pattern = node.lhv;
// first we need to assign the actual object to a fresh variable name
const objSymbol = createFreshSymbol(expression.srcloc);
const objDecl = AST.VariableDeclaration(
objSymbol,
expression,
expression.srcloc
);
unnestedExprs.push(objDecl);
let i = 0;
let used = [];
for (let prop of pattern.properties) {
if (i === pattern.properties.length - 1 && pattern.rest) {
// need to get the rest of the object's properties and assign them to the rest variable
// this maps the array of unused properties from the type to an array of Symbol
// nodes with each property name as the node name
const unusedProps = expression.type.properties
.filter((p) => {
return !used.includes(p.name);
})
.map((p) =>
AST.Symbol(Token.new(TokenTypes.Symbol, p.name, prop.srcloc))
);
// now create an object using the properties
const properties = unusedProps.reduce((props, p) => {
return [
...props,
AST.Property(
p,
AST.MemberExpression(objSymbol, p, p.srcloc),
p.srcloc
),
];
}, []);
const remainingObject = AST.RecordLiteral(properties, prop.srcloc);
// and a variable declaration using the remainder object assigning it to the rest variable
const restDecl = AST.VariableDeclaration(
prop,
remainingObject,
prop.srcloc
);
unnestedExprs.push(restDecl);
} else {
// need to assign the current variable's object property
const currentDecl = AST.VariableDeclaration(
prop,
AST.MemberExpression(objSymbol, prop, prop.srcloc)
);
// and push it onto the unnestedExprs array
unnestedExprs.push(currentDecl);
used.push(prop.name);
}
i++;
}
return unnestedExprs;
}
// If we get here, it's a simple variable declaration with a symbol as LHV
return [...unnestedExprs, anfedDecl];
};
Remember, we're only using toplevel destructuring without nested patterns. Adding nested patterns would make this even more complex.
Transforming Set Expressions
The code for transformSetExpressions
is similar to what just the simple symbol variable parts of transformVariableDeclaration
look like:
const transformSetExpression = (node) => {
const anfedExpr = anf(node.expression);
if (Array.isArray(anfedExpr)) {
let expression = anfedExpr.pop();
return [...anfedExpr, { ...node, expression }];
}
return [{ ...node, expression: anfedExpr }];
};
Transforming Vector Literals
To transform a vector literal we just need to unnest any nested expressions that make up the vector's members:
const transformVectorLiteral = (node) => {
let unnestedExprs = [];
let members = [];
for (let mem of node.members) {
let anfed = anf(mem);
if (Array.isArray(anfed)) {
members.push(anfed.pop());
unnestedExprs = unnestedExprs.concat(anfed);
} else {
members.push(anfed);
}
}
return [...unnestedExprs, { ...node, members }];
};
Transforming Record Literals
Same goes for transforming record literals where we unnest any nested expressions used in creating the record's properties:
const transformRecordLiteral = (node) => {
let unnestedExprs = [];
let properties = [];
for (let prop of node.properties) {
let anfed = anf(prop.value);
if (Array.isArray(anfed)) {
let value = anfed.pop();
properties.push({ ...prop, value });
unnestedExprs = unnestedExprs.concat(anfed);
} else {
properties.push({ ...prop, value: anfed });
}
}
return [...unnestedExprs, { ...node, properties }];
};
Transforming Member Expressions
The only complication in transforming a member expression is that the object can be practically any node type as long as the expression produces an object:
const transformMemberExpression = (node) => {
let unnestedExprs = [];
let anfedObject = anf(node.object);
if (Array.isArray(anfedObject)) {
const object = anfedObject.pop();
unnestedExprs = unnestedExprs.concat(anfedObject);
return [...unnestedExprs, { ...node, object }];
}
return [{ ...node, object: anfedObject }];
};
Transforming Do Expressions
Transforming a do expression is pretty much the same as transforming the Program:
const transformDoExpression = (node) => {
const body = node.body.flatMap(anf);
return { ...node, body };
};
Transforming As Expressions
In transforming as expressions, we're just going to transform the expression part and return it. That means we'll no longer need to handle as expressions in the emitter, just like we no longer need to handle destructuring in the emitter:
const transformAsExpression = (node) => {
const anfed = anf(node.expression);
if (Array.isArray(anfed)) {
return anfed;
}
return [anfed];
};
Transforming If Expressions
If expressions are straightforward. Transform the test, then append any unnested expressions to the unnested expressions array. Then do the same with the consequent (then) and alternate (else) branches:
const transformIfExpression = (node) => {
let unnestedExprs = [];
let transformedCondition = anf(node.test);
let test;
if (Array.isArray(transformedCondition)) {
test = transformedCondition.pop();
unnestedExprs = unnestedExprs.concat(transformedCondition);
} else {
test = transformedCondition;
}
let transformedThen = anf(node.then);
let then;
if (Array.isArray(transformedThen)) {
then = transformedThen.pop();
unnestedExprs = unnestedExprs.concat(transformedThen);
} else {
then = transformedThen;
}
let transformedElse = anf(node.else);
let elseBranch;
if (Array.isArray(transformedElse)) {
elseBranch = transformedElse.pop();
unnestedExprs = unnestedExprs.concat(transformedElse);
} else {
elseBranch = transformedElse;
}
return [...unnestedExprs, { ...node, test, then, else: elseBranch }];
};
Transforming When Expressions
When expressions are like a combination between if and do expressions. Transform the test like an if expression, then transform the body like a do expression:
const transformWhenExpression = (node) => {
let unnestedExprs = [];
let transformedCondition = anf(node.test);
let test;
if (Array.isArray(transformedCondition)) {
test = transformedCondition.pop();
unnestedExprs = unnestedExprs.concat(transformedCondition);
}
let body = node.body.flatMap(anf);
return [...unnestedExprs, { ...node, test, body }];
};
Transforming Logical Expression
Logical expressions are straightforward. Transform the left, transform the right, add anything that needs it to the unnested expressions array:
const transformLogicalExpression = (node) => {
let unnestedExprs = [];
let transformedLeft = anf(node.left);
let left;
if (Array.isArray(transformedLeft)) {
left = transformedLeft.pop();
unnestedExprs = unnestedExprs.concat(transformedLeft);
} else {
left = transformedLeft;
}
let transformedRight = anf(right);
let right;
if (Array.isArray(transformedRight)) {
right = transformedRight.pop();
unnestedExprs = unnestedExprs.concat(transformedRight);
} else {
right = transformedRight;
}
return [...unnestedExprs, { ...node, left, right }];
};
That's all the transformation functions!
Figuring out how to do AST transformations vexed me for a long time, so I really hope this helps you understand how transforming the tree works.
Tail Call Optimization
Now let's take the transformed tree and use it to detect recursive tail calls.
To detect recursive tail calls we traverse the whole transformed AST. When we come to a lambda expression, we check if it has a name
property (which gets added when function declarations are desugared into variable declarations with lambdas).
If the last (tail position) expression in the lambda body is a call expression, or if it's an if, do, or logical expression that contains a call expression in tail position, we check to see if the func
property of that call expression is a Symbol
node with the same name. If it is, then it's a recursive tail call and we can mark it as such.
We'll mark the call expression, lambda expression, and if the tail call is part of an if expression the if expression with a boolean flag indicating if it's tail recursive.
We'll use a class and extend Visitor
for this, because we only need functions for a couple of node types but we need to traverse the whole tree.
First the imports, in src/transform/tco.js
:
import { ASTTypes } from "../parser/ast.js";
import { Visitor } from "../visitor/Visitor.js";
We need a couple of helper functions. First, we'll need one to swap the last node in an expression body. We need to replace tail call nodes with new nodes that mark them as tail calls. Note that the 1st argument to this is the node with the body
property, not the body itself:
const swapLastExpr = (node, expr) => {
node.body[node.body.length - 1] = expr;
};
We also need a function to check if expressions. We'll use a function for this instead of using the regular visitor method because we only need to check if expressions that are in tail position in the body of a lambda or do expression. All other if expressions can be handled as normal.
This function recursively checks if expressions that are the then
or else
branch of the if expression being checked:
const checkIfExpression = (node, name, visitor) => {
let isTailRec = false;
if (
node.then.kind === ASTTypes.CallExpression &&
node.then.func.name === name
) {
node.then = visitor.visitCallExpression(node.then, true);
isTailRec = true;
}
if (
node.else.kind === ASTTypes.CallExpression &&
node.else.func.name === name
) {
node.else = visitor.visitCallExpression(node.else, true);
isTailRec = true;
}
if (node.then.kind === ASTTypes.IfExpression) {
node.then = checkIfExpression(node.then, name, visitor);
if (node.then.isTailRec) {
isTailRec = true;
}
}
if (node.else.kind === ASTTypes.IfExpression) {
node.else = checkIfExpression(node.else, name, visitor);
if (node.else.isTailRec) {
isTailRec = true;
}
}
return { ...node, isTailRec };
};
Ok, first let's stub out the class for the TCO transformer:
class TCOTransformer extends Visitor {
constructor(program) {
super(program);
}
static new(program) {
return new TCOTransformer(program);
}
}
Now let's add an override for visitCallExpression
that takes an extra, optional parameter isTailRec
that defaults to false
. It has to have a default because it will only get called with true
from our override version of visitLambdaExpression
:
visitCallExpression(node, isTailRec = false) {
return { ...node, isTailRec };
}
Ok, now let's turn to visitLambdaExpression
. It's a lot of code, but it's straightforward. If the lambda node has a name property, save it to a variable. Then get the last expression from the lambda body.
If the last expression is a call expression we check to see if its func
property is a Symbol
node with the same name as the lambda.
If it's a do expression, if expression, or logical expression we check to see if either the last expression (of a do expression) or either branch (of an if or logical expression) contains a call in tail position that calls a function with the same name.
If we detect a recursive tail call, we mark the call expression and lambda expression both with flags isTailRec
set to true
.
visitLambdaExpression(node) {
const name = node.name ?? "";
const lastExpr = node.body[node.body.length - 1];
// Could be tail recursive: CallExpression, DoExpression, IfExpression, LogicalExpression
if (lastExpr.kind === ASTTypes.CallExpression) {
// should only work if func is symbol
if (lastExpr.func.name === name) {
const newCall = this.visitCallExpression(lastExpr, true);
swapLastExpr(node, newCall);
return { ...node, isTailRec: true };
}
lastExpr.isTailRec = false;
return { ...node, isTailRec: false };
} else if (lastExpr.kind === ASTTypes.DoExpression) {
const lastBodyExpr = lastExpr.body[lastExpr.body.length - 1];
if (
lastBodyExpr.kind === ASTTypes.CallExpression &&
lastBodyExpr.func.name === name
) {
const newCall = this.visitCallExpression(lastBodyExpr, true);
swapLastExpr(lastExpr, newCall);
return { ...node, isTailRec: true };
} else if (lastBodyExpr.kind === ASTTypes.IfExpression) {
const newIf = checkIfExpression(lastBodyExpr, name, this);
swapLastExpr(lastExpr, newIf);
return { ...node, isTailRec: newIf.isTailRec };
} else if (lastBodyExpr.kind === ASTTypes.LogicalExpression) {
let isTailRec = false;
if (
lastBodyExpr.left.kind === ASTTypes.CallExpression &&
lastBodyExpr.left.func.name === name
) {
isTailRec = true;
lastBodyExpr.left = this.visitCallExpression(lastBodyExpr.left, true);
}
if (
lastBodyExpr.right.kind === ASTTypes.CallExpression &&
lastBodyExpr.right.func.name === name
) {
isTailRec = true;
lastBodyExpr.right = this.visitCallExpression(
lastBodyExpr.right,
true
);
}
return { ...node, isTailRec };
} else {
lastExpr.isTailRec = false;
return { ...node, isTailRec: false };
}
} else if (lastExpr.kind === ASTTypes.IfExpression) {
let newIf = checkIfExpression(lastExpr, name, this);
swapLastExpr(node, newIf);
return { ...node, isTailRec: newIf.isTailRec };
} else if (lastExpr.kind === ASTTypes.LogicalExpression) {
let isTailRec = false;
if (
lastExpr.left.kind === ASTTypes.CallExpression &&
lastExpr.left.func.name === name
) {
lastExpr.left = this.visitCallExpression(lastExpr.left, true);
isTailRec = true;
}
if (
lastExpr.right.kind === ASTTypes.CallExpression &&
lastExpr.right.func.name === name
) {
lastExpr.right = this.visitCallExpression(lastExpr.right, true);
isTailRec = true;
}
return { ...node, isTailRec };
}
return { ...node, isTailRec: false };
}
Finally, we export a function that constructs the TCO transformer and runs it on a program:
export const tco = (program) => TCOTransformer.new(program).visit();
Now we need a function to handle all transformations, in src/transform/transform.js
:
import { anf } from "./anf.js";
import { tco } from "./tco.js";
export const transform = (program) => tco(anf(program));
Changes to The Runtime
Now that we've added the ability to detect tail recursive calls, we need to optimize them.
The simplest way to do that in JavaScript is with a trampoline.
A trampoline is a loop that handles recursive calls. We're going to trampoline our tail recursive functions by making the tail call return a special object that includes the function and arguments to the tail call as properties on the object. You'll see how we rewrite the tail recursive call to achieve this when we get to the emitter.
The setup actually has 2 parts: a recur
function that returns the special object and a trampoline
function that returns the function that runs the loop.
It sounds complicated, but the two functions are actually very simple, in src/runtime/trampoline.js
:
import { makeWandaValue } from "./conversion.js";
// based on this Stack Overflow answer: https://stackoverflow.com/a/50493099
export const recur = (f, ...args) => ({ tag: recur, f, args });
export const trampoline = (f) => {
const trampolined = (...args) => {
let t = f(...args);
while (t && t.tag === recur) {
t = t.f(...t.args);
if (t && t.tag !== recur) {
return makeWandaValue(t);
}
}
};
trampolined.f = f;
return trampolined;
};
We also need to make a minor change to our makeFunction
function in the runtime.
Currently the function we create in makeFunction
returns a call to makeWandaValue
, which means right now if we rewrite the recursive call to use rt.recur
the object that returns will be transformed into a Wanda object, which will mess up the trampoline function's ability to process the special object.
So let's add a tailRec
property to the options argument that defaults to false
so when we detect a tail recursive function we can prevent it from transforming the object returned by the tail call. We'll let the trampoline function handle converting the final value to a Wanda value.
Here's our new makeFunction
function in src/runtime/makeFunction.js
:
import objectHash from "object-hash";
import { curryN } from "ramda";
import { makeWandaValue } from "./conversion.js";
import { addMetaField } from "./object.js";
import { parseContract } from "./parseContract.js";
export const makeFunction = (
func,
{ contract = "", name = "", tailRec = false } = {}
) => {
let fn = curryN(func.length, (...args) => {
const val = tailRec ? func(...args) : makeWandaValue(func(...args));
if (typeof val === "function") {
return makeFunction(val);
}
return val;
});
const hash = objectHash(func);
addMetaField(fn, "wanda", true);
addMetaField(fn, "arity", func.length);
addMetaField(fn, "name", name || hash);
if (contract !== "") {
Object.defineProperty(fn, "contract", {
enumerable: false,
configurable: false,
writable: false,
value: parseContract(contract),
});
}
Object.defineProperty(fn, "name", {
enumerable: false,
configurable: false,
writable: false,
value: name || hash,
});
return fn;
};
Then we need to add trampoline
and recur
to makeRuntime
in src/runtime/makeRuntime.js
:
// other imports
import { trampoline, recur } from "./trampoline.js";
export const makeRuntime = () => {
return {
// other members
trampoline,
recur,
};
};
Now that our trampoline is in place, we need to modify the emitter to use it.
Changes to The Emitter
The first thing we're going to do to our emitter is remove everything related to handling destructuring and as expressions, since thanks to our ANF transformation they won't be making it to the emitter anymore.
You can delete the emitAsExpression
method and the case for it from the emit
method. This leaves you with this for emit
in src/emitter/Emitter.js
:
emit(node = this.program, ns = this.ns) {
switch (node.kind) {
case ASTTypes.Program:
return this.emitProgram(node, ns);
case ASTTypes.NumberLiteral:
return this.emitNumber(node, ns);
case ASTTypes.StringLiteral:
return this.emitString(node, ns);
case ASTTypes.BooleanLiteral:
return this.emitBoolean(node, ns);
case ASTTypes.KeywordLiteral:
return this.emitKeyword(node, ns);
case ASTTypes.NilLiteral:
return this.emitNil(node, ns);
case ASTTypes.Symbol:
return this.emitSymbol(node, ns);
case ASTTypes.CallExpression:
return this.emitCallExpression(node, ns);
case ASTTypes.VariableDeclaration:
return this.emitVariableDeclaration(node, ns);
case ASTTypes.SetExpression:
return this.emitSetExpression(node, ns);
case ASTTypes.DoExpression:
return this.emitDoExpression(node, ns);
case ASTTypes.TypeAlias:
return this.emitTypeAlias(node, ns);
case ASTTypes.MemberExpression:
return this.emitMemberExpression(node, ns);
case ASTTypes.RecordLiteral:
return this.emitRecordLiteral(node, ns);
case ASTTypes.RecordPattern:
return this.emitRecordPattern(node, ns);
case ASTTypes.VectorLiteral:
return this.emitVectorLiteral(node, ns);
case ASTTypes.VectorPattern:
return this.emitVectorPattern(node, ns);
case ASTTypes.LambdaExpression:
return this.emitLambdaExpression(node, ns);
case ASTTypes.LogicalExpression:
return this.emitLogicalExpression(node, ns);
case ASTTypes.IfExpression:
return this.emitIfExpression(node, ns);
case ASTTypes.WhenExpression:
return this.emitWhenExpression(node, ns);
default:
throw new SyntaxException(node.kind, node.srcloc);
}
}
You can also delete the emitVariableDeclarationAssignment
method since we won't be using it anymore.
Finally, you can vastly simplify emitVariableDeclaration
so it looks like this:
emitVariableDeclaration(node, ns) {
const name = node.lhv.name;
const translatedName = makeSymbol(name);
if (ns.has(name)) {
throw new ReferenceException(
`Name ${name} has already been accessed in the current namespace; cannot access name before its definition`,
node.srcloc
);
}
ns.set(name, translatedName);
return `var ${makeSymbol(name)} = ${this.emit(node.expression, ns)}`;
}
Now to handle tail recursion and the trampoline.
In emitLambdaExpression
change the last line beginning with code +=
to handle the new tailRec
option for rt.makeFunction
:
code += `${
node.name
? `, { name: "${node.name}"${node.isTailRec ? ", tailRec: true" : ""} }`
: ""
})`;
Also change the return statement to use the trampoline if node.isTailRec
is true:
return node.isTailRec ? `rt.trampoline(${code})` : code;
Now in emitCallExpression
we're going to separate the call expression's function from its arguments since the rt.recur
function needs them separated. Then if it's a tail call we call rt.recur
and pass it the original function (NOT the trampolined version) as well as the arguments. We pass it the original function, which is stored in the f
property added to the trampolined function, because if the trampoline function called itself we'd just run into the same recursion limits. Here's the new version of emitCallExpression
:
emitCallExpression(node, ns) {
const func = `(${this.emit(node.func, ns)})`;
const args = `${node.args.map((a) => this.emit(a, ns)).join(", ")}`;
return node.isTailRec ? `rt.recur(${func}.f, ${args})` : `${func}(${args})`;
}
That's it for the emitter! Now we just need to add the transformation to our compilation pipeline and it will be done.
Changes to The CLI
In src/cli/compile.js
we need to import our transform
function:
import { transform } from "../transform/transform.js";
And finally we need to add transform
to the compile
function. Here's the new compile
function:
export const compile = (
input,
file = "stdin",
ns = undefined,
typeEnv = undefined
) =>
emit(
transform(
desugar(typecheck(parse(expand(read(tokenize(input, file)))), typeEnv))
),
ns
);
And with that we have implemented TCO in our compiler. Were you expecting it to be more difficult? I was certainly surprised at how simple and straightforward it turned out to be.
Trying It Out
Ok, let's fire up a REPL and try it out. First, try the naive factorial function:
(def fact (n)
(if (= n 1)
n
(* n (fact (- n 1)))))
Try it with (fact 1000)
. Ok, yeah, the answer is Infinity
, so we're not going to get better numeric answers for bigger numbers, but that's not the point.
Try it with (fact 2000)
. If your Node instance is configured like mine, you just got a stack overflow error.
Close your REPL and then open it again (note to self: add a command to refresh the REPL state), then try the tail recursive version of fact
:
(def fact (n a)
(if (= n 1)
a
(fact (- n 1) (* a n))))
Start with (fact 1000 1)
. Now keep incrementing by 1000 and seeing what happens.
I got to 10,000 and then decided to do something crazy: (fact 100000 1)
.
Infinity
.
No stack overflow error, even with 100,000 calls!
I'd say the trampoline works pretty damned well.
Conclusion
I'm thrilled that I was able to get this working the way I wanted to. Now we have true tail call optimization with virtually unlimited tail recursion! I'm excited with where we're at right now.
Like I said, AST transformations are something I struggled with conceptually for a long time before I finally figured out how to make this work.
As always, you can see the current state of the code at the relevant tag in the GitHub repo.
This is currently the last planned post in the series, though it's possible I'll come back and add more later. I hope you've had as much fun reading as I've had writing!
I also hope this inspires you to go out and create your own languages and language related tools. Let me know if you make something you think is cool!
Top comments (0)