spark finding average using rdd, dataframe and dataset

Problem to Solve : Given a list of employees with there department and salary find the average salary in each department.

Input Data sample

First Name,Last Name,Job Titles,Department,Full or Part-Time,Salary or Hourly,Typical Hours,Annual Salary,Hourly Rate

dubert,tomasz ,paramedic i/c,fire,f,salary,,91080.00,
edwards,tim p,lieutenant,fire,f,salary,,114846.00,
elkins,eric j,sergeant,police,f,salary,,104628.00,
estrada,luis f,police officer,police,f,salary,,96060.00,
ewing,marie a,clerk iii,police,f,salary,,53076.00,
finn,sean p,firefighter,fire,f,salary,,87006.00,
fitch,jordan m,law clerk,law,f,hourly,35,,14.51
Below is the code
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;

import scala.Tuple2;

public class Numeric {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setAppName("pattern").setMaster("local");

JavaSparkContext jsc = new JavaSparkContext(conf);

JavaRDD<String> rdd = jsc.textFile("C:\\codebase\\scala-project\\inputdata\\employee");

JavaPairRDD<String, AverageTuple> pair2 = rdd.mapToPair(new PairFunction<String, String, AverageTuple>() {

@Override
public Tuple2<String, AverageTuple> call(String value) throws Exception {

String data = value.toString();
String[] field = data.split(",", -1);
double salary = 0;

if (null != field && field.length == 9 && field[7].length() > 0) {

return new Tuple2<String, AverageTuple>(field[3], new AverageTuple(1, Double.parseDouble(field[7])));

}
return new Tuple2<String, AverageTuple>("Invalid_Record", new AverageTuple(0, 0.0));
}
});

JavaPairRDD<String, AverageTuple> result = pair2
.reduceByKey(new Function2<AverageTuple, AverageTuple, AverageTuple>() {

@Override
public AverageTuple call(AverageTuple result, AverageTuple value) throws Exception {

result.setAverage(result.getAverage() + value.getAverage());
result.setCount(result.getCount() + value.getCount());

return result;
}
});

for (Tuple2<String, AverageTuple> string : result.collect()) {

System.out.println(string._1 + " " + string._2.getAverage() / string._2().getCount());

}

}

}

Below is the AverageTuple class


import java.io.Serializable;

public class AverageTuple implements Serializable {

private int count;

private double average;

public AverageTuple(int count, double average) {
super();
this.count = count;
this.average = average;
}

public int getCount() {
return count;
}

public void setCount(int count) {
this.count = count;
}

public double getAverage() {
return average;
}

public void setAverage(double average) {
this.average = average;
}

}

Output

POLICE BOARD 86136.0
HUMAN RESOURCES 79851.76119402985
ADMIN HEARNG 78912.94736842105
FIRE 97762.3486619425
BOARD OF ELECTION 56051.142857142855
PUBLIC LIBRARY 71273.28813559322
CITY CLERK 69762.43902439025
IPRA 94429.28571428571
AVIATION 76140.01877697841
BOARD OF ETHICS 94552.5
COMMUNITY DEVELOPMENT 88363.25714285714
POLICE 87836.02534889111
LICENSE APPL COMM 80568.0
PROCUREMENT 83278.24390243902
TRANSPORTN 89976.89606060603
FAMILY & SUPPORT 79013.58878504673
ANIMAL CONTRL 66089.68421052632
BUSINESS AFFAIRS 80446.425
STREETS & SAN 84347.77570093458
FINANCE 73276.36466165414
GENERAL SERVICES 83095.5283902439
CITY COUNCIL 63577.17206896553
DISABILITIES 82431.72413793103
TREASURER 88062.65217391304
LAW 84582.81440443214
DoIT 99681.02970297029
BUILDINGS 98864.83353383462
Invalid_Record NaN
HUMAN RELATIONS 4664366.823529412
BUDGET & MGMT 93925.3953488372
HEALTH 85488.2109375
COPA 98784.70588235294
WATER MGMNT 89894.1118032787
OEMC 73153.77822115383
CULTURAL AFFAIRS 87048.90909090909
INSPECTOR GEN 84030.66666666667
MAYOR'S OFFICE 96165.5123076923

The same can be result can achieved using the dataset api as below


import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RelationalGroupedDataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.*;

public class Test {

public static void main(String[] args) {
SparkSession session = SparkSession.builder().appName("Test").master("local").getOrCreate();

Dataset<Row> dataset = session.read().option("inferSchema", "true").csv("C:\\codebase\\scala-project\\inputdata\\employee").toDF("fn", "ln",
"designation", "department", "emp_type", "full_hour", "NA", "salary", "NA2");

RelationalGroupedDataset grouped=dataset.groupBy(col("department"));

grouped.agg(avg("salary")).show();

}

}