Tail Call Recursion in Java with ASM

One kind of optimization offered by some compilers is tail call optimization. This optimization does not bring much, since the programmer can always tailor his code without recursion, especially in an imperative language. On the other side, recursive code often times more elegant, so why we don’t let the compiler do the nasty stuff when it is possible? In this article I will present a neat way to implement tail call optimization in Java using byte code manipulation with ASM.

What is tail call recursion?

A tail call recursion is a special form of recursion where the last operation is the recursive call. It belongs to the more general class of tail calls, where the last operation is a method call. I will limit myself to the more restrictive case of tail recursion. Let’s illustrate with an example.

long factorial(int n, long f) {
    if(n<2) {
        return f;
    }
    return factorial(n-1, f*n);
}

As one can see, the last operation is a call to the same function, but with different parameters. The next example is not a tail recursion.

long factorial(int n) {
    if(n<2) {
        return f;
    }
    return n * factorial(n-1);
}

The reason why the previous example is not a tail call recursion is that the last operation is not the recursive call, but the multiplication operation. Multiplication happens after the recursive call returns.

A tail recursion has a specific form which allows a faster execution by avoiding allocating a new stack frame since the execution can utilize the current stack.

Anatomy of a method call

If you don’t know much of how Java Virtual Machine make method calls this is a brief overview of it. The idea is almost universal in programming, however the details presented are specific to JVM.

In order for a method to be able to execute it needs a space called frame, where some specific things should be contained:

  • local variables space: a fixed array of entries with values of various types
  • operand stack: a stack where the current operands are stored

There is also an execution stack managed by JVM. The JVM execution stack, collects frames. When a method is called for execution a new frame is created, initialized properly and pushed on the JVM execution stack. The eventual parameters of the method call are collected from the current stack and used for the initialization of the new frame. After the method ends it’s execution, the returned value (if any) is collected, the frame allocated for that method call is removed from the JVM stack, the previous frame is referenced and the collected return value is pushed on stack.

The size of the local variables and operand stack parts depends on the method’s code, it’s computed at compile time and is stored along with the bytecode instructions in compiled classes. All the frames that correspond to the invocation of the same method are identical in size, but those that correspond to different methods can have different sizes.

When it is created, a frame is initialized with an empty stack, and its local variables are initialized with the target object this (for non static methods) and with the method’s arguments (in this order). For instance, calling the method a.equals(b) creates a frame with an empty stack and with the first two local variables initialized to a and b (other local variables are uninitialized).

As you can see, calling methods in chains will increase the space required by the JVM execution stack since for each inner method call a new frame is pushed on the stack. Since stack is limited, calling multiple time nested methods can fill the JVM execution stack to the point that it throws exceptions like StackOverflowError, about which you have heard about.

Exhausting stack space is often encountered when one implements algorithms in recursive fashion.

A practical example of a compiled class

Let’s study the generated code for the following class.

public class Factorial implements FactorialInterface {
    public long fact(int n) {
        return factTailRec(n, 1L);
    }
    private long factTailRec(int n, long ret) {
        if (n < 1) {
            return ret;
        }
        return factTailRec(n - 1, ret * n);
    }
}

As you can see we have a simple class which implements the factorial function in a tail recursive fashion. The first method is a facade for better user experience, while the second method implements the actual computations.

The generated class looks like the following

// class version 63.0 (63)
// access flags 0x21
public class rapaio/experiment/asm/Factorial implements rapaio/experiment/asm/FactorialInterface {
  // compiled from: Factorial.java
  // access flags 0x1
  public <init>()V
   L0
    LINENUMBER 34 L0
    ALOAD 0
    INVOKESPECIAL java/lang/Object.<init> ()V
    RETURN
   L1
    LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L1 0
    MAXSTACK = 1
    MAXLOCALS = 1
  // access flags 0x1
  public fact(I)J
   L0
    LINENUMBER 37 L0
    ALOAD 0
    ILOAD 1
    LCONST_1
    INVOKEVIRTUAL rapaio/experiment/asm/Factorial.factTailRec (IJ)J
    LRETURN
   L1
    LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L1 0
    LOCALVARIABLE n I L0 L1 1
    MAXSTACK = 4
    MAXLOCALS = 2
  // access flags 0x2
  private factTailRec(IJ)J
   L0
    LINENUMBER 41 L0
    ILOAD 1
    ICONST_1
    IF_ICMPGE L1
   L2
    LINENUMBER 42 L2
    LLOAD 2
    LRETURN
   L1
    LINENUMBER 44 L1
   FRAME SAME
    ALOAD 0
    ILOAD 1
    ICONST_1
    ISUB
    LLOAD 2
    ILOAD 1
    I2L
    LMUL
    INVOKEVIRTUAL rapaio/experiment/asm/Factorial.factTailRec (IJ)J
    LRETURN
   L3
    LOCALVARIABLE this Lrapaio/experiment/asm/Factorial; L0 L3 0
    LOCALVARIABLE n I L0 L3 1
    LOCALVARIABLE ret J L0 L3 2
    MAXSTACK = 6
    MAXLOCALS = 4
}

