home / 2018.10.02 19:00 / scala / apache spark / mongo db / apache kafka / big data

Spark Foreach Mongo Upsert Writer

Introduction

We need a way to write or update results in the Mongo DB after obtaining those results through Spark Structured Stream processing. We need a sollution that can be easily adapted for each Spark job we wish to run, which means our sollution should be able to save any case class we define as the result (and preferably it should also be able to save deep case classes).

Our example Spark job is inclued below:

package com.cacoveanu.bigdata

import org.apache.spark.sql.{Dataset, SparkSession}

case class SimpleResult(_id: String, info: String, value: Int)

case class MongoResultEntry(entryId: String, value: String)

case class DeepMongoResult(_id: String, info: MongoResultEntry, value: Map[String, Int])

object StructuredStreamingToMongoGeneric extends App {

  val spark = SparkSession
    .builder
    .appName(this.getClass.getName)
    .master("local[2]")
    .getOrCreate()

  import spark.implicits._

  val frame = spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", "localhost:9092")
    .option("startingOffsets", "earliest")
    .option("subscribe", "data")
    .load()

  val results: Dataset[DeepMongoResult] = frame.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
    .as[(String, String)]
    .map { case (key, value) =>
      // do some processing
    }

  val query = results
    .writeStream
    .outputMode("complete")
    .foreach(
      new MongoUpsertWriter("mongodb://localhost:27017", "local", "results_collection"))
    .start()

  query.awaitTermination()
}

The Spark job:

Next I’ll go over a few approaches to make a MongoUpsertWriter class that can help us in this sitation.

This sollution supports upserts to Mongo and can handle any deep case class, but the implementation is not generic. This means that, while most of the code for the MongoUpsertWriter can be reused when we want to write a new case class to Mongo, we need to copy the code into a new writer class and replace all references from the old (DeepMongoResult) case class to our new case class.

package com.cacoveanu.bigdata

import com.mongodb.client.model.ReplaceOptions
import org.apache.spark.sql.ForeachWriter
import org.bson.codecs.configuration.CodecRegistries.{fromProviders, fromRegistries}
import org.mongodb.scala._
import org.mongodb.scala.bson.codecs.{DEFAULT_CODEC_REGISTRY, Macros}
import org.mongodb.scala.model.Filters._

import scala.concurrent.Await
import scala.concurrent.duration.Duration

