spark dataframe untyped custom user defined aggregate functions – tutorial 17

The built-in DataFrames functions provide common aggregations such as count(), countDistinct(), avg(), max(), min(), etc. While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in Scala and Java to work with strongly typed Datasets. Moreover, users are not limited to the predefined aggregate functions and can create their own.

Untyped User Defined Aggregate Functions

Untyped user defined aggregate functions can created by extending the class UserDefinedAggregateFunction in java and scala.

Lets write a user defined function to calculate the average rating of all the movies in the input file.

Below is the sample data


Spider Man,4,978301398
Spider Man,4,978302091
Bat Man,5,978298709
Bat Man,4,978299000
Bat Man,4,978299620

In Java

Here we have two instance variable inputSchema and bufferSchema where we are defining the StructField and as we want to find the average rating we will be passing the rating data as the input and the bufferSchema will have two fields one for sum and the other for count so that we can calculate the average rating. Below is the code


import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class UntypedUDF extends UserDefinedAggregateFunction {

private StructType inputSchema;

private StructType bufferSchema;

public UntypedUDF() {

List<StructField> list_input = new ArrayList<>();

list_input.add(DataTypes.createStructField("inputType", DataTypes.DoubleType, true));

inputSchema = DataTypes.createStructType(list_input);

List<StructField> list_output = new ArrayList<>();

list_output.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true));

list_output.add(DataTypes.createStructField("count", DataTypes.DoubleType, true));

bufferSchema = DataTypes.createStructType(list_output);

}

// Data types of values in the aggregation buffer

@Override
public StructType bufferSchema() {
// TODO Auto-generated method stub
return bufferSchema;
}

// Data types of input arguments of this aggregate function

@Override
public StructType inputSchema() {
// TODO Auto-generated method stub
return inputSchema;
}

// The data type of the returned value

@Override
public DataType dataType() {
// TODO Auto-generated method stub
return DataTypes.DoubleType;
}

// Whether this function always returns the same output on the identical input

@Override
public boolean deterministic() {
// TODO Auto-generated method stub
return true;
}

// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.

@Override
public void initialize(MutableAggregationBuffer arg0) {

System.out.println("Inside Initialize Method");
arg0.update(0, 0.0);
arg0.update(1, 0.0);

}

// Updates the given aggregation buffer `buffer` with new input data from `input`

@Override
public void update(MutableAggregationBuffer arg0, Row arg1) {
System.out.println("Inside Update Method");
arg0.update(0, arg0.getDouble(0) + arg1.getDouble(0));
arg0.update(1, arg0.getDouble(1) + 1);

}

 

 

// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`

@Override
public void merge(MutableAggregationBuffer arg0, Row arg1) {
System.out.println("Inside Merge Method");
arg0.update(0, arg0.getDouble(0) + arg1.getDouble(0));
arg0.update(1, arg0.getDouble(1) + arg1.getDouble(1));

}

// Calculates the final result

@Override
public Object evaluate(Row arg0) {
// TODO Auto-generated method stub
return arg0.getDouble(0)/arg0.getDouble(1);
}

 

}

Driver Code

Here we are registering the udf using the session.udf().register(“avrage_rating”, new UntypedUDF()) and we are using it as just another built in function in the query.


import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;

public class CustomUdfDriver {

public static void main(String[] args) {

SparkSession session = SparkSession.builder().appName("Test").config("key", "value").master("local")
.getOrCreate();

session.udf().register("avrage_rating", new UntypedUDF());

Dataset<Row> dataframe= session.read().csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2");

dataframe.createOrReplaceTempView("test_udf");

dataframe.show();

session.sql("select avrage_rating(_c1) from test_udf").show();

}

In Scala


import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.hadoop.io.DataInputByteBuffer.Buffer
import org.apache.spark.sql.Row

object UntypedUDF extends UserDefinedAggregateFunction {

// Data types of input arguments of this aggregate function

def inputSchema: StructType = StructType(StructField("input", DoubleType) :: Nil)

// Data types of values in the aggregation buffer

def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", DoubleType) :: Nil)

// The data type of the returned value

def dataType: DataType = DoubleType

// Whether this function always returns the same output on the identical input

def deterministic: Boolean = true

/ Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.

def initialize(buffer: MutableAggregationBuffer): Unit = {

buffer(0) = 0d
buffer(1) = 0d

}

// Updates the given aggregation buffer `buffer` with new input data from `input`

def update(buffer: MutableAggregationBuffer, row: Row): Unit = {

buffer(0) = buffer.getDouble(0) + row.getDouble(0)
buffer(1) = buffer.getDouble(1) + 1

}

// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`

def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {

buffer(0) = buffer.getDouble(0) + row.getDouble(0)
buffer(1) = buffer.getDouble(1) + row.getDouble(1)

}

// Calculates the final result

def evaluate(row: Row): Double = {

row.getDouble(0) / row.getDouble(1)

}

}

Driver Code

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.DataType

object CustomUdfDriver extends App{

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.functions.col

val session = SparkSession.builder().appName("Test").master("local").getOrCreate()

var dataframe = session.read.csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2")

session.udf.register("average_rating",UntypedUDF)

dataframe.createOrReplaceTempView("test_untype_udf")

dataframe.show()

session.sql("select average_rating(_c1) from test_untype_udf").show()

}