spark partition level functions by examples

Spark has support for partition level functions which operate on per partition data. Working with data on a per partition basis allows us to avoid redoing set up work for each data item.

partitionBy function

The partitionBy function returns a copy of the RDD partitioned using the specified partitioner. The function is available in the class JavaPairRDD.

Below is the spark code in java

import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.
JavaPairRDD;
import org.apache.spark.api.java.
JavaRDD;
import org.apache.spark.api.java.
JavaSparkContext;
import org.apache.spark.api.java.
function.
PairFunction;
import scala.Tuple2;

public class PartionFunctions {

public static void main(String[] args) {

SparkConf conf = new SparkConf().
setAppName("Test").setMaster("local");

JavaSparkContext context = 
new JavaSparkContext(conf);

JavaRDD<String> rdd1 = context.textFile
("C:\\codebase\\scala-project
\\inputdata\\employee\\dataset1", 5);

JavaPairRDD<String, String> pairRdd = 
rdd1.mapToPair
(new PairFunction
<String, String, String>() {

private static final long serialVersionUID 
= 23232323;

public Tuple2<String, String> call
(String data) {
String[] record = data.split(",");

return new Tuple2<>(record[0], record[1]);
}

});

JavaPairRDD<String, String> pairRdd2 = 
pairRdd.
partitionBy(new HashPartitioner(10));

System.out.println(pairRdd2.partitioner()
.get());

}
}

partitions function
Get the array of partitions of this RDD. Each partition is numbered with a index with starts from 0 and grows till n partitions.

Below is the spark code in java

import java.util.List;
import org.apache.spark.
Partition;
import org.apache.spark.
SparkConf;
import org.apache.spark.api.
java.JavaRDD;
import org.apache.spark.api.java.
JavaSparkContext;

public class PartionFunctions {

public static void main
(String[] args) {

SparkConf conf = new SparkConf().
setAppName("Test").setMaster
("local");

JavaSparkContext context = 
new JavaSparkContext(conf);

JavaRDD<String> rdd1 = context.
textFile("C:\\codebase
\\scala-project\\
inputdata\\employee
\\dataset1", 
5);

List<Partition> optionalPartitioner = 
rdd1.partitions();

for (Partition partition : 
optionalPartitioner) {

System.out.println
(partition.index());

}
}
}

collectPartitions function

Return an array that contains all of the elements in a specific partition of this RDD. In the below code we are collecting all the elements from partition 1 and 2 in a list of string.

Below is the spark code in java

import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.
JavaRDD;
import org.apache.spark.api.java.
JavaSparkContext;

public class PartionFunctions {

public static void main(String[] args) {

SparkConf conf = 
new SparkConf().setAppName("Test")
.setMaster("local");

JavaSparkContext context = new 
JavaSparkContext(conf);

JavaRDD<String> rdd1 = context.
textFile
("C:\\codebase\\scala-project
\\inputdata\\employee\\dataset1", 5);

List<String>[] list = rdd1.
collectPartitions(new int[] { 1, 2 });

for (List<String> list2 : list) {

for (String string : list2) {

System.out.println(string);

}
}
}
}

foreachPartition and foreachPartitionAsync functions

Applies a function f to each partition of this RDD.The foreachPartitionAsync is the asynchronous version of the foreachPartition action, which applies a function f to each partition of this RDD. The foreachPartitionAsync returns a JavaFutureAction which is an interface which implements the java.util.concurrent.Future<T> which has inherited methods like cancel, get, get, isCancelled, isDone and also a specific method jobIds() which returns the job id. We are also printing the number of partitions using the function getNumPartitions.

Below is the spark code in java

import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.
JavaFutureAction;
import org.apache.spark.api.java.
JavaRDD;
import org.apache.spark.api.java.
JavaSparkContext;
import org.apache.spark.api.java.
function.VoidFunction;

public class PartionFunctions {

public static void main(String[] args) {

SparkConf conf = new SparkConf().
setAppName("Test").
setMaster("local");

JavaSparkContext context = 
new JavaSparkContext(conf);

JavaRDD<String> rdd1 = context.
textFile("C:\\codebase\\scala-project\\
inputdata\\employee\\dataset1", 5);

rdd1.foreachPartition
(new 
VoidFunction<Iterator<String>>() {

private static final long 
serialVersionUID = 554545;

@Override
public void call(Iterator<String> arg0) 
throws Exception {

while (arg0.hasNext()) {
System.out.println(arg0.next());
}

}
});

JavaFutureAction<Void> f = rdd1.
foreachPartitionAsync
(new VoidFunction<Iterator<String>>() {

private static final long 
serialVersionUID = 554547;

@Override
public void call(Iterator<String> arg0) 
throws Exception {
// TODO Auto-generated method stub

while (arg0.hasNext()) {
System.out.println(arg0.next());
}

}
});

System.out.println(rdd1.
getNumPartitions());

}
}

mapPartitions function

Returns a new RDD by applying a function to each partition of this RDD. The mapPartitions can be used as an alternative to map() function which calls the given function for every record whereas the mapPartitions calls the function once per partition for each partition. If we have some expensive initialization to be done we can use mapPartitions as initialization will be be done on per partition basis rather than each element of rdd. In the below example we need to enrich the data using a database so instead of creating a connection object for each record we are creating the object once per partition which will improve the performance by many folds. We need to implement a FlatMapFunction with Iterator<String> as input. The returned list iterator will be flattened to rdd of string.

Below is the spark code in java

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.
JavaDoubleRDD;
import org.apache.spark.api.java.
JavaPairRDD;
import org.apache.spark.api.java.
JavaRDD;
import org.apache.spark.api.java.
JavaSparkContext;
import org.apache.spark.api.java.
function.
DoubleFlatMapFunction;
import org.apache.spark.api.java.
function.
FlatMapFunction;
import org.apache.spark.api.java.
function.
Function2;
import org.apache.spark.api.java.
function.
PairFlatMapFunction;

import scala.Tuple2;

public class PartionFunctions {

public static void main(String[] args) {

SparkConf conf = new SparkConf().
setAppName("Test").setMaster("local");

JavaSparkContext context = 
new JavaSparkContext(conf);

JavaRDD<String> rdd1 = context.
textFile("C:\\codebase\\scala-project
\\inputdata\\employee\\dataset1", 5);

JavaRDD<String> mapPartitions = rdd1.
mapPartitions(new FlatMapFunction
<Iterator<String>, String>() {

private static final long 
serialVersionUID = 34389;

@Override
public Iterator<String> call(
Iterator<String> arg0)
 throws Exception {
// TODO Auto-generated method stub

Connection conn = Connection.getConnection();

List<String> list = new ArrayList<String>();

while (arg0.hasNext()) {

String data = arg0.next();

String enrichedData = conn.
enrichDataFromDatabase(data);

list.add(enrichedData);

}

return list.iterator();
}
});

mapPartitions.saveAsTextFile
("PATH_TO_OUTPUT_FILE");

}
}

mapPartitionsToDouble function

Return a new JavaDoubleRDD by applying a function to each partition of this RDD. We need to implement DoubleFlatMapFunction and here we dont have to specify the return type for DoubleFlatMapFunction as by default we need to return a double.

Below is the spark code in java

JavaDoubleRDD mapPartitionsToDouble = rdd1.
mapPartitionsToDouble
(new DoubleFlatMapFunction
<Iterator<String>>() {

private static final long serialVersionUID = 
3434343;

@Override
public Iterator<Double> call
(Iterator<String> arg0)
 throws Exception {
// TODO Auto-generated method stub

List<Double> list = new 
ArrayList<Double>();

while (arg0.hasNext()) {

String[] data = arg0.next().
split(",");

for (String d : data) {

boolean b = d.matches("\\d+");

list.add(b ? Double.parseDouble(d)
 : 0.0);
}

}

return list.iterator();
}
});

mapPartitionsToPair function

Return a new JavaPairRDD by applying a function to each partition of this RDD.

Below is the spark code in java

JavaPairRDD<String, String> 
mapPartitionsToPair = rdd1
.mapPartitionsToPair
(new PairFlatMapFunction
<Iterator<String>, String, String>()
{

private static final long 
serialVersionUID = 3434566;

@Override
public Iterator<Tuple2<String, String>> 
call(Iterator<String> arg0) 
throws Exception {
// TODO Auto-generated method stub

List<Tuple2<String, String>> list = 
new ArrayList<Tuple2<String, String>>();

while (arg0.hasNext()) {

String[] data = arg0.next().split(",");

for (String d : data) {

list.add(new Tuple2<String, 
String>(data[0], data[1]));
}

}

return list.iterator();
}
});

mapPartitionsWithIndex function

Returns a new RDD by applying a function to each partition of this RDD, while tracking the index of the original partition. The Function2 takes a Integer as a first parameter which is the partition index.

Below is the spark code in java

JavaRDD<String> rddWithIndex = rdd1.
mapPartitionsWithIndex
(new Function2<Integer, Iterator<String>,
 Iterator<String>>() {

private static final long serialVersionUID = 
343434;

@Override
public Iterator<String> call
(Integer arg0, Iterator<String> arg1)
 throws Exception {
// TODO Auto-generated method stub

List<String> list = new ArrayList<String>();

while (arg1.hasNext()) {

String dataWithPartitionIndex = 
String.valueOf(arg0) + arg1.next();

list.add(dataWithPartitionIndex);

}

return list.iterator();

}
}, true);

As we can see most of the per partition operation functions needs an implementation of FlatMapFunctions as the returned iterators from each partition is flattened to an rdd of string.