class MongoUpsertWriter(connectionString: String,
                        databaseName: String,
                        collectionName: String) extends ForeachWriter[DeepMongoResult] {

  var mongoClient: MongoClient = _
  var database: MongoDatabase = _
  var collection: MongoCollection[DeepMongoResult] = _

  override def open(partitionId: Long, version: Long): Boolean = {
    mongoClient = MongoClient(connectionString)
    val codecRegistry = fromRegistries(fromProviders(
      Macros.createCodecProvider[DeepMongoResult](),
      Macros.createCodecProvider[MongoResultEntry]()
    ), DEFAULT_CODEC_REGISTRY)
    database = mongoClient.getDatabase(databaseName).withCodecRegistry(codecRegistry)
    collection = database.getCollection[DeepMongoResult](collectionName)
    true
  }

  override def process(value: DeepMongoResult): Unit = {
    try {
      val options = new ReplaceOptions().upsert(true)

      val observable = collection.replaceOne(
        equal("_id", value._id),
        value,
        options
      )
      Await.result(observable.toFuture(), Duration.Inf)
    } catch {
      case t: Throwable => println(t)
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    //do nothing
  }
}

This sollution does all we need from it, but we need to create a new mongo writer class for each result we want to save into Mongo (so a new class for each Spark job).

Generic Writer Sollution

A better approach would be if we had a way to create a generic writer to which we provide the base case class that we want to save and then we just instantiate it with a new case class for each new result we want to save into Mongo. This, however, is a lot more complicated, mainly due to the fact that the bson codecs library uses Scala macros to dinamically create codecs for our case classes, and Scala macros do not support generics.

Another aproach is for us to create our own way of converting the case class into a bson document using reflection (more exactly we create bson conversions that get applied to the objects in the database to update them). This is a working sollution for that, but it only works with flat case classes:

package com.cacoveanu.bigdata

import org.apache.spark.sql.ForeachWriter
import org.bson.BsonDocument
import org.bson.conversions.Bson
import org.mongodb.scala.model.UpdateOptions

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import org.mongodb.scala._
import org.mongodb.scala.model.Filters._
import org.mongodb.scala.model.Updates.{combine, set}

class MongoUpsertWriter[T](connectionString: String,
                           databaseName: String,
                           collectionName: String) extends ForeachWriter[T] {

  private val ID_FIELD_NAME = "_id"

  var mongoClient: MongoClient = _
  var database: MongoDatabase = _
  var collection: MongoCollection[_] = _

  override def open(partitionId: Long, version: Long): Boolean = {
    mongoClient = MongoClient(connectionString)
    database = mongoClient.getDatabase(databaseName)
    collection = database.getCollection(collectionName)
    true
  }

  private def getId(obj: T) = {
    val field = obj.getClass.getDeclaredField(ID_FIELD_NAME)
    field setAccessible true
    field.get(obj)
  }

  private def getObjectFields(obj: T): List[(String, AnyRef)] =
    obj.getClass.getDeclaredFields.map(field => {
      field setAccessible true
      field.getName -> field.get(obj)
    }).toList

  private def toSetQuery(name: String, value: AnyRef) = set(name, value)

  private def toCombineQuery(fields: List[(String, AnyRef)]) =
    combine(fields.filter(f => f._1 != ID_FIELD_NAME)
      .map(f => toSetQuery(f._1, f._2)):_*)

  override def process(value: T): Unit = {
    try {
      val options = new UpdateOptions().upsert(true)
      
      val observable = collection.updateOne(
        equal("_id", getId(value)),
        toCombineQuery(getObjectFields(value)),
        options
      )
      Await.result(observable.toFuture(), Duration.Inf)
    } catch {
      case t: Throwable => println(t)
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    //do nothing
  }
}

This class is simpler to use, we just have to create the writer and instantiate it with the case class we want to write to Mongo (as long as it is a flat case class, it will work) as in the example below:

val query = results
    .writeStream
    .outputMode("complete")
    .foreach(
      new MongoUpsertWriter[SimpleResult]("mongodb://localhost:27017", "local", "results_collection"))
    .start()

More Powerful Generic Writer

We could, of course, expand the above sollution to a more generic implementation that can go into a deep case class structure and create the correct set queries to update the deep-structured bson document in the Mongo database.

To test out how a deep dive into case classes would look, we can run the following program:

package com.cacoveanu.bigdata

import scala.collection.mutable.ListBuffer

case class RRBaseClass(str: String, i: Int)

case class RRDeepClass(str: String, bc: RRBaseClass)

case class RRDeeperClass(str: String, dc: RRDeepClass, m: Map[String, Int], l: List[String])

object RecursiveReflection {

  private def isPrimitive(obj: AnyRef) =
    obj.isInstanceOf[Int] ||
      obj.isInstanceOf[String] ||
      obj.isInstanceOf[Double] ||
      obj.isInstanceOf[Float] ||
      obj.isInstanceOf[Long] ||
      obj.isInstanceOf[Short] ||
      obj.isInstanceOf[Byte] ||
      obj.isInstanceOf[Boolean] ||
      obj.isInstanceOf[Char] ||
      obj.isInstanceOf[List[_]] ||
      obj.isInstanceOf[Map[_, _]]

  def getObjectFields(path: String, obj: AnyRef): List[(String, AnyRef)] = {
    val res = new ListBuffer[(String, AnyRef)]()
    obj.getClass.getDeclaredFields.foreach(field => {
      field setAccessible true
      val value = field.get(obj)

      if (isPrimitive(value)) {
        res += (path + field.getName -> value)
      } else {
        res ++= getObjectFields(field.getName + ".", value)
      }
    })
    res.toList
  }

  def main(args: Array[String]): Unit = {
    println(getObjectFields("", RRBaseClass("one", 1)))
    println(getObjectFields("", RRDeepClass("deeptwo", RRBaseClass("two", 2))))
    println(getObjectFields("",
      RRDeeperClass("deeperthree",
        RRDeepClass("deepthree", RRBaseClass("three", 3)),
        Map("map1" -> 1, "map2" -> 2),
        List("listone", "listtwo")
      )
    ))
  }
}

If you run the above example, you should get the following output:

List((str,one), (i,1))
List((str,deeptwo), (bc.str,two), (bc.i,2))
List((str,deeperthree), (dc.str,deepthree), (bc.str,three), (bc.i,3), (m,Map(map1 -> 1, map2 -> 2)), (l,List(listone, listtwo)))

The Mongo upsert writer (generic) based on this approach is detailed below:

package com.cacoveanu.bigdata

import org.apache.spark.sql.ForeachWriter
import org.mongodb.scala._
import org.mongodb.scala.model.Filters._
import org.mongodb.scala.model.UpdateOptions
import org.mongodb.scala.model.Updates.{combine, set}

import scala.collection.mutable.ListBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration

class MongoUpsertWriter[T](connectionString: String,
                           databaseName: String,
                           collectionName: String) extends ForeachWriter[T] {

  private val ID_FIELD_NAME = "_id"

  var mongoClient: MongoClient = _
  var database: MongoDatabase = _
  var collection: MongoCollection[_] = _

  override def open(partitionId: Long, version: Long): Boolean = {
    mongoClient = MongoClient(connectionString)
    database = mongoClient.getDatabase(databaseName)
    collection = database.getCollection(collectionName)
    true
  }

  private def getId(obj: T) = {
    val field = obj.getClass.getDeclaredField(ID_FIELD_NAME)
    field setAccessible true
    field.get(obj)
  }

  private def isPrimitive(obj: AnyRef) =
    obj.isInstanceOf[Int] ||
      obj.isInstanceOf[String] ||
      obj.isInstanceOf[Double] ||
      obj.isInstanceOf[Float] ||
      obj.isInstanceOf[Long] ||
      obj.isInstanceOf[Short] ||
      obj.isInstanceOf[Byte] ||
      obj.isInstanceOf[Boolean] ||
      obj.isInstanceOf[Char] ||
      obj.isInstanceOf[List[_]] ||
      obj.isInstanceOf[Map[_, _]]

  def getObjectFields(path: String, obj: AnyRef): List[(String, AnyRef)] = {
    val res = new ListBuffer[(String, AnyRef)]()
    obj.getClass.getDeclaredFields.foreach(field => {
      field setAccessible true
      val value = field.get(obj)

      if (isPrimitive(value)) {
        res += (path + field.getName -> value)
      } else {
        res ++= getObjectFields(field.getName + ".", value)
      }
    })
    res.toList
  }

  private def toSetQuery(name: String, value: AnyRef) = set(name, value)

  private def toCombineQuery(fields: List[(String, AnyRef)]) =
    combine(fields.filter(f => f._1 != ID_FIELD_NAME)
      .map(f => toSetQuery(f._1, f._2)):_*)

  override def process(value: T): Unit = {
    try {
      val options = new UpdateOptions().upsert(true)
      
      val observable = collection.updateOne(
        equal("_id", getId(value)),
        toCombineQuery(getObjectFields("", value.asInstanceOf[AnyRef])),
        options
      )
      Await.result(observable.toFuture(), Duration.Inf)
    } catch {
      case t: Throwable => println(t)
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    //do nothing
  }
}

This can be used more easily, with a hierarchy of case classes, as long as the values in the case classes are the primitive types defined in the isPrimitive method (or the isPrimitive method can be extended):

val query = results
    .writeStream
    .outputMode("complete")
    .foreach(
      new MongoUpsertWriter[DeepMongoResult]("mongodb://localhost:27017", "local", "results_collection"))
    .start()

Generic Writer Based On Generic Scala Collections

A different way of looking at the problem would be to cosider the fact that any case class can be converted into a Map, a Scala structure that can be written to Mongo by the library without additional work. But for this we would need to convert a deep case class structure to a structure of Maps and Lists (or sequences) containing primitive (or at least non-case class) structures.

The following experiment does just that:

package com.cacoveanu.bigdata

case class CCMOne(str: String, i: Int)
case class CCMDeep(str: String, o: CCMOne)
case class CCMDeepest(str: String, cc: CCMDeep, m: Map[Int, CCMOne], l: List[CCMOne])

object CaseClassToMap {

  def isCaseClass(o: Any) =
    o.getClass.getInterfaces.contains(classOf[scala.Product])

  def unspoolSequence(seq: Seq[_]): Seq[_] = 
    seq.map(
      v => unspool(v)
    )

  def unspoolMap(map: Map[_,_]): Map[_, _] =
    map.map(
      m => (m._1, unspool(m._2))
    )

  def unspoolCaseClass(cc: Product): Map[String, Any] =
    unspoolMap(
        cc.getClass.getDeclaredFields
            .map(_.getName)
            .zip(cc.productIterator.to)
            .toMap
    ).asInstanceOf[Map[String, Any]]

  def unspool(v: Any) =
    if (isCaseClass(v)) unspoolCaseClass(v.asInstanceOf[Product])
    else if (v.isInstanceOf[Seq[_]]) unspoolSequence(v.asInstanceOf[Seq[_]])
    else if (v.isInstanceOf[Map[_,_]]) unspoolMap(v.asInstanceOf[Map[_, _]])
    else v

  def main(args: Array[String]): Unit = {
    println(isCaseClass(CCMOne("one", 1)))
    println(isCaseClass(1.asInstanceOf[AnyRef]))
    println(isCaseClass("test"))
    println(isCaseClass(List))

    println(unspool(CCMOne("two", 2)))
    println(unspool(CCMDeep("ccmd", CCMOne("three", 3))))
    println(unspool(CCMDeepest(
      "ccmddd",
      CCMDeep("ccmd2", CCMOne("four", 4)),
      Map(
        1 -> CCMOne("five", 5),
        2 -> CCMOne("six", 6)
      ),
      List(CCMOne("seven", 7))
    )))
  }
}

Running all that code you should get the following output:

true
false
false
false
Map(str -> two, i -> 2)
Map(str -> ccmd, o -> Map(str -> three, i -> 3))
Map(str -> ccmddd, cc -> Map(str -> ccmd2, o -> Map(str -> four, i -> 4)), m -> Map(1 -> Map(str -> five, i -> 5), 2 -> Map(str -> six, i -> 6)), l -> List(Map(str -> seven, i -> 7)))

The Mongo writer can now be implemented as follows:

package com.cacoveanu.bigdata

import com.mongodb.client.model.ReplaceOptions
import org.apache.spark.sql.ForeachWriter
import org.mongodb.scala._
import org.mongodb.scala.model.Filters._

import scala.concurrent.Await
import scala.concurrent.duration.Duration

class MongoUpsertWriter[T](connectionString: String,
                           databaseName: String,
                           collectionName: String) extends ForeachWriter[T] {

  private val ID_FIELD_NAME = "_id"

  var mongoClient: MongoClient = _
  var database: MongoDatabase = _
  var collection: MongoCollection[Map[_,_]] = _

  override def open(partitionId: Long, version: Long): Boolean = {
    mongoClient = MongoClient(connectionString)
    database = mongoClient.getDatabase(databaseName)
    collection = database.getCollection(collectionName)
    true
  }

  private def getId(obj: T) = {
    val field = obj.getClass.getDeclaredField(ID_FIELD_NAME)
    field setAccessible true
    field.get(obj)
  }

  def isCaseClass(o: Any) =
    o.getClass.getInterfaces.contains(classOf[scala.Product])

  def unspoolSequence(seq: Seq[_]): Seq[_] =
    seq.map(
      v => unspool(v)
    )

  def unspoolMap(map: Map[_,_]): Map[_, _] =
    map.map(
      m => (m._1, unspool(m._2))
    )

  def unspoolCaseClass(cc: Product): Map[String, Any] =
    unspoolMap(cc.getClass.getDeclaredFields
      .map( _.getName )
      .zip( cc.productIterator.to )
      .toMap).asInstanceOf[Map[String, Any]]

  def unspool(v: Any) =
    if (isCaseClass(v)) unspoolCaseClass(v.asInstanceOf[Product])
    else if (v.isInstanceOf[Seq[_]]) unspoolSequence(v.asInstanceOf[Seq[_]])
    else if (v.isInstanceOf[Map[_,_]]) unspoolMap(v.asInstanceOf[Map[_, _]])
    else v

  override def process(value: T): Unit = {
    try {
      val options = new ReplaceOptions().upsert(true)

      val observable = collection.replaceOne(
        equal("_id", getId(value)),
        unspool(value).asInstanceOf[Map[_,_]],
        options
      )
      Await.result(observable.toFuture(), Duration.Inf)
    } catch {
      case t: Throwable => println(t)
    }
  }

  override def close(errorOrNull: Throwable): Unit = {
    //do nothing
  }
}

Notice that we no longer use set queries, instead we use the replaceOne method, with upsert configured, to just save the map to Mongo. This will result in a structure in Mongo that is exactly the same as the structure saved by the first sollution (full-featured sollution).

This writer can be instantiated in the same way as previous ones:

val query = results
    .writeStream
    .outputMode("complete")
    .foreach(
      new MongoUpsertWriter[DeepMongoResult]("mongodb://localhost:27017", "local", "results_collection"))
    .start()