Let us explain the byte code. We have a class which implements an interface and we have a default constructor.

The default constructor is optional in Java language, but it is not in JVM bytecode. The default constructor has, as a first instruction ALOAD 0 which loads the this pointer on the operational stack from local variable with index 0 (remember how the local variables are initialized). The second instruction invokes method init of class Object. This method takes one parameter, which is a pointer to an object. This pointer is taken from operational stack (this is why we have the first instruction). It returns void as it is signaled in it’s description with V. After that it simply returns the control to the previous call. Notice that it follows the description of the initial frame with variables used, the number of entries allocated for operand stack and for local variable table.

It follows the description of method fact, which takes a single integer argument (described by I) and it returns a long value (described by J). The instruction list starts with loading this pointer (ALOAD 0), the value of the parameter (ILOAD 1) and pushing on operand stack the constant long value 1 (LCONST_1). All the three values loaded on operational stack are used by the following instruction which is a call to method factTailRect in reverse order (it’s a stack). After the method call ends, it’s returned value is pushed on the operand stack. That value is used by the next instruction LRETURN which returns the long value from the operand stack and resumes the control to the calling method.

The description of method factTailRect is a little bit more complex, but not overly complicated. Load the value of the first parameter on operand stack (ILOAD 1) and also the integer constant 1 (ICONST_1). The next instruction compares the value of the parameter with the constant (IF_ICMPGE L1) and if the variable is greater or equal than constant go to label L1, otherwise continue. If the comparison fails than loads the value of the second variable on operand stack (LLOAD 2) and returns it (LRETURN).

At the label L1 there are instructions for the recursive call. First the this pointer is loaded on stack in preparation for recursive call (ALOAD 0). Then the first variable is loaded on stack (ILOAD 1) and constant 1 (ICONST_1). Those two are used by subtract operation which decrease the value of variable with the value of constant (ISUB). The returned value is put on stack. Notice at this point that on stack we have two values: the value of this pointer and the decreased value of variable (the subtraction instruction pop up two operands from stack and put back one, the result).

Next the second variable is put on stack (LLOAD 2) and also the value of the first variable (ILOAD 1). The first variable still has the original value, the modified one being on stack. The integer variable is converted to long type (I2L) and both values are multiplied (LMUL). The multiplication uses the last two stack operands and push back the result of multiplication.

The operand stack has now three values, the values required for the recursive call function (INVOKEVIRTUAL). The three operands are consumed by method call and the result is put on stack. The final result is returned (LRETURN). Last part is the description of the frame which contains three local variables and has allocated 6 places on stack and 4 on local variable table.

I hope you were not bored reading all of that. Maybe you know how the JVM stack machine works, but I have put that description for the case when you don’t.

The structure of a tail call recursion

In general a tail call recursion has a very simple structure. A tail recursive method has three phases. The first phase are the stopping rules. Those rules defines when the recursion will end. The second phase contains calculations and the third stage is the recursive call who’s result is returned. Often times, when the computation is simple the last two phases are merged into a single one, where the computation happens just before passing parameter values. This is the case with our method.

Having this clear design some observations can be made which leads to a straightforward optimization.

The first observation is that since the same method is called, the shape of the frame for the recursive call is identical with the current frame. The reason is that the shape of a frame is determined at compile time and remains fixed. Since we call recursively the same method, we are sure that the current frame fits the needs of the recursive call. The possible optimization is to avoid creating a new frame which has to be pushed on JVM execution stack with each call.

The second observation is that in order to reuse the current frame we need to prepare the stack and local variables in the same way as it would be if a proper frame initialization would happen. However this is very easy to be done simply because the last call before return is the recursive call. In order to make a recursive call the stack needs to be filled with the value of this pointer and all parameter values in order. The call would pop all those values from the current operational stack and would initialize the next frame with those values. This is what we have to do. Simply to take all those values and properly initialize the local variables of the current frame. The values are already prepared for us.

Using ASM to transform byte code

ASM is a wonderfull and neat library which allows one to analyze, transform and generate byte code at compile time or at runtime. It is used by many platforms and tools, including the OpenJDK compiler itself. I have not enough words to describe the usefulness and elegance of this library and I feel in great debt to its creator and contributors.

ASM library allows one to transform byte code using two approaches: event based and tree based. I will use the tree based API since the changes are not trivial and could not be performed in a single pass of the parser. This is the code used to optimize a method which is tail recursive:

