Tuesday, June 26, 2012

Movimentum - Three basic algorithms: Constant Folding

Let us start bottom up: We know that we need a simplification algorithm; we know that we need a matching algorithm: and we know that we need a replacement algorithm, at least for variables.

Here is a standard simplification that I have already mentioned, namely constant folding.

The design idea for constant folding is as follows: We visit the subexpressions bottom-up. For each compound expression (unary or binary expression), we have three possible cases:
  • If all subexpressions are constants, we replace the whole expression with a new constant computed from the sub-constants.
  • Otherwise, at least one of the subexpressions is not a constant, even after simplification. We now check whether at least one of them has changed, i.e., it or one of its descendants has been simplified. If this is the case, we must recreate a new similar expression with the new subexpressions.
  • Else (no subexpression has changed, and at least one subexpression is not a constant), we return the original expression.
Maybe I should explain here a fundamental (and standard) design decision: All the expression objects will always and forever be immutable, i.e., nothing will ever be changed in them. The reason for this is that we can then share such objects between multiple data structures. Here is a diagram that shows the data structures that result from constant folding of (x + 1) + (2 + 3):


The subexpression for x+1 is shared between the original expression tree and the new expression tree. Some consequences of this sharing are:
  • We reduce the memory footprint. For example, if we run constant folding on an expression tree that has already been folded, no new objects will be created.
  • We can share expressions even when we create them. For example, instead of writing new Constant(...) each time we need a constant, we could use a static factory method that looks into an internal set of already created constants; and returns one of them if the value is already in there. In the case of strings, this is called "interning" (and is done e.g. by Java and .Net compilers).
  • If we can force algorithms to visit a single structure only once, this would also save runtime—we will maybe see examples later.
  •  A structural consequence is that we cannot have parent pointers, as each subexpression may have many parents. This restricts the design of algorithm to some sort of functional style —which, on the whole, is a good thing: This restriction prevents that we build contrived, complex, "clever" algorithms that are hard to understand (and to prove correct). If, at some point, we actually do need doubly-linked expression structures, we will most probably consider building a parallel data structure—right now, there is certainly not yet any need for this.
  • Another structural consequence is that we can have "diamonds" in expression trees: Identical subexpressions could be created only once. "Identical subexpression elimination" is a common step in compilers—if we don't forget this idea, we might check later whether it helps us in Movimentum.
  • Even though subexpressions are shared among expression trees, the immutability would still allow us to work on expression trees in parallel. I will not explore parallel expression algorithms for Movimentum, but it is certainly an interesing topic.
At the moment, I will not follow the ideas present in these items, because I want to concentrate on the solver algorithm proper. But when we (have to) optimize the solver later, we'll certainly return to some of them.

Incidentally, the design decision to make expression objects immutable means that a phrase like "a subexpression has changed" is actually meaningless. So what does it mean when I wrote "...whether at least one of [the subexpressions] has changed..." in the second case of our algorithm? Of course, I mean "...whether the visiting result is different from the subexpression..." The code will show this more clearly.

We start with two simple test cases:

    [Test]
    public void TestConstant() {
        Constant input = new Constant(1.5);
        AbstractExpr result = input.Accept(visitor, Ig.nore);
        Assert.AreEqual(new Constant(1.5), result);
        Assert.AreSame(input, result);
    }

    [Test]
    public void TestVariable() {
        NamedVariable input = new NamedVariable("a");
        AbstractExpr result = input.Accept(visitor, Ig.nore);
        Assert.AreSame(input, result);
    }

The assertion that we actually return the same expressions is essential, as the constant folding algorithm for the compound expressions expects this.

The corresponding implementation is as simple as it gets:

class ConstantFoldingVisitor : ISolverModelExprVisitor<AbstractExpr> {
    ...
    public AbstractExpr Visit(Constant constant, Ignore p) {
        return constant;
    }

    public AbstractExpr Visit(NamedVariable namedVariable, Ignore p) {
        return namedVariable;
    }

    public AbstractExpr Visit(AnchorVariable anchorVariable, Ignore p) {
        return anchorVariable;
    }
}

All other methods of the visitor, at this point, just throw a NotImplementedException. But of course, the tests are green.

Now, let us implement the actual folding. We start with the unary operators, where this simple test case checks that we fold for unary minus:

    [Test]
    public void TestDoFoldConstantInUnaryMinus() {
        AbstractExpr input = -new Constant(4);
        AbstractExpr result = input.Accept(visitor, Ig.nore);
        Assert.AreNotEqual(input, result);
        Assert.AreEqual(new Constant(-4), result);
    }

Here is the implementation of the three cases explained above:

    public AbstractExpr Visit(UnaryExpression unaryExpression, Ignore p) {
        AbstractExpr oldInner = unaryExpression.Inner;
        AbstractExpr newInner = oldInner.Accept(this, Ig.nore);
        if (newInner is Constant) {
            return new Constant(
                unaryExpression.Op.Accept(this, (Constant)newInner, Ig.nore)
            );
        } else if (newInner != oldInner) {
            return new UnaryExpression(newInner, unaryExpression.Op);
        } else {
            return unaryExpression;
        }
    }

