The built-in DataFrames functions provide common aggregations such as count(), countDistinct(), avg(), max(), min(), etc. While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in Scala and Java to work with strongly typed Datasets. Moreover, users are not limited to the predefined aggregate functions and can create their own.
Untyped User Defined Aggregate Functions
Untyped user defined aggregate functions can created by extending the class UserDefinedAggregateFunction in java and scala.
Lets write a user defined function to calculate the average rating of all the movies in the input file.
Below is the sample data
Spider Man,4,978301398 Spider Man,4,978302091 Bat Man,5,978298709 Bat Man,4,978299000 Bat Man,4,978299620
In Java
Here we have two instance variable inputSchema and bufferSchema where we are defining the StructField and as we want to find the average rating we will be passing the rating data as the input and the bufferSchema will have two fields one for sum and the other for count so that we can calculate the average rating. Below is the code
import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class UntypedUDF extends UserDefinedAggregateFunction { private StructType inputSchema; private StructType bufferSchema; public UntypedUDF() { List<StructField> list_input = new ArrayList<>(); list_input.add(DataTypes.createStructField("inputType", DataTypes.DoubleType, true)); inputSchema = DataTypes.createStructType(list_input); List<StructField> list_output = new ArrayList<>(); list_output.add(DataTypes.createStructField("sum", DataTypes.DoubleType, true)); list_output.add(DataTypes.createStructField("count", DataTypes.DoubleType, true)); bufferSchema = DataTypes.createStructType(list_output); } // Data types of values in the aggregation buffer @Override public StructType bufferSchema() { // TODO Auto-generated method stub return bufferSchema; } // Data types of input arguments of this aggregate function @Override public StructType inputSchema() { // TODO Auto-generated method stub return inputSchema; } // The data type of the returned value @Override public DataType dataType() { // TODO Auto-generated method stub return DataTypes.DoubleType; } // Whether this function always returns the same output on the identical input @Override public boolean deterministic() { // TODO Auto-generated method stub return true; } // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides // the opportunity to update its values. Note that arrays and maps inside the buffer are still // immutable. @Override public void initialize(MutableAggregationBuffer arg0) { System.out.println("Inside Initialize Method"); arg0.update(0, 0.0); arg0.update(1, 0.0); } // Updates the given aggregation buffer `buffer` with new input data from `input` @Override public void update(MutableAggregationBuffer arg0, Row arg1) { System.out.println("Inside Update Method"); arg0.update(0, arg0.getDouble(0) + arg1.getDouble(0)); arg0.update(1, arg0.getDouble(1) + 1); } // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` @Override public void merge(MutableAggregationBuffer arg0, Row arg1) { System.out.println("Inside Merge Method"); arg0.update(0, arg0.getDouble(0) + arg1.getDouble(0)); arg0.update(1, arg0.getDouble(1) + arg1.getDouble(1)); } // Calculates the final result @Override public Object evaluate(Row arg0) { // TODO Auto-generated method stub return arg0.getDouble(0)/arg0.getDouble(1); } }
Driver Code
Here we are registering the udf using the session.udf().register(“avrage_rating”, new UntypedUDF()) and we are using it as just another built in function in the query.
import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.TypedColumn; public class CustomUdfDriver { public static void main(String[] args) { SparkSession session = SparkSession.builder().appName("Test").config("key", "value").master("local") .getOrCreate(); session.udf().register("avrage_rating", new UntypedUDF()); Dataset<Row> dataframe= session.read().csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2"); dataframe.createOrReplaceTempView("test_udf"); dataframe.show(); session.sql("select avrage_rating(_c1) from test_udf").show(); }
In Scala
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.LongType import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.hadoop.io.DataInputByteBuffer.Buffer import org.apache.spark.sql.Row object UntypedUDF extends UserDefinedAggregateFunction { // Data types of input arguments of this aggregate function def inputSchema: StructType = StructType(StructField("input", DoubleType) :: Nil) // Data types of values in the aggregation buffer def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", DoubleType) :: Nil) // The data type of the returned value def dataType: DataType = DoubleType // Whether this function always returns the same output on the identical input def deterministic: Boolean = true / Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides // the opportunity to update its values. Note that arrays and maps inside the buffer are still // immutable. def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0d buffer(1) = 0d } // Updates the given aggregation buffer `buffer` with new input data from `input` def update(buffer: MutableAggregationBuffer, row: Row): Unit = { buffer(0) = buffer.getDouble(0) + row.getDouble(0) buffer(1) = buffer.getDouble(1) + 1 } // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` def merge(buffer: MutableAggregationBuffer, row: Row): Unit = { buffer(0) = buffer.getDouble(0) + row.getDouble(0) buffer(1) = buffer.getDouble(1) + row.getDouble(1) } // Calculates the final result def evaluate(row: Row): Double = { row.getDouble(0) / row.getDouble(1) } }
Driver Code
import org.apache.spark.sql.Encoders import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.DataType object CustomUdfDriver extends App{ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col val session = SparkSession.builder().appName("Test").master("local").getOrCreate() var dataframe = session.read.csv("C:\\codebase\\scala-project\\inputdata\\movies_data_2") session.udf.register("average_rating",UntypedUDF) dataframe.createOrReplaceTempView("test_untype_udf") dataframe.show() session.sql("select average_rating(_c1) from test_untype_udf").show() }