In this article I will illustrate how to do schema discovery for validation of column name before firing a select query on spark dataframe.
Let’s take the below example
import static org.apache.spark.sql.functions.col; import java.util.Arrays; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; public class SparkDriver { public static void main(String[] args) { SparkSession spark = SparkSession.builder().master("local").getOrCreate(); // Lets create the dataset of row using the Arrays asList Function Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"), new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"), new Employee("user-4", "male", "BSC", "USA")), Employee.class); test.select(col("name"), col("gender")).show(); } }
This prints the below output as expected
+------+------+ | name|gender| +------+------+ |user-1| male| |user-2| male| |user-3| male| |user-4| male| +------+------+
Let’s consider we have a spark job where we are doing parametric extraction based on the user input and user passes a invalid column value as below
spark-submit --deploy-mode cluster --class org.blog.SparkDriver ${jar_to_run} name gender invalid_column_name
import static org.apache.spark.sql.functions.col; import java.util.Arrays; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; public class SparkDriver { public static void main(String[] args) { SparkSession spark = SparkSession.builder().master("local").getOrCreate(); // Lets create the dataset of row using the Arrays asList Function Dataset<Row> test = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"), new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"), new Employee("user-4", "male", "BSC", "USA")), Employee.class); test.select(col(args[0]), col(args[1]),col(args[2])).show(); } }
This will result in a below error as invalid_column_name is not present in dataframe
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`invalid_column_name`' given input columns: [address, education, gender, name];; 'Project [name#3, gender#2, 'invalid_column_name] +- LocalRelation [address#0, education#1, gender#2, name#3]
So in these kind of scenarios where user is expected to pass the parameter to extract, it may be required to validate the parameter before firing a select query on dataframe.Below is the code to validate the schema for valid column names and filter the column names which is not part of the schema.
import static org.apache.spark.sql.functions.col; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.collection.JavaConverters; public class SparkDriver { public static void main(String[] args) { List<String> userParameter = new ArrayList<String>(Arrays.asList(args)); SparkSession spark = SparkSession.builder().master("local").getOrCreate(); // Lets create the dataset of row using the Arrays asList Function Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(new Employee("user-1", "male", "BSC", "USA"), new Employee("user-2", "male", "BSC", "USA"), new Employee("user-3", "male", "BSC", "USA"), new Employee("user-4", "male", "BSC", "USA")), Employee.class); StructType type = dataset.schema(); List<String> schemaList = new ArrayList<String>(); for (StructField structField : type.fields()) { String prefix = structField.name(); schemaList.add(prefix); } userParameter.retainAll(schemaList); List<Column> parameters = convertStringToColumn(userParameter); scala.collection.Seq<Column> column_filter = javaListToScalaSeq(parameters); dataset.select(column_filter).show(); } public static <T> scala.collection.Seq<T> javaListToScalaSeq(List<T> javaList) { return JavaConverters.asScalaIteratorConverter(javaList.iterator()).asScala().toSeq(); } public static List<Column> convertStringToColumn(List<String> parameters) { List<Column> listOfColumn = new ArrayList<Column>(); for (String column : parameters) { listOfColumn.add(col(column)); } return listOfColumn; } }