CombineByKey is the most general of the per-key aggregation functions. Most of the other per-key combiners are implemented using it. Like aggregate(), combineByKey() allows the user to return values that are not the same type as our input data.
To understand combineByKey(), it’s useful to think of how it handles each element it processes. As combineByKey() goes through the elements in a partition, each element either has a key it hasn’t seen before or has the same key as a previous element. If it’s a new element, combineByKey() uses a function we provide, called create Combiner(), to create the initial value for the accumulator on that key.
It’s important to note that this happens the first time a key is found in each partition, rather than only the first time the key is found in the RDD. If it is a value we have seen before while processing that partition, it will instead use the provided function, mergeValue(), with the current value for the accumulator for that key and the new value.
Since each partition is processed independently, we can have multiple accumulators for the same key. When we are merging the results from each partition, if two or more partitions have an accumulator for the same key we merge the accumulators using the user-supplied mergeCombiners() function.
Lets take an example of processing a rating csv file which has movie_name,rating,timestamp columns and we need to find average rating of each movie
Below is the sample data
Spider Man,4,978302174 Spider Man,4,978301398 Spider Man,4,978302091 Bat Man,5,978298709 Bat Man,4,978299000
We need to specify three fuction Combiner(), mergeValue() and mergeCombiners().
Lets say we are using the below map function and here the movie name is the key and the value is the rating which is a double type.
In Scala
val conf = new SparkConf().setAppName("scala spark").setMaster("local") val sc = new SparkContext(conf) val rdd = sc.textFile("C:\\codebase\\scala-project\\input data\\movies_data_2") val pairRdd=rdd.map { x => { var data=x.split(",") new Tuple2(data(0),data(1).toDouble) } }
In Java
SparkConf sparkConf = new SparkConf().setAppName("test").setMaster("local"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\input data\\movies_data_2"); JavaPairRDD<String, Double> pairRdd = rdd.mapToPair(new PairFunction<String, String, Double>() { @Override public Tuple2<String, Double> call(String str) throws Exception { String[] data = str.split(",", -1); return new Tuple2<String, Double>(data[0], Double.parseDouble(data[1])); } });
Combiner Function
In the combiner class we need to specify the input type as the value type which we pass from the key value pair rdd. For above example the map which converts the data into the key value format has the type Tuple2(String,Double) then our combiner function should take double as the input type. In the code below we are taking the rating value which gets passed from the pair rdd and returning the AverageRating object passing 1 as the count and x as the sum. Which will be used as a accumulator object for each unique key.
In Scala
val combiner=(x:Double) =>{ new AverageRating(1,x) }
In Java
In java we need to implement a interface of type Function specifying the input type and the return type and in our case the value input type is double and the return type is AverageRating. Here as the return type is AverageRating this will be the type of our accumulator which gets passed into the mergeValue fuction along with the value from the javapairrdd.
static Function<Double, AverageRating> combiner = new Function<Double, AverageRating>() { @Override public AverageRating call(Double value) throws Exception { return new AverageRating(1, value); } };
MergeValue Function
The accumulator object from the combine function and the actual value from the pair rdd gets passed into the mapValue function. In our case as our accumulator object type we are returning from the combine method is AverageRating and the value type from the input pair rdd is double so our MergeValue function will have a input type as (y: AverageRating,d:Double). So we are adding 1 to the count variable for our accumulator and value to the accumulator sum variable and returning the accumulator object.
In Scala
val mergeValue =(y: AverageRating,d:Double)=>{ y.count+=1 y.sum+=d y }
In Java
In java we need to implement a interface of type Function2 specifying the accumulator type,input value type from the mapToPair fuction and the return type we are returning. In our case the accumulator type is AverageRating and the value input type is double and the return type is AverageRating. Here as the return type is AverageRating this will be the type that gets passed into the final mergecombiner function which combines the data from different partitions .
static Function2<AverageRating, Double, AverageRating> mergeValue = new Function2<AverageRating, Double, AverageRating>() { @Override public AverageRating call(AverageRating avg, Double value) throws Exception { avg.setSum(avg.getSum() + value); avg.setCount(avg.getCount() + 1); return avg; } };
MergeCombiners
Since each partition is processed independently, we can have multiple accumulators for the same key. When we are merging the results from each partition, if two or more partitions have an accumulator for the same key we merge the accumulators using the user-supplied mergeCombiners() function.
In our case the output from the mergeValue function is of type AverageRating so the MergeCombiners will take 2 AverageRating accumulators from different partitions and combine the output into a single accumulator using the mergeCombiners function.
In Scala
val mergeCombiners=(aa:AverageRating,bb:AverageRating)=>{ aa.count+=bb.count aa.sum+=bb.sum aa }
In Java
In java we need to implement a interface of type Function2 specifying the accumulator type thats returned from each mergevalue function . In our case the accumulator type from each mergeValue function is AverageRating so the averageRating object from each partitioner gets passed into the final mergecombiner function which combines the data from different partitions .
static Function2<AverageRating, AverageRating, AverageRating> mergeCombiners = new Function2<AverageRating, AverageRating, AverageRating>() { @Override public AverageRating call(AverageRating avg1, AverageRating avg2) throws Exception { avg1.setCount(avg1.getCount() + avg2.getCount()); avg1.setSum(avg1.getSum() + avg2.getSum()); return avg1; } };
Finally lets write the driver code
In Scala
pairRdd.combineByKey(combiner, mergeValue,mergeCombiners).collect().foreach(x=>{ println(x._1+" "+x._2.average()) })
In Java
JavaPairRDD<String, AverageRating> output = pairRdd.combineByKey(Combiner.combiner, Combiner.mergeValue, Combiner.mergeCombiners); for (Tuple2<String, AverageRating> string : output.collect()) { System.out.println(string._1 + " " + string._2.average()); }