using hive udaf in spark sql

In this article i will demonstrate how to build a Hive UDAF and execute it in Apache Spark.

In hive user defined aggregate functions works on group of rows and gives single row as output. Some examples of build in udaf in Hive are MAX(),MIN(),AVG() and COUNT().

Lets say we have a input data as below


1920,shelf=0/slot=5/port=1,100
1920,shelf=1/slot=4/port=6,200
1920,shelf=2/slot=5/port=24,300
1920,shelf=3/slot=5/port=0,400

We need a hive custom aggregate function which will find the average usage per device. Lets write a custom java class to define user defined aggregate function or udaf which extends org.apache.hadoop.hive.ql.exec.UDAF . We have an inner class defined here MeanUDAFEvaluator which extends UDAFEvaluator which has all the required init(),iterate(),terminatePartial(),merge() and terminate() methods.

Below is the code


import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.ql.metadata.HiveException;
@SuppressWarnings("deprecation")
public class CustomUDAF extends UDAF {
// Define Logging
static final Log LOG = LogFactory.getLog(CustomUDAF.class.getName());
public static class MeanUDAFEvaluator implements UDAFEvaluator {
/**
* Use Column class to serialize intermediate computation This is our
* groupByColumn
*/
public static class Column {
double sum = 0;
int count = 0;
}
private Column col = null;
public MeanUDAFEvaluator() {
super();
init();
}
// A - Initalize evaluator - indicating that no values have been
// aggregated yet.
public void init() {
LOG.debug("Initialize evaluator");
col = new Column();
}
// B- Iterate every time there is a new value to be aggregated

public boolean iterate(double value) throws HiveException {
LOG.debug("Iterating over each value for aggregation");
if (col == null)
throw new HiveException("Item is not initialized");
col.sum = col.sum + value;
col.count = col.count + 1;
return true;
}
// C - Called when Hive wants partially aggregated results.
public Column terminatePartial() {
LOG.debug("Return partially aggregated results");
return col;
}
// D - Called when Hive decides to combine one partial aggregation with
// another

public boolean merge(Column other) {
LOG.debug("merging by combining partial aggregation");
if (other == null) {
return true;
}
col.sum += other.sum;
col.count += other.count;
return true;
}
// E - Called when the final result of the aggregation needed.
public double terminate() {
LOG.debug(
"At the end of last record of the group - returning final result");
return col.sum / col.count;
}
}
}

To deploy the above code package your java class into a jar file and add it to the hive class path and add a temporary function and execute the udaf as below


hive> ADD HiveUdaf.jar;
hive> CREATE TEMPORARY FUNCTION AVERAGE_USAGE as 'blog.hadoop.hive.custom.CustomUDAF';
hive> select AVERAGE_USAGE(octet) from test_table group by device;

So lets execute the same hive udaf using spark sql and dataframe. We need to create a temporary view from the dataset and register the function using the session.sql method . Once we register the function we can use the same in the queries.

Below is the code


import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import static org.apache.spark.sql.functions.col;

public class SparkUDAFExecutor {

public static void main(String[] args) throws AnalysisException {

// Pass the input path as a parameter

SparkSession session = SparkSession.builder().enableHiveSupport().appName("test").master("local").getOrCreate();
// session.sparkContext().setLogLevel("DEBUG");
Dataset<Row> df = session.read().option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ").csv(args[0])
.toDF("device", "slot", "usage").withColumn("octet", col("usage").cast("double"));
df.printSchema();
df.show();

df.createTempView("test_table");

session.sql("CREATE TEMPORARY FUNCTION AVERAGE_USAGE AS 'com.blog.hive.udf.CustomUDAF'");
session.sql("SELECT device,AVERAGE_USAGE(octet) FROM test_table group by device").show();

}

}

In the above code we are using spark 2.0 features like SparkSession . If we have to execute the hive udaf using older version of spark like spark 1.9,1.8,1.7 and 1.6 we can use the below code.

