Spark jdbc batch processing not inserting all records

In my spark job, I’m using jdbc batch processing to insert records into MySQL. But I noticed that all the records were not making it into MySQL. For example;

//count records before insert
println(s"dataframe: ${dataframe.count()}")

dataframe.foreachPartition(partition => {

  Class.forName(jdbcDriver)
  val dbConnection: Connection = DriverManager.getConnection(jdbcUrl, username, password)

  var preparedStatement: PreparedStatement = null
  dbConnection.setAutoCommit(false)
  val batchSize = 100

  partition.grouped(batchSize).foreach(batch => {
    batch.foreach(row => {
      val productName = row.getString(row.fieldIndex("productName"))
      val quantity = row.getLong(row.fieldIndex("quantity"))
      val sqlString =
        s"""
           |INSERT INTO myDb.product (productName, quantity)
           |VALUES (?, ?)
          """.stripMargin

      preparedStatement = dbConnection.prepareStatement(sqlString)
      preparedStatement.setString(1, productName)
      preparedStatement.setLong(2, quantity)

      preparedStatement.addBatch()
    })

    preparedStatement.executeBatch()
    dbConnection.commit()
    preparedStatement.close()
  })
  dbConnection.close()
})

I see 650 records in the dataframe.count but when I checked mysql, I see 195 records. And this is deterministic. I tried different batch sizes and still see the same number. But when I moved preparedStatement.executeBatch() inside the batch.foreach() i.e. the next line right after preparedStatement.addBatch(), I see the full 650 records in mysql..which isnt batching the insert statements anymore as its executing it immediately after adding it within a single iteration. What could be the issue preventing batching the queries?

Answer

It seems you’re creating a new preparedStatement in each iteration, which means preparedStatement.executeBatch() is applied to the last batch only i.e. 195 instead of 650 records. Instead, you should create one preparedStatement then substitute the parameters in the iteration, like this:

dataframe.foreachPartition(partition => {

  Class.forName(jdbcDriver)
  val dbConnection: Connection = DriverManager.getConnection(jdbcUrl, username, password)

  val sqlString =
        s"""
           |INSERT INTO myDb.product (productName, quantity)
           |VALUES (?, ?)
          """.stripMargin

  var preparedStatement: PreparedStatement = dbConnection.prepareStatement(sqlString)

  dbConnection.setAutoCommit(false)
  val batchSize = 100

  partition.grouped(batchSize).foreach(batch => {
    batch.foreach(row => {
      val productName = row.getString(row.fieldIndex("productName"))
      val quantity = row.getLong(row.fieldIndex("quantity"))
      

      preparedStatement = dbConnection.prepareStatement(sqlString)
      preparedStatement.setString(1, productName)
      preparedStatement.setLong(2, quantity)

      preparedStatement.addBatch()
    })

    preparedStatement.executeBatch()
    dbConnection.commit()
    preparedStatement.close()
  })
  dbConnection.close()
})