DEV Community

Cover image for Writing an equation parser in Scala
Leo Pfeiffer
Leo Pfeiffer

Posted on

Writing an equation parser in Scala

I recently started diving into Scala and as a first learning project, I decided to build a functional mathematical expression tokenizer, parser and evaluator. In this post I will walk through the implementation and some of the core algorithms. Feel free to reach out with any questions!

You can find the final project here.

Overview

The goal of the project is to be able to pass a mathematical expression (e.g. 1 + 2 * 3 - 4) to the program and get the value of the expression back (i.e. 3).

Internally, this involves three steps. First, the raw string expression needs to be tokenized into a well-defined format as preparation for further processing. Second, a parser reads the tokens and parses them into a mathematical expression following the usual precedence rules. Lastly, the expression is evaluated by combining its components.

Allowed operations

Let's quickly define what the final program should be able to do.

Numbers should be specified as integers (e.g. 42) or decimals (e.g. 1.23) and should evaluated as doubles.

Negative numbers should be evaluated correctly, but should always be wrapped in parentheses (i.e. write (-42)+2 instead of -42+2).

The following operations should be allowed:

+    Addition
-    Difference
*    Multiplication
/    Division
^    Power
Enter fullscreen mode Exit fullscreen mode

Left ( and right ) parentheses should also be allowed and the resulting precedence should be taken into account.

Defining Expressions

We start by defining the expressions. In total, we will need to define the expressions Number, Difference, Product, Division, and Power.

In fact, each component of an expression is itself an expression. For example, 1 + 2 consists of two expressions of type Number, namely Number(1) and Number(2), as well as a Sum operator.

Moreover, every expression is either of type Number (which wraps a double) or Operator (which has a left and a right expression). Continuing with the example from above, the Sum operator has the left expression Number(1) and the right expression Number(2).

Lastly, we differentiate between regular operators and Commutative operators. A commutative operator does not differentiate between left and right, as the order doesn't matter. For example, x+y always equals y+x (commutative) but x-y does not necessarily equal y-x.

With that in mind, we can define the Expression trait and the Operator trait, which extends Expression.

/** Represents arithmetic expression. */
trait Expression

/** Arithmetic operator */
trait Operator extends Expression {
    val left: Expression
    val right: Expression
}
Enter fullscreen mode Exit fullscreen mode

As laid out, a Commutative expression is simply an Operator that doesn't differentiate between left and right. We can achieve this behavior by extending Operator and overriding the equals def and hashCode def accordingly.

/** 
 * Represents operators whose left and right side are commutative,
 * i.e. the order of the LHS and RHS expression does not matter.
 * */
trait Commutative extends Operator {
    def canEqual(a: Any) = a.isInstanceOf[Commutative]

    override def equals(that: Any): Boolean =
        that match {
            case that: Commutative => {
                that.canEqual(this) &&
                ((this.left == that.left && this.right == that.right) ||
                    (this.left == that.right && this.right == that.left))
            }
            case _ => false
        }

    override def hashCode: Int = {
        val prime = 31
        prime + left.hashCode * right.hashCode
    }
}
Enter fullscreen mode Exit fullscreen mode

The actual expression types can now be implemented as case classes that extend any of the Expression traits.

case class Number(n: Double) extends Expression {
    def value = n
}

case class Sum(left: Expression, right: Expression) extends Commutative
case class Difference(left: Expression, right: Expression) extends Operator
case class Product(left: Expression, right: Expression) extends Commutative
case class Division(left: Expression, right: Expression) extends Operator
case class Power(left: Expression, right: Expression) extends Operator
Enter fullscreen mode Exit fullscreen mode

Evaluating these expressions is now almost trivial. Any operator can be evaluated by simply applying the underlying mathematical operation to the left and right sides (which in turn need to be evaluated). For Number expressions, we simply return the underlying value. Scala's pattern matching comes in handy here.

/**
 * Evaluate an expression.
 * @param expression: expression to evaluate
 * @return value of the expression
 * */
def evaluate(expression: Expression): Double = expression match
    case Number(n) => n
    case Sum(left, right) => evaluate(left) + evaluate(right)
    case Difference(left, right) => evaluate(left) - evaluate(right)
    case Product(left, right) => evaluate(left) * evaluate(right)
    case Division(left, right) => evaluate(left) / evaluate(right)
    case Power(left, right) => scala.math.pow(evaluate(left), evaluate(right))
Enter fullscreen mode Exit fullscreen mode

That's the expression implementation done. If we wanted to represent the example 1 + 2 * 3 - 4 from earlier using our implementation, it would give us

Difference(Sum(Number(1), Product(Number(2), Number(3))), Number(4))

We can also express these expressions as a syntax tree. For example, parsing the equation 3+4*2/(1-5)^2^3 creates the following syntax tree.

Syntax Tree

