Validating Spark DataFrame Schemas

In this article I will illustrate how to do schema discovery for validation of column name before firing a select query on spark dataframe.

Let’s take the below example


import static org.apache.spark.sql.functions.col;
import java.util.Arrays;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class SparkDriver {

public static void main(String[] args) {

SparkSession spark = SparkSession.builder().master("local").getOrCreate();

// Lets create the dataset of row using the Arrays asList Function

Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);

test.select(col("name"), col("gender")).show();

}
}

This prints the below output as expected


+------+------+
| name|gender|
+------+------+
|user-1| male|
|user-2| male|
|user-3| male|
|user-4| male|
+------+------+

Let’s consider we have a spark job where we are doing parametric extraction based on the user input and user passes a invalid column value as below


spark-submit --deploy-mode cluster --class org.blog.SparkDriver ${jar_to_run} name gender invalid_column_name


import static org.apache.spark.sql.functions.col;
import java.util.Arrays;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class SparkDriver {

public static void main(String[] args) {


SparkSession spark = SparkSession.builder().master("local").getOrCreate();

// Lets create the dataset of row using the Arrays asList Function

Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);

test.select(col(args[0]), col(args[1]),col(args[2])).show();

}
}

This will result in a below error as invalid_column_name is not present in dataframe


Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`invalid_column_name`' given input columns: [address, education, gender, name];;
'Project [name#3, gender#2, 'invalid_column_name]
+- LocalRelation [address#0, education#1, gender#2, name#3]

So in these kind of scenarios where user is expected to pass the parameter to extract, it may be required to validate the parameter before firing a select query on dataframe.Below is the code to validate the schema for valid column names and filter the column names which is not part of the schema.


import static org.apache.spark.sql.functions.col;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConverters;

public class SparkDriver {

public static void main(String[] args) {

List<String> userParameter = new ArrayList<String>(Arrays.asList(args));

SparkSession spark = SparkSession.builder().master("local").getOrCreate();

// Lets create the dataset of row using the Arrays asList Function
Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);

StructType type = dataset.schema();
List<String> schemaList = new ArrayList<String>();
for (StructField structField : type.fields()) {

String prefix = structField.name();
schemaList.add(prefix);
}

userParameter.retainAll(schemaList);

List<Column> parameters = convertStringToColumn(userParameter);
scala.collection.Seq<Column> column_filter = javaListToScalaSeq(parameters);

dataset.select(column_filter).show();

}

public static <T> scala.collection.Seq<T> javaListToScalaSeq(List<T> javaList) {

return JavaConverters.asScalaIteratorConverter(javaList.iterator()).asScala().toSeq();
}

public static List<Column> convertStringToColumn(List<String> parameters) {
List<Column> listOfColumn = new ArrayList<Column>();

for (String column : parameters) {

listOfColumn.add(col(column));
}

return listOfColumn;

}

}