spark per partition processing example – tutorial 11

Working with data on a per partition basis allows us to avoid redoing set up work for each data item. Operations like opening a database connection or creating a random number generator are examples of set up steps that we wish to avoid doing for each element. Spark has per-partition versions of map and for each to help reduce the cost of these operations by letting you run code only once for each partition of an RDD.

Lets take an example say we have 1 million elements in a particular RDD partition and when we call a map method then the function provided to the the mapping transformation will be called 1 million times. Conversely, if we use mapPartitions() then we will only call the particular function one time, but we will pass in all 1 million records and get back all responses in one function call. That means, we could get a big lift in the fact that we aren’t exercising the particular function so many times, especially if the function is doing something expensive each time that it wouldn’t need to do if we passed in all the elements at once.

Below is an example where we have used mapPartitions instead of map and returning only those records whose rating is more than 4. Though for the below example there will not be much difference in the performance as we are just filtering the data you will see the real use of mapPartitions when we have a expensive operation which doest not have to be executed for each record in the rdd but only once per partition.

When operating on a per-partition basis, Spark gives our function an Iterator of the elements in that partition. To return values, we return an Iterable. In addition to mapPartitions(), Spark has a number of other per-partition operators like mapPartitionsWithIndex() which passes the Integer of partition number, and Iterator of the elements in that partition and foreachPartition() which passes the Iterator of the elements and returns nothing.


import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFlatMapFunction;

import scala.Tuple2;

public class MapPartition {

public static void main(String[] args) {

File file=new File("C:\\codebase\\scala-project\\output\\one");

System.out.println(file.delete());

SparkConf sparkConf = new SparkConf().setAppName("test").setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\inputdata\\movies_data_2");

JavaPairRDD<String, String> pair=rdd.mapPartitionsToPair(new PairFlatMapFunction<Iterator<String>, String, String>() {

@Override
public Iterator<Tuple2<String, String>> call(Iterator<String> arg0) throws Exception {
// TODO Auto-generated method stub

List<Tuple2<String, String>> list = new ArrayList<Tuple2<String, String>>();

while (arg0.hasNext()) {
String[] data = arg0.next().split(",");
if (Integer.parseInt(data[1]) > 4) {
list.add(new Tuple2<String, String>(data[0], data[1]));
}
}

return list.iterator();
}

});

pair.saveAsTextFile("C:\\codebase\\scala-project\\output\\one");

}

}