In this article I will illustrate how to do schema discovery for validation of column name before firing a select query on spark dataframe.
Let’s take the below example
import static org.apache.spark.sql.functions.col;
import java.util.Arrays;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class SparkDriver {
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().master("local").getOrCreate();
// Lets create the dataset of row using the Arrays asList Function
Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);
test.select(col("name"), col("gender")).show();
}
}
This prints the below output as expected
+------+------+ | name|gender| +------+------+ |user-1| male| |user-2| male| |user-3| male| |user-4| male| +------+------+
Let’s consider we have a spark job where we are doing parametric extraction based on the user input and user passes a invalid column value as below
spark-submit --deploy-mode cluster --class org.blog.SparkDriver ${jar_to_run} name gender invalid_column_name
import static org.apache.spark.sql.functions.col;
import java.util.Arrays;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
public class SparkDriver {
public static void main(String[] args) {
SparkSession spark = SparkSession.builder().master("local").getOrCreate();
// Lets create the dataset of row using the Arrays asList Function
Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);
test.select(col(args[0]), col(args[1]),col(args[2])).show();
}
}
This will result in a below error as invalid_column_name is not present in dataframe
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`invalid_column_name`' given input columns: [address, education, gender, name];; 'Project [name#3, gender#2, 'invalid_column_name] +- LocalRelation [address#0, education#1, gender#2, name#3]
So in these kind of scenarios where user is expected to pass the parameter to extract, it may be required to validate the parameter before firing a select query on dataframe.Below is the code to validate the schema for valid column names and filter the column names which is not part of the schema.
import static org.apache.spark.sql.functions.col;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConverters;
public class SparkDriver {
public static void main(String[] args) {
List<String> userParameter = new ArrayList<String>(Arrays.asList(args));
SparkSession spark = SparkSession.builder().master("local").getOrCreate();
// Lets create the dataset of row using the Arrays asList Function
Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"),
new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"),
new Employee("user-4", "male", "BSC", "USA")), Employee.class);
StructType type = dataset.schema();
List<String> schemaList = new ArrayList<String>();
for (StructField structField : type.fields()) {
String prefix = structField.name();
schemaList.add(prefix);
}
userParameter.retainAll(schemaList);
List<Column> parameters = convertStringToColumn(userParameter);
scala.collection.Seq<Column> column_filter = javaListToScalaSeq(parameters);
dataset.select(column_filter).show();
}
public static <T> scala.collection.Seq<T> javaListToScalaSeq(List<T> javaList) {
return JavaConverters.asScalaIteratorConverter(javaList.iterator()).asScala().toSeq();
}
public static List<Column> convertStringToColumn(List<String> parameters) {
List<Column> listOfColumn = new ArrayList<Column>();
for (String column : parameters) {
listOfColumn.add(col(column));
}
return listOfColumn;
}
}