class TailRecTransformer extends ClassNode {
    private static final String METHOD_SUFFIX = "TailRec";
    public TailRecTransformer(ClassVisitor cv) {
        super(ASM9);
        this.cv = cv;
    }
    @Override
    public void visitEnd() {
        // we optimize all methods which ends with TailRec for simplicity
        methods.stream().filter(mn -> mn.name.endsWith(METHOD_SUFFIX))
                .forEach(this::transformTailRec);
        accept(cv);
    }
    void transformTailRec(MethodNode methodNode) {
        // method argument types
        Type[] argumentTypes = Type.getArgumentTypes(methodNode.desc);
        // iterator over instructions
        var it = methodNode.instructions.iterator();
        LabelNode firstLabel = null;
        while (it.hasNext()) {
            var inode = it.next();
            // locate the first label
            // this label will be used to jump instead of recursive call
            if (firstLabel == null && inode instanceof LabelNode labelNode) {
                firstLabel = labelNode;
                continue;
            }
            if (inode instanceof FrameNode) {
                // remove all frames since we recompute them all at writing
                it.remove();
                continue;
            }
            if (inode instanceof MethodInsnNode methodInsnNode &&
                    methodInsnNode.name.equals(methodNode.name) &&
                    methodInsnNode.desc.equals(methodNode.desc)) {
                // find the recursive call which has to have
                // same signature and be followed by return
                // check if the next instruction is return of proper type
                var nextInstruction = it.next();
                Type returnType = Type.getReturnType(methodNode.desc);
                if (!(nextInstruction.getOpcode() == 
                        returnType.getOpcode(IRETURN))) {
                    continue;
                }
                // remove the return and recursive call from instructions
                it.previous();
                it.previous();
                it.remove();
                it.next();
                it.remove();
                // pop values from stack and store them in local 
                // variables in reverse order
                for (int i = argumentTypes.length - 1; i >= 0; i--) {
                    Type type = argumentTypes[i];
                    it.add(new VarInsnNode(type.getOpcode(ISTORE), i + 1));
                }
                // add a new jump instruction to the first label
                it.add(new JumpInsnNode(GOTO, firstLabel));
                // finally remove the instruction which loaded 'this'
                // since it was required by the recursive call
                while (it.hasPrevious()) {
                    AbstractInsnNode node = it.previous();
                    if (node instanceof VarInsnNode varInsnNode) {
                        if (varInsnNode.getOpcode() == Opcodes.ALOAD && 
                                varInsnNode.var == 0) {
                            it.remove();
                            // we remove only the last instruction of this kind
                            // we don't touch it other similar instructions 
                            // to not break the existent code
                            break;
                        }
                    }
                }
            }
        }
    }
}

I really hope the code and comments are self contained. I will briefly present the logic of it for consistency.

In order to transform a method with tree API of ASM library, one needs to change the values in class MethodNode since this is the representation of JVM byte code in the ASM library. For simplicity, I created a transformer which tries to optimize all the methods who’s name ends with suffix TailRec. This is for illustrative purpose, an annotation would be preferable, but require more code and building an agent.

The core of the optimization logic lies in method transformTailRec. This method receives the corresponding representation of the bytecode of any class method who’s name ends with our sufix. The optimization has the following stages.

We identify the first code label. This is the start of the code for the recursive methods. We will use this label when we will replace the recursive call with a simple jump instruction. This jump instruction is goto. As a fun fact this infamous instruction does not exist in the Java language for good reason. This kind of uncontrolled jump would break all the accounting machinery of the JVM. However the same instruction exists in JVM. Because in JVM we can jump only inside a set of instructions from the same method call, it is safe to be used.

Instead of the recursive method call which would create a new frame, we will reuse the current frame. The next stage is to remove the recursive call and the return instruction after it, altogether with preparing the local variables and stack for next use. In place of the recursive call we introduce a goto instruction which points to the first label. Basically we implemented a while loop. The stopping conditions are already in the code, so we will not obtain an infinite loop because of the optimization.

We are done!

Testing the recursive tail optimization

A complete treatment of this would imply implementing a Java agent which would optimize the code before class loading. A avoided those complications because it is irrelevant to the subject. Maybe in the future I will create a tiny github project with this annotation and optimization.

To keep thing simple I wrote a custom class loader which creates classes with optimized code. Java allows one to have two classes with the same specification if those classes are loaded by different class loaders. In order to be easy to use them, I created also an interface.

In this way we will have two classes, one optimized and the other not optimized, and both implementing the same interface. In this way we can use them in the same JVM instance and test them with JMH. For reference the code for class loader is listed below.

