User-defined aggregations for strongly typed Datasets revolve around the Aggregator abstract class.
Lets write a user defined function to calculate the average rating of all the movies in the input file.
Below is the sample data
Spider Man,4,978301398 Spider Man,4,978302091 Bat Man,5,978298709 Bat Man,4,978299000 Bat Man,4,978299620
Lets code the example
In Java
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
public class TypedUDF extends Aggregator<Movie, Average, Double> {
// Specifies the Encoder for the intermediate value type
@Override
public Encoder<Average> bufferEncoder() {
// TODO Auto-generated method stub
return Encoders.bean(Average.class);
}
// Transform the output of the reduction
@Override
public Double finish(Average arg0) {
// TODO Auto-generated method stub
System.out.println("Inside Method Finish");
return arg0.getSum()/arg0.getCount();
}
// Merge two intermediate values
@Override
public Average merge(Average arg0, Average arg1) {
// TODO Auto-generated method stub
System.out.println("Inside Merge Finish");
arg0.setSum(arg0.getSum() + arg1.getSum());
arg0.setCount(arg0.getCount() + arg1.getCount());
return arg0;
}
// Specifies the Encoder for the final output value type
@Override
public Encoder<Double> outputEncoder() {
// TODO Auto-generated method stub
return Encoders.DOUBLE();
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
@Override
public Average reduce(Average arg0, Movie arg1) {
System.out.println("Inside Method Reduce");
arg0.setSum(arg0.getSum() + arg1.getRating());
arg0.setCount(arg0.getCount() + 1);
return arg0;
}
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
@Override
public Average zero() {
// TODO Auto-generated method stub
System.out.println("Inside Zero Finish");
return new Average(0, 0);
}
}
Driver Code
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
public class CustomUdfDriver {
public static void main(String[] args) {
SparkSession session = SparkSession.builder().appName("Test").config("key", "value").master("local")
.getOrCreate();
Encoder<Movie> movie_encoder=Encoders.bean(Movie.class);
Dataset<Row> movie_dataframe=session.read().csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2");
movie_dataframe=movie_dataframe.withColumnRenamed("_c0", "name");
movie_dataframe=movie_dataframe.withColumnRenamed("_c1", "rating");
movie_dataframe= movie_dataframe.withColumnRenamed("_c2", "timestamp");
Dataset<Movie> dataset=movie_dataframe.as(movie_encoder);
dataset.show();
TypedUDF typed=new TypedUDF();
TypedColumn<Movie, Double> average_rating=typed.toColumn().name("average_rating");
dataset.select(average_rating).show();
}
}
In Scala
import org.apache.spark.sql.expressions.Aggregator
object TypedUDF extends Aggregator[Movie,Average,Double]{
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0d, 0d)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(average: Average, movie: Movie): Average = {
average.sum += movie.rating
average.count += 1
average
}
// Merge two intermediate values
def merge(ave1: Average, ave2: Average): Average = {
ave1.sum += ave2.sum
ave1.count += ave2.count
ave1
}
// Transform the output of the reduction
def finish(average: Average): Double = {
average.sum / average.count
}
// Specifies the Encoder for the intermediate value type
def bufferEncoder: org.apache.spark.sql.Encoder[Average] = org.apache.spark.sql.Encoders.product
// Specifies the Encoder for the final output value type
def outputEncoder: org.apache.spark.sql.Encoder[Double] = org.apache.spark.sql.Encoders.scalaDouble
}
Driver Code
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.DataType
object CustomUdfDriver extends App{
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
val session = SparkSession.builder().appName("Test").master("local").getOrCreate()
var dataframe = session.read.csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2")
dataframe=dataframe.withColumnRenamed("_c0", "name")
dataframe=dataframe.withColumnRenamed("_c1", "rating")
dataframe=dataframe.withColumnRenamed("_c2", "timestamp")
dataframe=dataframe.withColumn("rating", col("rating").cast("double"))
import session.implicits._
val dataset:Dataset[Movie]=dataframe.as[Movie]
dataset.show();
val average=TypedUDF.toColumn.name("average_rating_name")
dataset.select(average).show()
}
if i want to return a value which type is mutableList[(String, Double)],what should i do? thanks
like public class TypedUDF extends Aggregator