spark aggregatebykey example in java

Both foldByKey() and reduceByKey() require that the return type of our result be the same type as that of the elements in the RDD we are operating over. This works well for operations like sum, but sometimes we want to return a different type.

The aggregateByKey() function frees us from the constraint of having the return be the same type as the RDD we are working on. With aggregateByKey(), like foldByKey(), we supply an initial zero value of the type we want to return. We then supply a function to combine the elements from our RDD with the accumulator. Finally, we need to supply a second function to merge two accumulators, given that each node accumulates its own results locally.

Lets take an example

Problem : Find the sum of octet consumed for each id.

Input data


#id,octet,status,time

5868904,100,1,18/01/2018 18:03:44
5868904,200,1,18/01/2018 20:03:55
5868904,300,1,18/01/2018 22:12:22
5868904,400,1,19/01/2018 00:16:22
5868904,500,1,19/01/2018 02:19:12
5868904,600,1,19/01/2018 04:11:44

Lets write the spark code in java . We are using the mapPartitionsToPair instead of mapToPair as we are using the SimpleDateFormat for formatting the date and we don’t want to internalize it for each record.


import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;

import scala.Tuple2;

public class AggregateByKeyExample {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("aggregate_by_key").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(conf);

JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\inputdata\\agg_data");

JavaPairRDD<String, CallData> pair = rdd
.mapPartitionsToPair(new PairFlatMapFunction<Iterator<String>, String, CallData>() {

private static final long serialVersionUID = 4454545;

@Override
public Iterator<Tuple2<String, CallData>> call(Iterator<String> arg0) throws Exception {

List<Tuple2<String, CallData>> list = new ArrayList<>();

SimpleDateFormat formatter1 = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");

while (arg0.hasNext()) {
String[] data = arg0.next().split(",");

list.add(new Tuple2<String, CallData>(data[0], new CallData(Double.parseDouble(data[1]),
Integer.parseInt(data[2]), formatter1.parse(data[3]))));
}

return list.iterator();

}

});

TreeSet<CallData> set = new TreeSet<CallData>();

Function2<TreeSet<CallData>, CallData, TreeSet<CallData>> test = new Function2<TreeSet<CallData>, CallData, TreeSet<CallData>>() {

private static final long serialVersionUID = 2323;

@Override
public TreeSet<CallData> call(TreeSet<CallData> arg0, CallData arg1) throws Exception {
// TODO Auto-generated method stub

arg0.add(arg1);
return arg0;
}
};

Function2<TreeSet<CallData>, TreeSet<CallData>, TreeSet<CallData>> test2 = new Function2<TreeSet<CallData>, TreeSet<CallData>, TreeSet<CallData>>() {

private static final long serialVersionUID = 9898;

@Override
public TreeSet<CallData> call(TreeSet<CallData> arg0, TreeSet<CallData> arg1) throws Exception {
// TODO Auto-generated method stub

arg0.addAll(arg1);

return arg0;
}
};

JavaPairRDD<String, TreeSet<CallData>> finalRdd = pair.aggregateByKey(set, test, test2);

Double sum = 0d;

for (Tuple2<String, TreeSet<CallData>> callData : finalRdd.collect()) {

for (CallData calldata : callData._2) {

sum = sum + calldata.getOctets();

}

System.out.println("Total Octet for id " + callData._1 + " " + sum);

}

}

}

Below is the model class used


import java.io.Serializable;
import java.util.Date;

public class CallData implements Comparable<CallData>, Serializable {

private static final long serialVersionUID = 232323;

private double octets;

private int status;

private Date input_date;

public CallData(double octets, int status, Date input_date) {
super();
this.octets = octets;
this.status = status;
this.input_date = input_date;
}

public double getOctets() {
return octets;
}

public void setOctets(double octets) {
this.octets = octets;
}

public int getStatus() {
return status;
}

public void setStatus(int status) {
this.status = status;
}

public Date getInput_date() {
return input_date;
}

public void setInput_date(Date input_date) {
this.input_date = input_date;
}

@Override
public int compareTo(CallData data) {

if (this.getInput_date().compareTo(data.getInput_date()) < 0) {
return -1;
} else if (this.getInput_date().compareTo(data.getInput_date()) > 0) {
return 1;
} else if (this.getInput_date().compareTo(data.getInput_date()) == 0) {

return Double.compare(this.getOctets(), data.getOctets());

}

return -1;
}

}