Thursday, June 26, 2008

Schwartzian Transforms in Java

Libraries like Java's Collections API contain numerous algorithms tuned to give good performance in a theoretical 'typical case', but it can be very easy to forget that these algorithms aren't magical, and have certain requirements of their own if that goal is to be achieved.

On numerous occasions, I've written or encountered code that seemed to work perfectly well during development and testing, but performed very poorly when it had to deal with a larger or more complex data-set, something which often happens only after the code has been put into production use.

Here's a program which lists the contents of a specified directory ordered by size, with the size of a child directory being calculated from the total size of all of its contents:

import java.io.*;
import java.util.*;

public class FileSizes {

public static void main(String[] args) {

File directory = new File(args[0]);
List<File> files = new ArrayList<File>(Arrays.asList(directory.listFiles()));

Collections.sort(files, new Comparator<File>() {
public int compare(File file1, File file2) {
long file1Size = calculateSize(file1);
long file2Size = calculateSize(file2);
if (file1Size == file2Size)
return 0;
return file1Size < file2Size ? -1 : 1;
}
});

for (File file : files)
System.out.println(file);
}

static long calculateSize(File file) {
long size = 0;
if (file.isDirectory())
for (File child : file.listFiles())
size += calculateSize(child);
else
size = file.length();
return size;
}
}

Now, how long does this program take to execute?

Potentially a lot longer than it should, that's how long.

The problem here is that during the sort operation, each item being sorted may be compared against the other items multiple times - exactly how many depending on the sorting algorithm in use and the size and ordering of the initial data-set. If that comparison operation involves some expensive calculation, as it does in this rather blatant example, it's normally wasted effort to perform that calculation every time. In such cases, it's often preferable to cache the result of the calculation the first time it's performed, and reuse that cached version if it's needed again.

A common technique used by Perl programmers faced with this problem is to employ something known as the Schwartzian Transform, which looks something like this:

my @sorted_data =
map { $_->[0] }
sort { $a->[1] cmp $b->[1] }
map { [$_, expensive_calculation($_)] }
@unsorted_data;

The idea here is that we start with the unsorted data (read the code from the bottom line up), and apply a mapping function to it which maps each item to a 2 element array containing the original item at index 0 and its associated calculated value at index 1. These are then sorted, using the element at index 1 for the comparisons, and the result is finally fed through another mapping function to extract just the original items, now in the desired order.

We can try something similar in Java. We'll need a mapping method, and let's make it general enough that we can give it any Iterable (such as a List) and a Mapper to apply to each element:

interface Mapper<T,U> {
U map(T item);
}

<T,U> Iterable<U> map(final Iterable<T> items, final Mapper<? super T, ? extends U> mapper) {
return new Iterable<U>() {
public Iterator<U> iterator() {
Iterator<T> iter = items.iterator();
return new Iterator<U>() {
public boolean hasNext() { return iter.hasNext(); }
public U next() { return mapper.map(iter.next()); }
public void remove() { iter.remove(); }
};
}
};
}

To match the Perl example, we'll also want a sort method which sorts and returns a new List rather than modifying the one passed to it:

<T> List<T> sort(Iterable<T> items, Comparator<? super T> comparator) {
List<T> list = new ArrayList<T>();
for (T t : items)
list.add(t);
Collections.sort(list, comparator);
return list;
}

Finally we'll need some kind of data structure to hold the item together with its calculated-value. A type-safe Pair class will do:

class Pair<A,B> {
public static <A,B> Pair<A,B> of(A fst, B snd) {
return new Pair<A,B>(fst, snd);
}
private A fst;
private B snd;
private Pair(A fst, B snd) {
this.fst = fst;
this.snd = snd;
}
public A fst() { return fst; }
public B snd() { return snd; }
}

Now we can replace the Collections.sort call in the original example with a Schwartzian Transform:

Iterable<File> sortedFiles =
map(sort(map(files,
new Mapper<File, Pair<File,Long>>() {
public Pair<File,Long> map(File f) {
return Pair.of(f, calculateSize(f));
}
}),
new Comparator<Pair<File,Long>>() {
public int compare(Pair<File,Long> p1, Pair<File,Long> p2) {
return p1.snd().compareTo(p2.snd());
}
}),
new Mapper<Pair<File,Long>,File>() {
public File map(Pair<File,Long> p) {
return p.fst();
}
});

Well, that's pretty horrible, whichever way you try to lay out the code. Still, there's one improvement we could make - since we're just writing a Comparator which extracts the Long values and compares them, we could overload our sort method to accept a Mapper instead of a Comparator:

<T, U extends Comparable<U>> List<T> sort(Iterable<T> items, final Mapper<? super T, U> mapper) {
List<T> list = new ArrayList<T>();
for (T t : items)
list.add(t);
Collections.sort(list, new Comparator<T>() {
public int compare(T t1, T t2) { return mapper.map(t1).compareTo(mapper.map(t2)); }
});
return list;
}

That allows us to shorten the code a little:

Iterable<File> sortedFiles =
map(sort(map(files,
new Mapper<File, Pair<File,Long>>() {
public Pair<File,Long> map(File f) {
return Pair.of(f, calculateSize(f));
}
}),
new Mapper<Pair<File,Long>,Long>() {
public Long map(Pair<File,Long> p) {
return p.snd();
}
}),
new Mapper<Pair<File,Long>,File>() {
public File map(Pair<File,Long> p) {
return p.fst();
}
});

It's still far too cumbersome though - if I were writing this 'for real' I'd probably have given up by now.

Let's try it using closures instead of anonymous classes:

Iterable<File> sortedFiles =
map(sort(map(files,
{File f => Pair.of(f, calculateSize(f))}),
{Pair<File,Long> p => p.snd()}),
{Pair<File,Long> p => p.fst()});

That's much better to my eyes - maybe not quite as succinct as the Perl version, but it's digestible.

The Mapper interface can be replaced by function types so let's lose that and apply a bit more closures magic to the map and sort methods:

<T,U> Iterable<U> map(Iterable<T> items, {T=>U} mapper) {
return {=>
Iterator<T> iter = items.iterator();
new Iterator<U>() {
public boolean hasNext() { return iter.hasNext(); }
public U next() { return mapper.invoke(iter.next()); }
public void remove() { iter.remove(); }
}
};
}

<T,U extends Comparable<U>> List<T> sort(Iterable<T> items, {T=>U} c) {
List<T> list = new ArrayList<T>();
for (T t : items)
list.add(t);
Collections.sort(list, {T t1, T t2 => c.invoke(t1).compareTo(c.invoke(t2))});
return list;
}

Of course we could also tuck all that map/sort/map stuff away in a handy reusable method:

<T,U extends Comparable<U>> Iterable<T> schwartzianSort(Iterable<T> items, {T=>U} mapper) {
return map(sort(map(items,
{T t => Pair.of(t, mapper.invoke(t))}),
{Pair<T,U> p => p.snd()}),
{Pair<T,U> p => p.fst()});
}

Leaving us with:

public static void main(String[] args) {

File directory = new File(args[0]);
List<File> files = new ArrayList<File>(Arrays.asList(directory.listFiles()));

for (File file : schwartzianSort(files, {File f => calculateSize(f)}))
System.out.println(file);
}

Much nicer!

Sometimes, it's not immediately obvious what a method like that should look like, or how it might be implemented. Breaking the algorithm down into a series of functions, as the Schwartzian Transform does, can be a useful way to approach the problem at hand, but only if the language you're using has sufficiently practical constructs.

updated 29/06/08 to fix a bug in map() pointed out by Konstantin Triger.
updated 03/07/08 to fix a bug in the Comparator pointed out by bjkail.