public class CustomClassLoader extends ClassLoader {
    private final boolean verbose;
    public CustomClassLoader(boolean verbose) {
        this.verbose = verbose;
    }
    @Override
    protected Class<?> findClass(String name) {
        ClassWriter cw = new ClassWriter(0);
        ClassVisitor lastCv;
        if (verbose) {
            TraceClassVisitor beforeTcv = new TraceClassVisitor(cw, new PrintWriter(System.out));
            TailRecTransformer trt = new TailRecTransformer(beforeTcv);
            lastCv = new TraceClassVisitor(trt, new PrintWriter(System.out));
        } else {
            lastCv = new TailRecTransformer(cw);
        }
        ClassReader cr;
        try {
            cr = new ClassReader(name);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        cr.accept(lastCv, 0);
        byte[] buffer = cw.toByteArray();
        return defineClass(name, buffer, 0, buffer.length);
    }
    public <T> T newTailRecInstance(Class<T> external, Class<?> internal) throws NoSuchMethodException,
            InvocationTargetException, InstantiationException, IllegalAccessException {
        Class<?> c = findClass(internal.getCanonicalName());
        return (T) c.getConstructor().newInstance();
    }
}

Factorial JMH benchmark

I implemented two simple recursive method calls. The first one was already presented and it is the factorial.

public class Factorial implements FactorialInterface {
    public long fact(int n) {
        return factTailRec(n, 1L);
    }
    private long factTailRec(int n, long ret) {
        if (n < 1) {
            return ret;
        }
        ret *= n;
        n -= 1;
        return factTailRec(n, ret);
    }
}

JMH benchmark results are presented below:

Benchmark                     (n)   Mode  Cnt    Score   Error   Units
TailRec.recursiveFact           1  thrpt    5  771.714 ± 9.722  ops/us
TailRec.recursiveFact           3  thrpt    5  242.958 ± 1.693  ops/us
TailRec.recursiveFact           5  thrpt    5  194.606 ± 2.418  ops/us
TailRec.recursiveFact          10  thrpt    5   90.850 ± 2.345  ops/us
TailRec.recursiveFact          15  thrpt    5   66.567 ± 0.898  ops/us
TailRec.recursiveFact          20  thrpt    5   48.615 ± 0.308  ops/us
TailRec.recursiveFactTailRec    1  thrpt    5  735.701 ± 4.936  ops/us
TailRec.recursiveFactTailRec    3  thrpt    5  512.596 ± 0.946  ops/us
TailRec.recursiveFactTailRec    5  thrpt    5  409.343 ± 3.884  ops/us
TailRec.recursiveFactTailRec   10  thrpt    5  263.263 ± 3.033  ops/us
TailRec.recursiveFactTailRec   15  thrpt    5  184.061 ± 2.992  ops/us
TailRec.recursiveFactTailRec   20  thrpt    5  133.968 ± 1.070  ops/us

The difference is pretty clear. The optimized version is faster. The differences are not large, thought. This is simply because of the small number of recursive calls, which has to be small to not produce integer overflow.

Sum JMH Benchmark

For illustrative purposes I implemented a sum over the values of an array in tail recursive manner. Of course, this is not the best option, but if the container would be a linked list it would be an appealing implementation in functional style. Below is the implementation of the sum method.

public class Sum implements SumInterface {
    public int sum(int[] array) {
        return sumTailRec(array, 0, 0);
    }
    public int sumTailRec(int[] array, int i, int sum) {
        if(i>=array.length) {
            return sum;
        }
        return sumTailRec(array, i+1, sum+array[i]);
    }
}

And below we have the JMH benchmark results.

Benchmark                      (n)   Mode  Cnt       Score      Error   Units
TailRec.recursiveSum            10  thrpt    5  102800.521 ± 7870.635  ops/ms
TailRec.recursiveSum           100  thrpt    5    8949.731 ±  473.936  ops/ms
TailRec.recursiveSum          1000  thrpt    5     846.104 ±   30.766  ops/ms
TailRec.recursiveSum         10000  thrpt    5      73.955 ±   17.637  ops/ms
TailRec.recursiveSumTailRec     10  thrpt    5  132477.710 ± 2955.738  ops/ms
TailRec.recursiveSumTailRec    100  thrpt    5   16956.311 ±  541.083  ops/ms
TailRec.recursiveSumTailRec   1000  thrpt    5    1915.083 ±  116.170  ops/ms
TailRec.recursiveSumTailRec  10000  thrpt    5     187.088 ±   10.059  ops/ms

We also notice improvements produced by tail call elimination.

Final remarks

I am not a huge fun of recursion in general, and I tend to prefer tight iterative implementations when is possible. This is by no means an argument against tail call optimization, especially tail call recursion.

Java at this moment does not offer any kind of tail call optimizations. Project Loom seems to take into consideration an even greater class of call optimizations, but those does not look to be a priority now. The tail recursion optimization can be implemented instead into a library, like Lombok, offering the proposed optimization when a given annotation is present.


Posted

in

,

by

Tags:

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *