home / 2021.07.06 10:10 / spark / test-driven / unit test / in-memory db
This short article will go over a method to develop unit tests for Spark jobs. One such unit test for Spark can be seen as having the following steps/components:
We start with test data, which will be part of our test code base. We have some operations to prepare the test environment, like removing outputs from old tests, or starting up embedded, in-memory databases, that can receive the outputs. After running the test, we look at the output location and verify that the data has been processed as expected.
We'll start with the simple use-case, reading and writing to the file system, then move to writing data to Hive and Mongo databases. The example is developed in Scala, with sbt and the scalaTest library.
First, all our Spark jobs are written in a way that allows us to run them locally, based on an input parameter. This
input parameter decides the way that the SparkContext
is initialized. When we run Spark locally, we also want to
stop the SparkContext
at the end of the program.
object HdfsSparkJob {
def main(args: Array[String]): Unit = {
val argmap = getArgumentsMap(args)
val local = argmap.getOrElse("local", "false").toBoolean
// read input parameters
val inputLocation = argmap.getOrElse("input_location", "")
val outputLocation = argmap.getOrElse("output_location", "")
implicit val spark: SparkSession = (
if (local) SparkSession.builder().master("local[*]")
else SparkSession.builder()
)
.config("spark.sql.shuffle.partitions", 12)
// spark configuration
.getOrCreate()
import spark.implicits._
val inputData = spark.read.parquet(inputLocation)
// spark operations resulting in output data
outputData
.write
.mode("append")
.parquet(outputLocation)
if (local) spark.stop()
}
def getArgumentsMap(args: Array[String]): Map[String, String] = {
args
.map(a => {
val firstEq = a.indexOf('=')
Seq(a.substring(0, firstEq), a.substring(firstEq + 1))
})
.filter(a => a(0).nonEmpty && a(1).nonEmpty)
.map(a => a(0) -> a(1))
.toMap
}
}
The test for the job defined above will start by cleaning up any output data that may exist. Then, we run the
HdfsSparkJob
with the correct input and output configuration, in local mode. Once the job is finished, we initialize
a new spark context to verify that the output location contains the correct results. The example test also gives you
a method that can be used to verify the schema of the output.
class HdfsSparkJobTest extends FlatSpec with Matchers {
val INPUT_LOCATION = "src/test/resources/input/hdfs_spark_job"
val OUTPUT_LOCATION = "src/test/resources/output/hdfs_spark_job"
def cleanup(): Unit = {
FileUtils.deleteDirectory(new File(OUTPUT_LOCATION))
}
"running hdfs spark job" should "process the input data correctly" in {
cleanup()
HdfsSparkJob.main(Array(
"local=true",
"input_location=" + INPUT_LOCATION,
"output_location=" + OUTPUT_LOCATION
))
implicit val spark: SparkSession = SparkSession.builder()
.master("local[*]").getOrCreate()
val output = spark.read.parquet(OUTPUT_LOCATION)
output.count() should be (7)
assertSchemaField[DoubleType](output.schema, "sensor_value")
}
def assertSchemaField[T:ClassTag](parent: StructType, name: String) = {
val field = parent.find(f => f.name == name)
field should not be None
field.get.dataType mustBe a[T]
}
}
As seen in the test, both the input and output locations are relative to the project folder. The project structure would be the following:
- src
- main
- scala
- my.package
- HdfsSparkJob.scala
- test
- resources
- input
- hdfs_spark_job
- [output]
- [hdfs_spark_job]
- scala
- my.package
- HdfsSparkJobTest.scala
- .gitignore
- build.sbt
The output folder should not be checked in to source control, if you are using Git you should exclude everything under
output
in the .gitignore
file.
You can run the test above from within SBT, with testOnly *HdfsSparkJobTest
.
This second Spark job is designed to write its output data into Hive. With locally run unit tests, we don't have a Hive database, but we can use an in-memory instance of Apache Derby as stand-in for the Hive database. When initializing the spark session in local mode, we configure it to use the in-memory derby instance. The spark job example here also shows how we can create or dynamically insert overwrite (by partition) a Hive table.
object HiveSparkJob {
def main(args: Array[String]): Unit = {
val argmap = getArgumentsMap(args)
val local = argmap.getOrElse("local", "false").toBoolean
// read input parameters
val inputLocation = argmap.getOrElse("input_location", "")
val outputLocation = argmap.getOrElse("output_location", "")
val database = argmap.getOrElse("database", "default")
val table = argmap.getOrElse("table", "")
val tempView = argmap.getOrElse("temp_view", "temp")
implicit val spark: SparkSession = (
if (local) SparkSession.builder().master("local[*]")
.config("spark.hadoop.javax.jdo.option.ConnectionDriverName", "org.apache.derby.jdbc.EmbeddedDriver")
.config("spark.hadoop.javax.jdo.option.ConnectionURL", "jdbc:derby:memory:default;create=true")
.config("spark.hadoop.javax.jdo.option.ConnectionUserName", "hiveuser")
.config("spark.hadoop.javax.jdo.option.ConnectionPassword", "hivepass")
else SparkSession.builder()
)
.config("spark.sql.sources.partitionOverwriteMode", "dynamic")
// other settings
.enableHiveSupport()
.getOrCreate()
val inputData = spark.read.parquet(inputLocation)
// spark operations
output.createOrReplaceTempView(tempView)
val tableExists = spark.sql("show tables from " + database)
.filter(col("tableName") === table)
.count() == 1
if (tableExists) {
val tableRow = spark.sql("select * from " + database + "." + table + " limit 1")
val columns = tableRow.columns.toSeq
spark.sql(
"insert overwrite table " + database + "." + table
+ " select " + columns.mkString(",")
+ " from " + tempView
)
} else {
spark.sql(s"""
create table if not exists $database.$table
using parquet partitioned by (${partitions.mkString(",")})
options (path "$outputLocation")
as select * from $tempView
""")
}
if (local) spark.stop()
}
}
In this example, the output Hive table data is stored as parquet files, so like the previous test we can verify our output data by looking at those files.
class HiveSparkJobTest extends FlatSpec with Matchers {
val INPUT_LOCATION = "src/test/resources/input/hive_spark_job"
val OUTPUT_LOCATION = "src/test/resources/output/hive_spark_job"
val TABLE = "hive_spark_table"
def cleanup() = {
// delete output folder
}
"running the hive spark job" should "process data and create table and parquet files" in {
cleanup()
HiveSparkJob.main(Array(
"local=true",
"input_location=" + INPUT_LOCATION,
"output_location=" + OUTPUT_LOCATION,
"table=" + TABLE
))
implicit val spark: SparkSession = SparkSession.builder().master("local[*]").getOrCreate()
val output = spark.read.parquet(OUTPUT_LOCATION)
output.count() should be (3)
// other verifications
}
}
The last test we'll look at works with a Spark job writing results in a Mongo database. I will not show the actual job, but for reference I will include the test, which will start an in-memory Mongo database before running the Spark job. After the Spark job finished running, we can keep the Mongo database running, connect to it and verify the output data. At the end of the test, we stop the Mongo daemon. The embedded Mongo from Flapdoodle is used in this test.
class MongoSparkJobTest extends FlatSpec with Matchers {
val INPUT_LOCATION = "src/test/resources/input/mongo_spark_job"
val DATABASE = "test_db"
val COLLECTION_NAME = "test_collection"
// cleanup code
"running mongo spark job" should "create mongo collection with data" in {
cleanup()
val starter = MongodStarter.getDefaultInstance()
val port = Network.getFreeServerPort()
print("mongo service port: " + port)
val mongodConfig = new MongodConfigBuilder()
.version(Version.Main.PRODUCTION)
.net(new Net(port, Network.localhostIsIPv6()))
.build()
val mongodExecutable = starter.prepare(mongodConfig)
val mongod = mongodExecutable.start()
val mongoConnectionString = "mongodb://localhost:" + port
MongoSparkJob.main(Array(
"local=true",
"mongo_connection=" + mongoConnectionString,
"database=" + DATABASE,
"collection_name=" + COLLECTION_NAME
))
val mongoClient = MongoClient(mongoConnectionString)
val db = mongoClient.getDatabase(DATABASE)
val collection = db.getCollection[Object](COLLECTION_NAME)
val outputCount: Long = Await.result(collection.countDocuments().toFuture(), Duration.Inf)
outputCount should be (2)
mongod.stop()
}
}
These three tests can be starting points to a complex unit test setup that can improve the development of Spark pipelines in your project, improve the quality of your code and also make onboarding simpler for people new to your project and to Spark.