Here we are first loading the javardd of string and then we define the schema string where all the fields is of string type . We then create an list of StructField and add data into the list using DataTypes.CreateStructField method. And then we create the StringType object passing the list we created above to the createstructType method. Once we have the structType object ready we then convert the javardd of string into the javardd of row using the RowFactory.create method inside the map method. Finally we pass the structType object we created earlier and javardd of row into the hiveContext.createDataFrame method to get the dataframe. Once we have the dataframe we create a temporary view from the dataframe and register the function using the dataFrame.registerTempTable method . Once we register the function we can use the same in the queries.


import java.util.ArrayList;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class CustomUDAF {

public static void main(String[] args) throws AnalysisException {

SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
JavaSparkContext sc = new JavaSparkContext(conf);

List<org.apache.spark.sql.types.StructField> listOfStructField = new ArrayList<StructField>();
listOfStructField.add(DataTypes.createStructField("device", DataTypes.StringType, true));
listOfStructField.add(DataTypes.createStructField("slot", DataTypes.StringType, true));
listOfStructField.add(DataTypes.createStructField("usage", DataTypes.StringType, true));

StructType structType = DataTypes.createStructType(listOfStructField);

JavaRDD<Row> rdd = sc.textFile("C:\\dataset\\test").map(new Function<String, Row>() {

@Override
public Row call(String v1) throws Exception {
// TODO Auto-generated method stub

String[] data = v1.split(",");

return RowFactory.create(data[0], data[1], data[2]);
}
});

HiveContext hc = new HiveContext(sc);

DataFrame df = hc.createDataFrame(rdd, structType);

df.registerTempTable("test_table");

hc.sql("CREATE TEMPORARY FUNCTION AVERAGE_USAGE AS 'com.blog.hive.udf.CustomUDAF'");

hc.sql("SELECT one,AVERAGE_USAGE(three) FROM test_table group by one").show();

}
}

Lets execute the same udaf for avro data . Lets first create the avro data using the below code, we are using the same input data that we have used in the above example.

We need additional dependency to process the avro data . We can add the below dependency if we are using maven to build the project


<dependency>
<groupId>com.databricks</groupId>
<artifactId>spark-avro_2.10</artifactId>
<version>1.1.0-cdh5.9.1</version>
</dependency>


import java.util.ArrayList;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class CreateAvroData {

public static void main(String[] args) {

SparkConf conf = new SparkConf().setMaster("local").setAppName("test");
JavaSparkContext sc = new JavaSparkContext(conf);

List<org.apache.spark.sql.types.StructField> listOfStructField = new ArrayList<StructField>();
listOfStructField.add(
DataTypes.createStructField("device", DataTypes.StringType, true));
listOfStructField.add(
DataTypes.createStructField("shelf", DataTypes.StringType, true));
listOfStructField.add(DataTypes.createStructField("usage",
DataTypes.StringType, true));

StructType structType = DataTypes.createStructType(listOfStructField);

JavaRDD<Row> rddData = sc.textFile("Input Path")
.map(new Function<String, Row>() {

private static final long serialVersionUID = 1212;

@Override
public Row call(String v1) throws Exception {
// TODO Auto-generated method stub

String[] data = v1.split(",");
return RowFactory.create(data[0], data[1], data[2]);
}
});

SQLContext hiveContext = new HiveContext(sc);
DataFrame dataFrame = hiveContext.createDataFrame(rddData, structType);
dataFrame.write().format("com.databricks.spark.avro").save("Output Location");
}

}

We will use the same hive udaf that we created above and execute the same in spark sql to process the avro data. In the below code we are loading the query from a configuration file.


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.hive.HiveContext;

