spark sql example to find max of average

In this short article I will show how to find max of average in spark sql. Below is the sample dataset, which we will be using to calculate the same.

+--------+-------------+--------+-----+--------+
|Hospital|AccountNumber| date|Visit| Amount|
+--------+-------------+--------+-----+--------+
| Apollo| 1|20200901| 1| 234.0|
| Apollo| 2|20200901| 0| 343.0|
| Apollo| 3|20200901| 1| 434.0|
| Apollo| 4|20200901| 0|565656.0|
| Apollo| 1|20190901| 1| 234.0|
| Apollo| 2|20190901| 0| 343.0|
| Apollo| 3|20190901| 1| 434.0|
| Apollo| 4|20190901| 0|565656.0|
| Apollo| 1|20200902| 1| 4343.0|
| Apollo| 2|20200902| 0| 3434.0|
| Apollo| 3|20200902| 1| 3434.0|
| Apollo| 4|20200902| 1| 3434.0|
| Apollo| 1|20200903| 0| 6534.0|
| Apollo| 2|20200903| 0| 3423.0|
| Apollo| 3|20200903| 0| 32323.0|
| Apollo| 4|20200903| 1| 454.0|
| Apollo| 1|20200904| 0| 32323.0|
| Apollo| 2|20200904| 0| 232.0|
| Apollo| 3|20200904| 1| 3232.0|
| Apollo| 4|20200904| 1| 6767.0|
| JMC| 1|20200903| 0| 898.0|
| JMC| 2|20200903| 0| 7878.0|
| JMC| 3|20200903| 0| 4545.0|
| JMC| 4|20200903| 1| 3434.0|
| JMC| 1|20200904| 0| 6767.0|
| JMC| 2|20200904| 0| 3434.0|
| JMC| 3|20200904| 1| 34343.0|
| JMC| 4|20200904| 1| 3434.0|
| Fortis| 1|20200903| 0| 2121.0|
| Fortis| 2|20200903| 0| 12323.0|
| Fortis| 3|20200903| 0| 2323.0|
| Fortis| 4|20200903| 1| 2323.0|
| Fortis| 1|20200904| 0| 2323.0|
| Fortis| 2|20200904| 0| 2323.0|
| Fortis| 3|20200904| 1| 2323.0|
| Fortis| 4|20200904| 1| 323.0|
+--------+-------------+--------+-----+--------+

Let`s say we want to find the hospital which collected maximum amount every year . Below is the result we want to arrive at

+----+--------+-------------+----+
|year|Hospital|averageAmount|rank|
+----+--------+-------------+----+
|2019| Apollo| 141666.75| 1|
|2020| Apollo| 41662.5| 1|
+----+--------+-------------+----+

Let`s create a dummy dataframe with the data as above.

package com.timepasstechies.blog.examples

import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.mutable.ListBuffer

class DataLoader {

def loadData(sparkSession: SparkSession): DataFrame = {

import sparkSession.implicits._

var sequenceOfOverview =
ListBuffer[(String, String, String, Integer, Double)]()
sequenceOfOverview += Tuple5("Apollo", "1", "20200901", 1, 234)
sequenceOfOverview += Tuple5("Apollo", "2", "20200901", 0, 343)
sequenceOfOverview += Tuple5("Apollo", "3", "20200901", 1, 434)
sequenceOfOverview += Tuple5("Apollo", "4", "20200901", 0, 565)

sequenceOfOverview += Tuple5("Apollo", "1", "20190901", 1, 234)
sequenceOfOverview += Tuple5("Apollo", "2", "20190901", 0, 343)
sequenceOfOverview += Tuple5("Apollo", "3", "20190901", 1, 434)
sequenceOfOverview += Tuple5("Apollo", "4", "20190901", 0, 565)

sequenceOfOverview += Tuple5("Apollo", "1", "20200902", 1, 4343)
sequenceOfOverview += Tuple5("Apollo", "2", "20200902", 0, 3434)
sequenceOfOverview += Tuple5("Apollo", "3", "20200902", 1, 3434)
sequenceOfOverview += Tuple5("Apollo", "4", "20200902", 1, 3434)

sequenceOfOverview += Tuple5("Apollo", "1", "20200903", 0, 6534)
sequenceOfOverview += Tuple5("Apollo", "2", "20200903", 0, 3423)
sequenceOfOverview += Tuple5("Apollo", "3", "20200903", 0, 3232)
sequenceOfOverview += Tuple5("Apollo", "4", "20200903", 1, 454)

sequenceOfOverview += Tuple5("Apollo", "1", "20200904", 0, 3232)
sequenceOfOverview += Tuple5("Apollo", "2", "20200904", 0, 232)
sequenceOfOverview += Tuple5("Apollo", "3", "20200904", 1, 3232)
sequenceOfOverview += Tuple5("Apollo", "4", "20200904", 1, 6767)

sequenceOfOverview += Tuple5("JMC", "1", "20200903", 0, 898)
sequenceOfOverview += Tuple5("JMC", "2", "20200903", 0, 7878)
sequenceOfOverview += Tuple5("JMC", "3", "20200903", 0, 4545)
sequenceOfOverview += Tuple5("JMC", "4", "20200903", 1, 3434)

sequenceOfOverview += Tuple5("JMC", "1", "20200904", 0, 6767)
sequenceOfOverview += Tuple5("JMC", "2", "20200904", 0, 3434)
sequenceOfOverview += Tuple5("JMC", "3", "20200904", 1, 34343)
sequenceOfOverview += Tuple5("JMC", "4", "20200904", 1, 3434)

sequenceOfOverview += Tuple5("Fortis", "1", "20200903", 0, 2121)
sequenceOfOverview += Tuple5("Fortis", "2", "20200903", 0, 123)
sequenceOfOverview += Tuple5("Fortis", "3", "20200903", 0, 2323)
sequenceOfOverview += Tuple5("Fortis", "4", "20200903", 1, 2323)

sequenceOfOverview += Tuple5("Fortis", "1", "20200904", 0, 2323)
sequenceOfOverview += Tuple5("Fortis", "2", "20200904", 0, 2323)
sequenceOfOverview += Tuple5("Fortis", "3", "20200904", 1, 2323)
sequenceOfOverview += Tuple5("Fortis", "4", "20200904", 1, 323)

val df1 =
sequenceOfOverview.toDF(
"Hospital",
"AccountNumber",
"date",
"Visit",
"Amount"
)
df1
}

}

Lets code the solution to find max of average in spark sql

package com.timepasstechies.blog.examples

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object MaxOfAverage extends App {

val dataLoader = new DataLoader()
lazy val sparkSession: SparkSession = SparkSession
.builder()
.master("local[*]")
.getOrCreate()

val data = dataLoader.loadData(sparkSession)

val yearlyData = data
.select(
col("Hospital"),
col("AccountNumber"),
col("Visit"),
col("Amount"),
year(to_date(col("date").cast("String"), "yyyyMMdd")).as("year")
)
.groupBy("year", "Hospital")
.agg(avg(col("Amount")).as("averageAmount"))

val win =
Window
.partitionBy(col("year"))
.orderBy(col("averageAmount").desc)

yearlyData
.select(col("*"), rank().over(win).as("rank"))
.filter(col("rank") === 1)
.show()

}

That’s a brief on how we can find max of average in spark sql.