spark dataset type safe custom user defined aggregate functions – tutorial 18

User-defined aggregations for strongly typed Datasets revolve around the Aggregator abstract class.

Lets write a user defined function to calculate the average rating of all the movies in the input file.

Below is the sample data


Spider Man,4,978301398
Spider Man,4,978302091
Bat Man,5,978298709
Bat Man,4,978299000
Bat Man,4,978299620

Lets code the example

In Java


import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;

public class TypedUDF extends Aggregator<Movie, Average, Double> {

// Specifies the Encoder for the intermediate value type

@Override
public Encoder<Average> bufferEncoder() {
// TODO Auto-generated method stub
return Encoders.bean(Average.class);
}

// Transform the output of the reduction

@Override
public Double finish(Average arg0) {
// TODO Auto-generated method stub
System.out.println("Inside Method Finish");
return arg0.getSum()/arg0.getCount();
}

// Merge two intermediate values

@Override
public Average merge(Average arg0, Average arg1) {
// TODO Auto-generated method stub

System.out.println("Inside Merge Finish");
arg0.setSum(arg0.getSum() + arg1.getSum());

arg0.setCount(arg0.getCount() + arg1.getCount());

return arg0;
}

// Specifies the Encoder for the final output value type

@Override
public Encoder<Double> outputEncoder() {
// TODO Auto-generated method stub
return Encoders.DOUBLE();
}

// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object

@Override
public Average reduce(Average arg0, Movie arg1) {

System.out.println("Inside Method Reduce");
arg0.setSum(arg0.getSum() + arg1.getRating());
arg0.setCount(arg0.getCount() + 1);

return arg0;
}

// A zero value for this aggregation. Should satisfy the property that any b + zero = b

@Override
public Average zero() {
// TODO Auto-generated method stub

System.out.println("Inside Zero Finish");
return new Average(0, 0);
}

}

Driver Code

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;

public class CustomUdfDriver {

public static void main(String[] args) {

SparkSession session = SparkSession.builder().appName("Test").config("key", "value").master("local")
.getOrCreate();

Encoder<Movie> movie_encoder=Encoders.bean(Movie.class);

Dataset<Row> movie_dataframe=session.read().csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2");

movie_dataframe=movie_dataframe.withColumnRenamed("_c0", "name");

movie_dataframe=movie_dataframe.withColumnRenamed("_c1", "rating");

movie_dataframe= movie_dataframe.withColumnRenamed("_c2", "timestamp");

Dataset<Movie> dataset=movie_dataframe.as(movie_encoder);

dataset.show();

TypedUDF typed=new TypedUDF();

TypedColumn<Movie, Double> average_rating=typed.toColumn().name("average_rating");

dataset.select(average_rating).show();

}

}

In Scala


import org.apache.spark.sql.expressions.Aggregator

object TypedUDF extends Aggregator[Movie,Average,Double]{

// A zero value for this aggregation. Should satisfy the property that any b + zero = b

def zero: Average = Average(0d, 0d)

// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object

def reduce(average: Average, movie: Movie): Average = {

average.sum += movie.rating
average.count += 1
average
}

// Merge two intermediate values

def merge(ave1: Average, ave2: Average): Average = {

ave1.sum += ave2.sum

ave1.count += ave2.count

ave1

}

// Transform the output of the reduction

def finish(average: Average): Double = {

average.sum / average.count

}

// Specifies the Encoder for the intermediate value type

def bufferEncoder: org.apache.spark.sql.Encoder[Average] = org.apache.spark.sql.Encoders.product

// Specifies the Encoder for the final output value type

def outputEncoder: org.apache.spark.sql.Encoder[Double] = org.apache.spark.sql.Encoders.scalaDouble

}

Driver Code

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.DataType

object CustomUdfDriver extends App{

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.functions.col

val session = SparkSession.builder().appName("Test").master("local").getOrCreate()

var dataframe = session.read.csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2")

dataframe=dataframe.withColumnRenamed("_c0", "name")

dataframe=dataframe.withColumnRenamed("_c1", "rating")

dataframe=dataframe.withColumnRenamed("_c2", "timestamp")

dataframe=dataframe.withColumn("rating", col("rating").cast("double"))

import session.implicits._

val dataset:Dataset[Movie]=dataframe.as[Movie]

dataset.show();

val average=TypedUDF.toColumn.name("average_rating_name")

dataset.select(average).show()

}

2 thoughts on “spark dataset type safe custom user defined aggregate functions – tutorial 18”

  1. if i want to return a value which type is mutableList[(String, Double)],what should i do? thanks

Comments are closed.