public class SparkExecutorAvro {

public static void main(String[] args) {

args = new String[]{"Input Path"};
SparkConf conf = new SparkConf().setMaster("local").setAppName("test");

JavaSparkContext sc = new JavaSparkContext(conf);
sc.hadoopConfiguration()
.set("avro.mapred.ignore.inputs.without.extension", "false");
HiveContext sql = new HiveContext(sc);
DataFrame df = sql.read().format("com.databricks.spark.avro")
.load(args[0]);

df.registerTempTable("diequal_long_format");
sql.sql("CREATE TEMPORARY FUNCTION testUdaf AS 'com.blog.hive.udf.CustomUDAF'");
ConfigFile scriptFile = new ConfigFile(Constants.QUERY1_TXT,
FileType.script);
String query = scriptFile.getFileContent();
sql.sql(query).save("Output Path", SaveMode.Overwrite);

}

}

Below are the helper classes used for loading the query from the configuration file


import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.util.Properties;
import org.apache.commons.io.IOUtils;
import org.apache.log4j.Logger;

public class ConfigFile {

private String fileName;

private SequencedProperties properties;

private FileType fileType;

private String fileContent;

private static Logger logger;

public ConfigFile(String fileName, FileType fileType) {
this.fileName = fileName;
this.properties = new SequencedProperties();
this.fileType = fileType;
loadFile();

}

public Properties getProperties() {
return properties;
}

public void setProperties(SequencedProperties properties) {
this.properties = properties;
}

public FileType getFileType() {
return fileType;
}

public void setFileType(FileType fileType) {
this.fileType = fileType;
}

public String getFileContent() {
return fileContent;
}

public void setFileContent(String fileContent) {
this.fileContent = fileContent;
}

public String getFileName() {
return fileName;
}

public void setFileName(String fileName) {
this.fileName = fileName;
}

private void loadFile() {
InputStream in = getClass().getClassLoader().getResourceAsStream(getFileName());
try {
if (this.getFileType() == FileType.property) {
this.getProperties().load(in);
} else if (this.getFileType() == FileType.script) {
StringWriter writer = new StringWriter();
IOUtils.copy(in, writer);
fileContent = writer.toString();
}
} catch (IOException e) {
logger.error(e.getMessage().toString());
} finally {
try {
in.close();
} catch (IOException e) {
logger.error(e.getMessage().toString());
}
}
}

public Properties getProperty() {
return properties;

}

public String getString(String key) {
return properties.getProperty(key);
}

}


public class Constants {

public static final String QUERY1_TXT = "query1.txt";

public static final String QUERY2_TXT = "query2.txt";

public static final String QUERY3_TXT = "query3.txt";

/**
* Instantiates a new constants.
*/
public Constants() {
super();
}

}


public enum FileType {

property,script

}


import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

/**
* The Class SequencedProperties is a custom property handler implementation.
*/
public class SequencedProperties extends Properties {

/** The Constant serialVersionUID. */
private static final long serialVersionUID = 1L;

/** The key set. */
private transient Set<Object> keySet = new LinkedHashSet<Object>(100);

/*
* (non-Javadoc)
*
* @see java.util.Hashtable#keys()
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public synchronized Enumeration keys() {
return Collections.enumeration(keySet);
}

/*
* (non-Javadoc)
*
* @see java.util.Hashtable#keySet()
*/
@Override
public Set<Object> keySet() {
return keySet;
}

/*
* (non-Javadoc)
*
* @see java.util.Hashtable#put(java.lang.Object, java.lang.Object)
*/
@Override
public synchronized Object put(Object key, Object value) {
if (!keySet.contains(key)) {
keySet.add(key);
}
return super.put(key, value);
}

/*
* (non-Javadoc)
*
* @see java.util.Hashtable#remove(java.lang.Object)
*/
@Override
public synchronized Object remove(Object key) {
keySet.remove(key);
return super.remove(key);
}

/*
* (non-Javadoc)
*
* @see java.util.Hashtable#putAll(java.util.Map)
*/
@SuppressWarnings("unchecked")
@Override
public synchronized void putAll(@SuppressWarnings("rawtypes") Map values) {
for (Object key : values.keySet()) {
if (!containsKey(key)) {
keySet.add(key);
}
}
super.putAll(values);
}
}