Beware Java Enums in Spark

Josh - 23 Feb 2014

A few days back I wrote a Spark job that runs an A/B test to compare the conversion rates between two groups of website visitors on one of our client's websites.

In this case, visitors were placed into the test/control group based on a deterministic method using their uuid. At the time of writing we had multiple projects wanting to use our general A/B test code. However some projects were using Scala 2.9x and others were on Scala 2.10. We wanted to share code but we didn't want the hassle of maintaining separate artifacts for each Scala version - so we packaged the group labels and the method for determining who belongs to which group in a Java project. It looked something like this:

  public enum TestControl {

    public static TestControl getGroup(UUID id) { 
      // returrn appropriate group

The Spark Code

The report was a fairly standard Spark report for us. Basically it involved loading & massaging our log files into the appropriate RDD format - then mapping them into a format conducive to aggregation. Spark supports aggregation of key/value pairs using the reduceByKey method via an implicit conversion on an RDD[Tuple2] (see PairRDDFunctions).

Since I was interested in comparing the average conversion rate between the two groups - it was a simple matter of mapping the log files into the appropriate key/value pairs and then calling reduceByKey like so:

// aggregation object
 case class Stats(impressions: Long = 0, conversions: Long = 0)  {
  lazy val converisonRate: Double = conversions / impressions.toDouble

  def +(other: Stats) {
      impressions + other.impressions,
      conversions + other.conversions

val myRDD = // load/map log files from S3 to case class with methods for each field

val readyForAggregation = { 
  case line if line.isImp  => (TestControl.getGroup(line.uuID), Stats(1, 0)) 
  case line if line.isConv => (TestControl.getGroup(line.uuID), Stats(0, 1)) 

val results = readyForAggregation.reduceByKey { _ + _ }

At this point we'd expect the resulting RDD to contain just two elements, one for the test group and another for the control (as our key was a TestControl value and getGroup only produce 2 values).


It might seem obvious that we should only get 2 items in our results RDD but it's always good to check with a unit test. Running a local unit test proved that indeed yes this was the case:

  def resultsShouldHaveJust2Elements() {
    val lines = // make some stub log lines (Scala case classes)
    val results = new ABTestReport(lines).run // runs aggregation discussed above
    results.size shouldBe 2

Sweet all our tests pass, we're done! But wait! Not so fast, running this on a real spark cluster yielded very different results...

  val data = results.collect // pulls RDD onto master as an Array
  println(data.size) // 52 -> WTF?!!!

  println(data) // Array( (CONTROL, Stats), (TEST, Stats), (CONTROL, Stats), (TEST, Stats)...) 

WTF Spark?!

So what's going on here? We have a local unit test that shows that reduceByKey works as expected when using our Enum for a key, but our spark cluster seems to incorrectly reduce our results - we wind up with way too many keys.

Well after some digging with a few other engineers it turns out that by default Spark will map your data items to partitions using a HashPartitioner. HashPartitioner uses the hashCode of an object to determine which partition it will live in. Ok so far, that seems completely sensible.

The Devil is in the Enum

But wait, the hashCode method on Java's enum type is based on the memory address of the object. So while yes, we're guaranteed that the same enum value have a stable hashCode inside a particular JVM (since the enum will be a static object) - we don't have this guarantee when you try to compare hashCodes of Java enums with identical values living in different JVMs. They will very likely have different hashCode values.

Our local unit test passed because it executed on a single JVM, the enum's hashCode remained consistent when HashPartitioner asked for it - yet our real cluster failed since HashPartitioner was getting different hashCodes for same enum values due to each slave having its own machine/JVM.

What to do instead

At this point is should be pretty clear that we should not use Java enums as keys for RDD's we'd like to aggregate. Fortunately there are two easy alternatives:

The omg I'm lazy way

You can simply toString() your enums prior to calling reduceByKey since String's hashCode method does not rely on the memory address of the object.

A better way

Use sealed traits and a case objects instead, they use a fixed hashCode value and do not use the memory location to calculate the hashCode

sealed trait Group 

object TestControl {
  case object Test extends Group
  case object Control extends Group

Wrapping Up

I think this goes to show you that while local unit tests on a single JVM are always a good idea, you should also check your results carefully on clustered systems. When you start doing parallel computing you often run into bugs that only show up on real clusters. Happy Spark-ing

comments powered by Disqus