The final version of Expression.scala can be found here.

Tokenizer

Next, we will take care of tokenizing the raw string expression.

The tokenizer splits the raw expression into its components and encodes each component as a token. The token closely follow the Expression classes defined earlier, however, we will also have to handle parentheses here.

We start by defining a simple abstract class Token and an abstract subclass OperatorToken, which represents operators. Each operator has a precedence (lower precedence operators are evaluated first).

/** Token of an expression */
abstract class Token()

/** Token representing an operator */
abstract class OperatorToken() extends Token {
    /** Precedence value of the operator */
    def precedence: Int
}
Enter fullscreen mode Exit fullscreen mode

Next, we have to handle how the different operators associate. A left associative operator evaluates from left to right, while a right associative takes the opposite direction.

This is implemented using the following three traits.

trait Associates

/** Associates left */
trait Left extends Associates

/** Associates right */
trait Right extends Associates
Enter fullscreen mode Exit fullscreen mode

This gives us all the components to define our actual tokens. NumberToken, LeftParensToken, and RightParensToken extend Token directly. All others are OperatorToken subclasses an also extend an Associates trait. Only the power token (^) associates right, all others associate left.

/** Token representing sum */
case class SumToken() extends OperatorToken, Left {
    def precedence = 2
}

/** Token representing difference */
case class DifferenceToken() extends OperatorToken, Left {
    def precedence = 2
}

/** Token representing product */
case class ProductToken() extends OperatorToken, Left {
    def precedence = 3
}

/** Token representing division */
case class DivisionToken() extends OperatorToken, Left {
    def precedence = 3
}

/** Token representing power */
case class PowerToken() extends OperatorToken, Right {
    def precedence = 4
}

/** Token representing number */
case class NumberToken(n: Double) extends Token

/** Token representing left parenthesis */
case class LeftParensToken() extends Token

/** Token representing right parenthesis */
case class RightParensToken() extends Token
Enter fullscreen mode Exit fullscreen mode

What remains is the function to actually convert a raw string expression into a list of tokens.

The function will have the following signature (we will fill in the body step by step).

/** 
 * Tokenize string expression.
 * 
 * Tokenize a string representation of an arithmentic expression. 
 * 
 * @param rawExpression: String representation of expression.
 * @return List of tokens of the expression.
 * */
def tokenize(rawExpression: String): List[Token] = ???
Enter fullscreen mode Exit fullscreen mode

In the method body, we will have to add the following.

Splitting the raw string expression

    val splitted = rawExpression
        .filterNot(_.isWhitespace)
        .split("(?=[)(+/*-])|(?<=[)(+/*-])|(?=[\\^])|(?<=[\\^])")
        .map(_.trim)
Enter fullscreen mode Exit fullscreen mode

This takes the raw string, splits it into its parts using a regular expression, and removes any whitespace.

Tokenizing string components

The val splitted now contains a list of strings that can be converted into tokens. For that, we define the tokenizeOne function and map it to each element of splitted.

// Regex representing a double.
val numPattern = "(\\-?\\d*\\.?\\d+)".r

/**
 * Tokenize a single string.
 * 
 * @param x: String to tokenize
 * @returns Corresponding token
 * */
def tokenizeOne(x: String) = x match {
    case "+" => SumToken()
    case "-" => DifferenceToken()
    case "*" => ProductToken()
    case "/" => DivisionToken()
    case "^" => PowerToken()
    case "(" => LeftParensToken()
    case ")" => RightParensToken()
    case numPattern(c: String) => NumberToken(c.toDouble)
    case _ => throw RuntimeException(s""""$x is not legal""")
}

// tokenize each element
val tokenized = splitted.map(tokenizeOne).toList
Enter fullscreen mode Exit fullscreen mode

Handling negative numbers

At this stage, the tokenizer can handle all (valid) input and convert it into a list of tokens. However, we would run into problems later, if we parsed negative numbers (e.g. 1+(-1)) since the - is converted into a DifferenceToken, yet there is not NumberToken to its left.

