home / 2021.07.06 10:10 / spark / test-driven / unit test / in-memory db

Spark Test-driven Development

This short article will go over a method to develop unit tests for Spark jobs. One such unit test for Spark can be seen as having the following steps/components:

We start with test data, which will be part of our test code base. We have some operations to prepare the test environment, like removing outputs from old tests, or starting up embedded, in-memory databases, that can receive the outputs. After running the test, we look at the output location and verify that the data has been processed as expected.

We'll start with the simple use-case, reading and writing to the file system, then move to writing data to Hive and Mongo databases. The example is developed in Scala, with sbt and the scalaTest library.

HDFS test

First, all our Spark jobs are written in a way that allows us to run them locally, based on an input parameter. This input parameter decides the way that the SparkContext is initialized. When we run Spark locally, we also want to stop the SparkContext at the end of the program.

object HdfsSparkJob { def main(args: Array[String]): Unit = { val argmap = getArgumentsMap(args) val local = argmap.getOrElse("local", "false").toBoolean // read input parameters val inputLocation = argmap.getOrElse("input_location", "") val outputLocation = argmap.getOrElse("output_location", "") implicit val spark: SparkSession = ( if (local) SparkSession.builder().master("local[*]") else SparkSession.builder() ) .config("spark.sql.shuffle.partitions", 12) // spark configuration .getOrCreate() import spark.implicits._ val inputData = spark.read.parquet(inputLocation) // spark operations resulting in output data outputData .write .mode("append") .parquet(outputLocation) if (local) spark.stop() } def getArgumentsMap(args: Array[String]): Map[String, String] = { args .map(a => { val firstEq = a.indexOf('=') Seq(a.substring(0, firstEq), a.substring(firstEq + 1)) }) .filter(a => a(0).nonEmpty && a(1).nonEmpty) .map(a => a(0) -> a(1)) .toMap } }

The test for the job defined above will start by cleaning up any output data that may exist. Then, we run the HdfsSparkJob with the correct input and output configuration, in local mode. Once the job is finished, we initialize a new spark context to verify that the output location contains the correct results. The example test also gives you a method that can be used to verify the schema of the output.

class HdfsSparkJobTest extends FlatSpec with Matchers { val INPUT_LOCATION = "src/test/resources/input/hdfs_spark_job" val OUTPUT_LOCATION = "src/test/resources/output/hdfs_spark_job" def cleanup(): Unit = { FileUtils.deleteDirectory(new File(OUTPUT_LOCATION)) } "running hdfs spark job" should "process the input data correctly" in { cleanup() HdfsSparkJob.main(Array( "local=true", "input_location=" + INPUT_LOCATION, "output_location=" + OUTPUT_LOCATION )) implicit val spark: SparkSession = SparkSession.builder() .master("local[*]").getOrCreate() val output = spark.read.parquet(OUTPUT_LOCATION) output.count() should be (7) assertSchemaField[DoubleType](output.schema, "sensor_value") } def assertSchemaField[T:ClassTag](parent: StructType, name: String) = { val field = parent.find(f => f.name == name) field should not be None field.get.dataType mustBe a[T] } }

As seen in the test, both the input and output locations are relative to the project folder. The project structure would be the following:

- src - main - scala - my.package - HdfsSparkJob.scala - test - resources - input - hdfs_spark_job - [output] - [hdfs_spark_job] - scala - my.package - HdfsSparkJobTest.scala - .gitignore - build.sbt

The output folder should not be checked in to source control, if you are using Git you should exclude everything under output in the .gitignore file.

You can run the test above from within SBT, with testOnly *HdfsSparkJobTest.

Hive test

This second Spark job is designed to write its output data into Hive. With locally run unit tests, we don't have a Hive database, but we can use an in-memory instance of Apache Derby as stand-in for the Hive database. When initializing the spark session in local mode, we configure it to use the in-memory derby instance. The spark job example here also shows how we can create or dynamically insert overwrite (by partition) a Hive table.

object HiveSparkJob { def main(args: Array[String]): Unit = { val argmap = getArgumentsMap(args) val local = argmap.getOrElse("local", "false").toBoolean // read input parameters val inputLocation = argmap.getOrElse("input_location", "") val outputLocation = argmap.getOrElse("output_location", "") val database = argmap.getOrElse("database", "default") val table = argmap.getOrElse("table", "") val tempView = argmap.getOrElse("temp_view", "temp") implicit val spark: SparkSession = ( if (local) SparkSession.builder().master("local[*]") .config("spark.hadoop.javax.jdo.option.ConnectionDriverName", "org.apache.derby.jdbc.EmbeddedDriver") .config("spark.hadoop.javax.jdo.option.ConnectionURL", "jdbc:derby:memory:default;create=true") .config("spark.hadoop.javax.jdo.option.ConnectionUserName", "hiveuser") .config("spark.hadoop.javax.jdo.option.ConnectionPassword", "hivepass") else SparkSession.builder() ) .config("spark.sql.sources.partitionOverwriteMode", "dynamic") // other settings .enableHiveSupport() .getOrCreate() val inputData = spark.read.parquet(inputLocation) // spark operations output.createOrReplaceTempView(tempView) val tableExists = spark.sql("show tables from " + database) .filter(col("tableName") === table) .count() == 1 if (tableExists) { val tableRow = spark.sql("select * from " + database + "." + table + " limit 1") val columns = tableRow.columns.toSeq spark.sql( "insert overwrite table " + database + "." + table + " select " + columns.mkString(",") + " from " + tempView ) } else { spark.sql(s""" create table if not exists $database.$table using parquet partitioned by (${partitions.mkString(",")}) options (path "$outputLocation") as select * from $tempView """) } if (local) spark.stop() } }

In this example, the output Hive table data is stored as parquet files, so like the previous test we can verify our output data by looking at those files.

class HiveSparkJobTest extends FlatSpec with Matchers { val INPUT_LOCATION = "src/test/resources/input/hive_spark_job" val OUTPUT_LOCATION = "src/test/resources/output/hive_spark_job" val TABLE = "hive_spark_table" def cleanup() = { // delete output folder } "running the hive spark job" should "process data and create table and parquet files" in { cleanup() HiveSparkJob.main(Array( "local=true", "input_location=" + INPUT_LOCATION, "output_location=" + OUTPUT_LOCATION, "table=" + TABLE )) implicit val spark: SparkSession = SparkSession.builder().master("local[*]").getOrCreate() val output = spark.read.parquet(OUTPUT_LOCATION) output.count() should be (3) // other verifications } }

Mongo test

The last test we'll look at works with a Spark job writing results in a Mongo database. I will not show the actual job, but for reference I will include the test, which will start an in-memory Mongo database before running the Spark job. After the Spark job finished running, we can keep the Mongo database running, connect to it and verify the output data. At the end of the test, we stop the Mongo daemon. The embedded Mongo from Flapdoodle is used in this test.

class MongoSparkJobTest extends FlatSpec with Matchers { val INPUT_LOCATION = "src/test/resources/input/mongo_spark_job" val DATABASE = "test_db" val COLLECTION_NAME = "test_collection" // cleanup code "running mongo spark job" should "create mongo collection with data" in { cleanup() val starter = MongodStarter.getDefaultInstance() val port = Network.getFreeServerPort() print("mongo service port: " + port) val mongodConfig = new MongodConfigBuilder() .version(Version.Main.PRODUCTION) .net(new Net(port, Network.localhostIsIPv6())) .build() val mongodExecutable = starter.prepare(mongodConfig) val mongod = mongodExecutable.start() val mongoConnectionString = "mongodb://localhost:" + port MongoSparkJob.main(Array( "local=true", "mongo_connection=" + mongoConnectionString, "database=" + DATABASE, "collection_name=" + COLLECTION_NAME )) val mongoClient = MongoClient(mongoConnectionString) val db = mongoClient.getDatabase(DATABASE) val collection = db.getCollection[Object](COLLECTION_NAME) val outputCount: Long = Await.result(collection.countDocuments().toFuture(), Duration.Inf) outputCount should be (2) mongod.stop() } }

These three tests can be starting points to a complex unit test setup that can improve the development of Spark pipelines in your project, improve the quality of your code and also make onboarding simpler for people new to your project and to Spark.