spark treeAggregate example and treeReduce example

treeAggregate is a specialized implementation of aggregate that iteratively applies the combine function to a subset of partitions. This is done in order to prevent returning all partial results to the driver. Its a specialized implementation of aggregate that iteratively applies the combine function to a subset of partitions. This is done in order to prevent returning all partial results to the driver. In case of treeAggregate imagine n array tree that has all the partitions at its leaves and the root will contain the final reduced value. This way there is no single bottleneck machine.

treeReduce is a generalization of reduce operation on any RDD. reduceByKey is used for implementing treeReduce but they are not related in any other sense. reduceByKey performs reduction for each key, resulting in an RDD. It is not an action but a transformation that returns a Shuffled RDD. On the other hand, treeReduce perform the reduction in parallel using reduceByKey this is done by creating a key-value pair RDD on the fly, with the keys determined by the depth of the tree.

In a reduce or aggregate functions of Spark and MapReduce all partitions have to send their reduced value to the driver machine, and that machine spends linear time on the number of partitions due to the CPU cost in merging partial results and the network bandwidth limit. Sometimes it becomes a bottleneck when there are many partitions and the data from each partition is big.

It is recommended that we should use treeAggregate and treeReduce instead of aggregate and reduce function whenever applicable for better performance. Although we can use treeAggregate and treeReduce on key-value rdd generally treeAggregate and treeReduce is not suitable to be used with key-value rdd as it does not retain the key grouping of data.

Lets take an example

Problem : Find the sum of octet consumed. The focus is on using the treeAggregate and treeReduce function here the algorithm here doesn’t have any logic/usecase for it.

Input data


#id,octet,status,time

5868904,100,1,18/01/2018 18:03:44
5868904,200,1,18/01/2018 20:03:55
5868904,300,1,18/01/2018 22:12:22
5868904,400,1,19/01/2018 00:16:22
5868904,500,1,19/01/2018 02:19:12
5868904,600,1,19/01/2018 04:11:44

Lets write the spark code in java . We are using the mapPartitionsToPair instead of mapToPair as we are using the SimpleDateFormat for formatting the date and we don’t want to internalize it for each record.

As we can see treeAggregate does not return an rdd but a single object of mergePartition return type. As with treeAggregate , treeReduce also return just the final value not an rdd.

Using treeAggregate and treeReduce on RDD


import java.util.TreeSet;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import scala.Tuple2;

public class TreeAggregate {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("aggregate_by_key").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(conf);

JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\inputdata\\agg_data");

TreeSet<String> zero = new TreeSet<String>();

Function2<TreeSet<String>, String, TreeSet<String>> mergeValue = new Function2<TreeSet<String>, String, TreeSet<String>>() {

private static final long serialVersionUID = 2323;

@Override
public TreeSet<String> call(TreeSet<String> arg0, String arg1) throws Exception {
// TODO Auto-generated method stub

arg0.add(arg1);
return arg0;
}
};

Function2<TreeSet<String>, TreeSet<String>, TreeSet<String>> mergePartition = new Function2<TreeSet<String>, TreeSet<String>, TreeSet<String>>() {

private static final long serialVersionUID = 9898;

@Override
public TreeSet<String> call(TreeSet<String> arg0, TreeSet<String> arg1) throws Exception {
// TODO Auto-generated method stub

arg0.addAll(arg1);

return arg0;
}
};

TreeSet<String> finalRdd3 = rdd.treeAggregate(zero, mergeValue, mergePartition);

double add = 0d;

for (String string : finalRdd3) {

String[] data = string.split(",");

add = add + Double.parseDouble(data[1]);

}

System.out.println(add);

String sumTreeReduce=rdd.treeReduce(new Function2<String, String, String>() {

private static final long serialVersionUID = 89898;

@Override
public String call(String arg0, String arg1) throws Exception {

String[] data = arg1.split(",");

String[] data1 = arg0.split(",");

data1[1] = String.valueOf(Double.parseDouble(data1[1]) + Double.parseDouble(data[1]));

return new StringBuilder(data1[0]+","+data1[1]).toString();
}
});

System.out.println(sumTreeReduce);

}

}

Using treeAggregate and treeReduce on JavaPairRDD

As we can see treeAggregate does not return an rdd but a single object of mergePartition return type. As with treeAggregate , treeReduce also return just the final value not an rdd. Due to this reason its not suitable to be used with pair rdd as it does not return a value per key.


import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;

import scala.Tuple2;

public class TreeAggregatePairRdd {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("aggregate_by_key").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(conf);

JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\inputdata\\agg_data");

JavaPairRDD<String, CallData> pair = rdd
.mapPartitionsToPair(new PairFlatMapFunction<Iterator<String>, String, CallData>() {

private static final long serialVersionUID = 4454545;

@Override
public Iterator<Tuple2<String, CallData>> call(Iterator<String> arg0) throws Exception {

List<Tuple2<String, CallData>> list = new ArrayList<>();

SimpleDateFormat formatter1 = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");

while (arg0.hasNext()) {
String[] data = arg0.next().split(",");

list.add(new Tuple2<String, CallData>(data[0], new CallData(Double.parseDouble(data[1]),
Integer.parseInt(data[2]), formatter1.parse(data[3]))));
}

return list.iterator();

}

});

TreeSet<CallData> set = new TreeSet<CallData>();

Function2<TreeSet<CallData>, Tuple2<String, CallData>, TreeSet<CallData>> mergeValue = new Function2<TreeSet<CallData>, Tuple2<String, CallData>, TreeSet<CallData>>() {

private static final long serialVersionUID = 2323;

@Override
public TreeSet<CallData> call(TreeSet<CallData> arg0, Tuple2<String, CallData> arg1) throws Exception {
// TODO Auto-generated method stub

arg0.add(arg1._2());

return arg0;
}
};

Function2<TreeSet<CallData>, TreeSet<CallData>, TreeSet<CallData>> mergePartition = new Function2<TreeSet<CallData>, TreeSet<CallData>, TreeSet<CallData>>() {

private static final long serialVersionUID = 9898;

@Override
public TreeSet<CallData> call(TreeSet<CallData> arg0, TreeSet<CallData> arg1) throws Exception {
// TODO Auto-generated method stub

arg0.addAll(arg1);

return arg0;
}
};

TreeSet<CallData> finalRdd2 = pair.treeAggregate(set, mergeValue, mergePartition,6);

double sum = 0;

for (CallData callData : finalRdd2) {

sum = sum + callData.getOctets();

}

System.out.println(sum);

Tuple2<String, CallData> p = pair.treeReduce(
new Function2<Tuple2<String, CallData>, Tuple2<String, CallData>, Tuple2<String, CallData>>() {

@Override
public Tuple2<String, CallData> call(Tuple2<String, CallData> arg0, Tuple2<String, CallData> arg1)
throws Exception {

arg0._2.setOctets(arg0._2.getOctets() + arg1._2.getOctets());

return arg0;
}
},6);

System.out.println(p._2.getOctets());

}

}

Leave a Reply

Your email address will not be published. Required fields are marked *