DEV Community

Jim Hatcher
Jim Hatcher

Posted on

Spark Update Optimizations

About a year ago, I posted a blog on using Apache Spark jobs to do updates against CockroachDB. At the end of that post, I pledged to come back and do some optimizations to speed things up. And, true to my word, here I am.

Between the time I published my last blog and now, I am running on a new laptop, so I needed to get Spark running on my laptop. I did the following:

  • installed Scala 2.12 from homebrew
  • installed Spark from homebrew (latest as of the time of me writing this is 3.3.2)
  • tried to install Java 11 but gave up and installed Java 17 (Spark 3.3.2's overview page tells me I can run on Java 8/11/17)

I then tried to simply re-run the examples from my last blog post and ran into a gnarly error:
java.lang.IllegalAccessException: final field has no write access

After a little googling, I found a similarly-described problem here. This issue is fixed in the not-yet-released Spark 3.4.0, so I was a little stuck until I looked at the changes included in the fix and decided I could reproduce them myself.

Namely, the fix seems to be starting your Java process with the following option: -Djdk.reflect.useDirectMethodHandle=false

I added this option to my java driver options like so, and was able to run my previous examples. (I also bumped up my Postgres driver version a bit):

spark-shell
  --driver-class-path /Users/jimhatcher/.m2/repository/org/postgresql/postgresql/42.5.1/postgresql-42.5.1.jar
  --jars /Users/jimhatcher/.m2/repository/org/postgresql/postgresql/42.5.1/postgresql-42.5.1.jar
  --conf spark.jars=/Users/jimhatcher/.m2/repository/org/postgresql/postgresql/42.5.1/postgresql-42.5.1.jar
  --conf spark.executor.memory=4g
  --conf spark.driver.memory=4g
  --driver-java-options=-Djdk.reflect.useDirectMethodHandle=false
Enter fullscreen mode Exit fullscreen mode

This is a nasty little workaround (I admit), but I can omit it once Spark 3.4.0+ comes out. And, at this point, I was able to re-run my previous tests. My previous test involved creating a table with an id column and one additional field (which started null) and then updating all the records in the table where that field was null to have a value of 1.

I then proceeded to try to tune various settings (batch size, number of partitions in the Spark dataframe, fetch size in the dataframe, etc.). I could make some small improvements but nothing significant or interesting.

I looked as the SQL Statements page in the CockroachDB DB Console and saw the queries that were being generated and run against the cluster. All my updates were coming through as individual statements. Each statement itself was fairly quick (~1ms), but running a bunch of individual statements is never as quick as batching data together.

In the Postgres JDBC driver, there is a parameter called reWriteBatchInserts which can have a huge performance lift when writing data to CockroachDB. I verified that I was passing the parameter on my JDBC URL, and I was. This parameter only comes into play when running INSERTs, not UPDATEs, so I wrote a different version of my Spark job that would read from one table and INSERT values into another table. When I ran this job, it took around 6sec to INSERT 1 million records instead of the 40sec I had been seeing when trying to UPDATE 1 million records.

Here's the important bits of that Spark code:

df.coalesce(10).foreachPartition(
  (partition: Iterator[Row]) => {

    val dbc: Connection = DriverManager.getConnection(url, "root", "")
    val batchSize = 1000
    val st: PreparedStatement = dbc.prepareStatement("INSERT INTO test.test_table_target ( id, f1 ) VALUES ( ?, 1 );")

    partition.grouped(batchSize).foreach (
      batch => {
        batch.foreach (
          row => {
            st.setLong(1, row.getLong(0)) 
            st.addBatch()
          }
        )
        st.executeBatch()
      }
    )

    dbc.close()

  }
)
Enter fullscreen mode Exit fullscreen mode

I then turned my thoughts to how I could force CockroachDB to do UPDATE statements in a batch, similarly to how the Postgres JDBC driver leverages query re-writing to batch INSERT queries in a very efficient manner. One such technique that we can leverage is to pass a bunch of values into a SQL statement as arrays and then use the UNNEST statement to convert these arrays into tabular input.

Instead of running a single update like this:

UPDATE test.test_table
SET f1 = 1
WHERE id = ?;
Enter fullscreen mode Exit fullscreen mode

I turned my attention to an approach like this:

UPDATE test_table
SET f1 = data_table.new_f1
FROM (
  SELECT UNNEST(?) AS id, UNNEST(?) AS new_f1
) AS data_table
WHERE test_table.id = data_table.id;
Enter fullscreen mode Exit fullscreen mode

To help visualize what is happening in with the UNNEST function, here is a simple example of un-nesting a simple array in CockroachDB:

root@localhost:26257/test> SELECT UNNEST(ARRAY[1, 2, 3]);
  unnest
----------
       1
       2
       3
(3 rows)
Enter fullscreen mode Exit fullscreen mode

With this goal in mind, I re-wrote my Scala job to submit one UPDATE query for each batch (defined by my batch size variable).

import java.sql._
import org.apache.spark.sql._
import scala.collection.mutable.ListBuffer

val url = "jdbc:postgresql://<CRDB host name here>:26257/test?sslmode=require&sslcert=/Users/jimhatcher/spark-cluster-certs/client.root.crt&sslkey=/Users/jimhatcher/spark-cluster-certs/client.root.key.der&sslrootcert=/Users/jimhatcher/spark-cluster-certs/ca.crt&reWriteBatchedInserts=true"


val sql = "( select id from test.test_table where f1 IS NULL) t1"
val df = spark.read
  .format("jdbc")
  .option("url", url)
  .option("dbtable", sql)
  .option("user", "root")
  .option("partitionColumn", "id")
  .option("lowerBound", "1")
  .option("upperBound", "1000000")
  .option("numPartitions", "10")
  .load

df.coalesce(10).foreachPartition(
  (partition: Iterator[Row]) => {

    val dbc: Connection = DriverManager.getConnection(url, "root", "")
    val batchSize = 1000
    val st: PreparedStatement = dbc.prepareStatement("UPDATE test_table SET f1 = data_table.new_f1 FROM ( SELECT UNNEST(?) AS id, UNNEST(?) AS new_f1 ) AS data_table WHERE test_table.id = data_table.id")

    partition.grouped(batchSize).foreach (
      batch => {
        var ids = new ListBuffer[BigInt];
        var fl1s = new ListBuffer[BigInt];
        batch.foreach (
          row => {
            ids += row.getLong(0);
            fl1s += 1;
          }
        )
        st.setArray(1, st.getConnection.createArrayOf("BIGINT", ids.toList.toArray));
        st.setArray(2, st.getConnection.createArrayOf("BIGINT", fl1s.toList.toArray));
        st.execute()
      }
    )

    dbc.close()

  }
)

Enter fullscreen mode Exit fullscreen mode

Using this job, I was able to UPDATE 1 million records in 3.6sec instead of ~40sec.

Note: It's generally considered bad form in Scala to use mutable lists in the way I'm doing above. I could probably re-write my batch loop above using for comprehensions and make this a little more Scala-acceptable. But, hey, that will give me something to write about in my next Spark optimization blog!

Top comments (0)