The actual computation for constants is delegated to the operator. Therefore, we need at least the following additional code:

class ConstantFoldingVisitor : ...
          , ISolverModelUnaryOpVisitor<Constant, Ignore, double> {
    ...
    public double Visit(UnaryMinus op, Constant inner, Ignore p) {
        return -inner.Value;
    }
}

The test runs green, and so we can—after a few more test cases—complete the unary operator visits:

    public double Visit(UnaryMinus op, Constant inner, Ignore p) {
        return -inner.Value;
    }

    public double Visit(PositiveSquareroot op, Constant inner, Ignore p) {
        return Math.Sqrt(inner.Value);
    }

    public double Visit(Sin op, Constant inner, Ignore p) {
        return Math.Sin(inner.Value * Math.PI / 180);
    }

    public double Visit(Cos op, Constant inner, Ignore p) {
        return Math.Cos(inner.Value * Math.PI / 180);
    }

    public double Visit(Square op, Constant inner, Ignore p) {
        return inner.Value * inner.Value;
    }

    public double Visit(FormalSquareroot op, Constant inner, Ignore p) {
        ???
    }

But wait ... what do we do with the formal square root? Well, a formal square root cannot be evaluated, as it does not return a single value (except when the argument is zero). Therefore, we cannot simplify a formal square root! So, we must modify the code that visits unary expressions. Here is a simple version:

    public AbstractExpr Visit(UnaryExpression unaryExpression, Ignore p) {
       ...
        if (newInner is Constant
            && !(unaryExpression.Op is FormalSquareroot)) {
            return new Constant(
                unaryExpression.Op.Accept(this, (Constant)newInner, Ig.nore)
            );
        } else ...
    }

Instead of the specific type check, it might be better to introduce an IsFunction property for operators that is true if the opeator is a true function, i.e., it returns at most one result for an argument. On the other hand, that might be YAGNI—I simply leave this text as a comment in there so that the IsFunction idea does not get completely lost.

For the binary operators, we again write a few test cases, for example (there are more in the actual code):

    [Test]
    public void TestDoFoldConstant() {
        AbstractExpr input = (new Constant(1) + new Constant(2))
                             * new Constant(4);
        AbstractExpr result = input.Accept(visitor, Ig.nore);
        Assert.AreEqual(new Constant(12), result);
    }

    [Test]
    public void TestDontFoldConstant() {
        AbstractExpr input = (new Constant(1) + new NamedVariable("b"))
                             * new Constant(4);
        AbstractExpr result = input.Accept(visitor, Ig.nore);
        Assert.AreEqual(input, result);
        Assert.AreSame(input, result);
    }

And then we implement the folding algorithm:

class ConstantFoldingVisitor : ...
      , ISolverModelBinaryOpVisitor<Constant {
    public AbstractExpr Visit(BinaryExpression binaryExpression, Ignore p) {
        AbstractExpr oldLhs = binaryExpression.Lhs;
        AbstractExpr oldRhs = binaryExpression.Rhs;
        AbstractExpr newLhs = oldLhs.Accept(this, Ig.nore);
        AbstractExpr newRhs = oldRhs.Accept(this, Ig.nore);
        if (newLhs is Constant & newRhs is Constant) {
            return new Constant(
                binaryExpression.Op.Accept(this, (Constant)newLhs,
                                           (Constant)newRhs, Ig.nore)
            );
        } else if (newLhs != oldLhs | newRhs != oldRhs) {
            return new BinaryExpression(newLhs, binaryExpression.Op, newRhs);
        } else {
            return binaryExpression;
        }
    }

    public double Visit(Plus op, Constant lhs, Constant rhs, Ignore p) {
        return lhs.Value + rhs.Value;
    }

    public double Visit(Times op, Constant lhs, Constant rhs, Ignore p) {
        return lhs.Value * rhs.Value;
    }

    public double Visit(Divide op, Constant lhs, Constant rhs, Ignore p) {
        return lhs.Value / rhs.Value;
    }

As a last step, we extend the visitor so that it simplifies complete constraints. This will most probably come in handy in the solver, when we want to match constraints to patterns. Here is the code for the EqualsZeroConstraint:

    public AbstractConstraint Visit(EqualsZeroConstraint equalsZero, Ignore p) {
        AbstractExpr result = equalsZero.Expr.Accept(this, Ig.nore);
        return result != equalsZero.Expr
                      ? new EqualsZeroConstraint(result)
                      : equalsZero;
    }

As you can see, the folding algorithm does not do an elaborate analysis of the expression. This has the consequence that constants in
v+c+c
c+v+c
(where v is a variable, c a constant) are not at all folded, whereas in
c+c+v
v+(c+c)
the sums are folded. It might be that we need better simplification algorithms later. Right now, we hope that this algorithm is good enough.

No comments:

Post a Comment