To handle this, we will walk through the list of tokens, and if we find a (- sequence in the list, we insert a NumberToken(0) between the ( and the -, producing a valid negative number.

The algorithm could be implemented more concisely with a loop, however, I wanted to find a functional implementation. Thus, I came up with the following:

/**
 * Handle negative numbers.
 * 
 * Negative numbers are prefixed with a zero
 * e.g. (-1) -> (0-1)
 * to maintian both a left and right expression of the Difference operator.
 * 
 * @param tokens: list of tokens to handle
 * @returns tokens with inserted zeros
 * */
def handleNegative(tokens: List[Token]): List[Token] = 

    /**
     * Insert prefix zero.
     * 
     * If the token list starts with "(, -" insert a zero.
     * 
     * @param tokens: list of params to check
     * @return token list with inserted zero
     * */
    def insert(tokens: List[Token]): List[Token] = tokens match {
        case a :: b :: rest => {
            a match {
                case _a: LeftParensToken => {
                    b match {
                        case _b: DifferenceToken => _a :: NumberToken(0) :: Nil
                        case _ => _a :: Nil
                    }
                }
                case _ => a :: Nil
            }
        }
        case a :: Nil => a :: Nil
        case _ => Nil
    }

    /** Recursively insert zeros where necessary */ 
    def recur(tokens: List[Token]): List[Token] =
        if tokens.isEmpty then Nil
        else if tokens.tail.isEmpty then tokens
        else insert(tokens) ++ recur(tokens.tail)

    recur(tokens)
Enter fullscreen mode Exit fullscreen mode

The algorithm recursively explores the list. If the list starts with the defined pattern of (-, it returns (0- plus the tail of the list. Else, it simply returns the tail of the list.

In the body of tokenize we now simply call

// handle negative numbers
handleNegative(tokenized)
Enter fullscreen mode Exit fullscreen mode

as the last line.

The final Tokenizer.scala file can be found here

Parser

In this last step, we can finally write the actual parser, that converts the token list into an expression that can be evaluated.

The parser will work in two steps.

  1. Convert the token list from infix to postfix notation
  2. Parse the postfix token list into an expression

Converting infix to postfix

When we input the expression as a raw string, we use what is called infix notation. While this makes sense for human readability, the parsing algorithm of the implementation requires postfix notation.

To given an example, the infix expression (5-6) * 7 expressed in postfix notation is * -5 6 7.

The algorithm to do this conversion used here is called Shunting Yard algorithm. My implementation is based on the pseudo code provided on Wikipedia, but implemented recursively. I won't go into much detail as to how the algorithm works as I think the Wikipedia article does a pretty good job at explaining that.

Let's start by defining the core recursive procedure of the function.

/**
 * Shunting Yard algorithm.
 * 
 * Converts a list of tokens from infix to postfix notation.
 * https://en.wikipedia.org/wiki/Shunting_yard_algorithm
 * 
 * @param tokens: List of tokens in infix notation
 * @return list of tokens in postfix notation
 * */
def shuntingYard(tokens: List[Token]): List[Token] =

    // todo: we will fill in the helper functions later

    /**
     * Recursive method of shunting yard.
     * 
     * @param stack: Stack of tokens left to place
     * @param postfix: Tokens converted to postfix notation
     * @param tokens: Tokens in infix notation
     * @return tokens in postfix notation
     * */
    @tailrec
    def recur(stack: Stack[Token], postfix: List[Token], tokens: List[Token]): List[Token] = 
        tokens match {
            case Nil => postfix ++ stack
            case t :: rest => {
                t match {
                    case n: NumberToken => recur(stack, postfix :+ n, rest)
                    case o: OperatorToken => {
                        if (stack.isEmpty) then recur(stack.push(o), postfix, rest)
                        else {
                            val updated = operatorUpdate(postfix, stack, o)
                            recur(updated._1, updated._2, rest)
                        }
                    }
                    case l: LeftParensToken => recur(stack.push(l), postfix, rest)
                    case r: RightParensToken => {
                        val updated = rightParensUpdate(postfix, stack)
                        recur(updated._1, updated._2, rest)
                    }
                }
            }
        }

    recur(new Stack[Token], List(), tokens)
Enter fullscreen mode Exit fullscreen mode

Essentially, this procedure recursively walks over the token list and depending on the token it encounters, it either places the token on a stack or combines the token with the last token from the step.

Notice that we're still missing some helper functions that were crated to make the core procedure more readable. Insert the following helper functions into shuntingYard before the recur function.

/** Helper method to determine if token if left associative. */
def isLeftAssoc(t: Token): Boolean = t match {
    case a: Left => true
    case _ => false
}

/** Helper method to determine if token if right associative. */
def isRightAssoc(t: Token): Boolean = t match {
    case a: Right => true
    case _ => false
}

/** Helper method to determine if token if left parenthesis. */
def isLeftParens(t: Token): Boolean = t match {
    case a: LeftParensToken => true
    case _ => false
}

/** Helper method to update postfix and stack during operator parsing. */
@tailrec
def operatorUpdate(postfix: List[Token], stack: Stack[Token], o: OperatorToken): (Stack[Token], List[Token]) =
    def matchCond(o: OperatorToken, stack: Stack[Token]): Boolean = stack.head match {
        case o2: OperatorToken => {
            (isLeftAssoc(o) && (o.precedence <= o2.precedence)) ||
            (isRightAssoc(o) && (o.precedence < o2.precedence))
        }
        case _ => false
    }
    if (stack.isEmpty || !matchCond(o, stack)) (stack.push(o), postfix)
    else operatorUpdate(postfix :+ stack.pop, stack, o)

/** Helper method to update postfix and stack during right parens parsing. */
@tailrec
def rightParensUpdate(postfix: List[Token], stack: Stack[Token]): (Stack[Token], List[Token]) =
    if (isLeftParens(stack.head)) {stack.pop; (stack, postfix)}
    else rightParensUpdate(postfix :+ stack.pop, stack)
Enter fullscreen mode Exit fullscreen mode

The shuntingYard function returns a new token list in postfix notation.

Parsing postfix token list into an expression

Given the postfix token list, the second step of the parser converts it into a full expression.

The algorithm is quite simple. It recursively walks through the token lists. If a number token is encountered, it is placed on a stack. If an operator is encountered, the last two elements from the stack are combined accordingly and the resulting expression is put on the stack. Eventually, the stack only contains one expression, which is the final one.

/**
 * Parses RPN to expression.
 * 
 * Takes a list of tokens in RPN and parses the expression representation.
 * https://en.wikipedia.org/wiki/Reverse_Polish_notation
 * 
 * @param tokens: List of tokens in postfix notation
 * @return parsed expression
 * */
def parsePostfix(tokens: List[Token]): Expression = 

    /** Helper method to determine if token is operator. */
    def isOperator(t: Token): Boolean = t match {
        case a: OperatorToken => true
        case _ => false
    }

    /**
     * Recursive method of the algorithm.
     * 
     * @param stack: Stack of expressions to parse
     * @param tokens: Tokens to parse.
     * @return parsed expression
     * */
    @tailrec
    def recur(stack: Stack[Expression], tokens: List[Token]): Expression = tokens match {
        case Nil => stack.pop
        case t :: rest => {
            if (isOperator(t)) t match {
                case t: SumToken => stack.push(Sum(stack.pop, stack.pop))
                case t: DifferenceToken => val x = stack.pop; stack.push(Difference(stack.pop, x))
                case t: ProductToken => stack.push(Product(stack.pop, stack.pop))
                case t: DivisionToken => val x = stack.pop; stack.push(Division(stack.pop, x))
                case t: PowerToken => val x = stack.pop; stack.push(Power(stack.pop, x))
                case _ => throw new RuntimeException(s""""$t" is not an operator""")
            }
            else t match {
                case t: NumberToken => stack.push(Number(t.n))
                case _ => throw new RuntimeException(s""""$t" is not valid here""")   
            }
            recur(stack, rest)
        }
    }

    recur(new Stack[Expression], tokens)
Enter fullscreen mode Exit fullscreen mode

Putting it together

With the shuntingYard and parsePostfix functions in place, we can define the parse function, which we can call to actually perform the parsing.

/**
 * Run the parser.
 * 
 * Converts the tokens to postfix notation and then reverses it to RPN.
 *
 * @param tokens: List of tokens in infix notation
 * @return Parsed expression
 * */
def parse(tokens: List[Token]): Expression =
    val postfix = shuntingYard(tokens)
    parsePostfix(postfix)
Enter fullscreen mode Exit fullscreen mode

And that is the parser done. You can find the final version of Parser.scala here.

Main.scala

To make the project runnable, we need a main method in Main.scala. This is just a few lines:

object Main {
    def main(args: Array[String]) = args match
        case Array(x: String) => printResult(getResult(x))
        case Array() => throw new java.lang.IllegalArgumentException("Too few arguments!")
        case _ => throw new java.lang.IllegalArgumentException("Too many arguments!")

    /** Tokenize and parse expression. */
    def getResult(rawExpr: String): Expression = tokenize(rawExpr).parser
    /** Print the result of the evaluation. */
    def printResult(expr: Expression): Unit = println(expr.eval)
}
Enter fullscreen mode Exit fullscreen mode

Testing the parser

Let's try out the parser on some expressions. Fire up your sbt server and run some examples.

sbt:equation-parser> run "1+2"
3.0
Enter fullscreen mode Exit fullscreen mode
sbt:equation-parser> run "3 * (1 + 2) ^ 7"
6561.0
Enter fullscreen mode Exit fullscreen mode
sbt:equation-parser> run "100 / 8 - (2 * 3) + 4 ^ 3"
70.5
Enter fullscreen mode Exit fullscreen mode
sbt:equation-parser> run "((3 + 2) * (2 + 1)) ^ 2"
225.0
Enter fullscreen mode Exit fullscreen mode
sbt:equation-parser> run "(-42)^3+(-42)"
-74130.0
Enter fullscreen mode Exit fullscreen mode

Seems to work!

(We should probably write some more thorough tests to make sure it actually works, which I did here)

Conclusion

Implementing this Tokenizer/Parser/Evaluator in a functional way in Scala proved to be a very insightful and fun learning project. Maybe it will inspire anyone to build their own parser or improve upon mine.

Feel free to reach out with any comments or feedback!

Top comments (0)