Spark UDF memoization

Posted in: Big Data, Hadoop, Technical Track

Memoization is a powerful technique that allows you to improve performance of repeatable computations. Although it would be a pretty handy feature, there is no memoization or result cache for UDFs in Spark as of today.
In fact it’s something we can easily implement.
All examples below are in Scala.

The problem

Imagine we have a relatively expensive function

spark.udf.register("expensive", udf((x: Int) => { Thread.sleep(1); 1 }))

And assume this function needs to be executed many times for a small set of arguments:

  .withColumn("parent_id", col("id").mod(100))
spark.sql("select id, expensive(parent_id) as hostname from myTable")

Let’s run some tests. I modified this function to increment invocation counting accumulator.
I executed the test on a small dataproc cluster (Spark 2.2.0)

scala> import
scala> val invocations = spark.sparkContext.longAccumulator("invocations")
invocations: org.apache.spark.util.LongAccumulator = LongAccumulator(id: 0, name: Some(invocations), value: 0)
scala> def timing[T](body: => T): T = {
     |   val t0 = System.nanoTime()
     |   invocations.reset()
     |   val res = body
     |   val t1 = System.nanoTime()
     |   println(s"invocations=${invocations.value}, time=${(t1 - t0) / 1e9}")
     |   res
     | }
timing: [T](body: => T)T
scala> def expensive(n: Int) = {
     |   Thread.sleep(1)
     |   invocations.add(1)
     |   1
     | }
expensive: (n: Int)Int
scala> spark.udf.register("expensive", udf((x: Int) => expensive(x)))
res0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType)))
scala> spark.range(50000).toDF().
     |   withColumn("parent_id", col("id").mod(100)).
     |   createTempView("myTable")
scala> spark.sql("select id, expensive(parent_id) as one from myTable").createTempView("expensive_table")
scala> timing(spark.sql("select sum(one) from expensive_table").show(truncate = false))
|50000   |
invocations=50000, time=9.493999374

The sum (one) expensive function was called 50000 times and the job took around 10 seconds.

Simple memoization

How can we improve this timing? We can memoize function results with following simple code:

def memo[T, U](f: T => U): T => U = {
  lazy val cache = new ConcurrentHashMap[T, U]()
  (t: T) => cache.computeIfAbsent(t, new JF[T, U] {
    def apply(t: T): U = f(t)
spark.udf.register("memoized", udf(memo((x: Int) => expensive(x))))

It uses lazy val from closure so there is an instance of cache for udf instance.

Let’s run more tests!

scala> import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentHashMap
scala> import java.util.function.{Function => JF}
import java.util.function.{Function=>JF}
scala> implicit def toJF[T, U](f: T => U): JF[T, U] = new JF[T, U] {
     |   def apply(t: T): U = f(t)
     | }
warning: there was one feature warning; re-run with -feature for details
toJF: [T, U](f: T => U)java.util.function.Function[T,U]
scala> def memo[T, U](f: T => U): T => U = {
     |   lazy val cache = new ConcurrentHashMap[T, U]()
     |   (t: T) => cache.computeIfAbsent(t, f)
     | }
memo: [T, U](f: T => U)T => U
scala> spark.udf.register("memoized", udf(memo((x: Int) => expensive(x))))
res4: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType)))
scala> spark.sql("select id, memoized(parent_id) as one from myTable").createTempView("memoized_table")
scala> timing(spark.sql("select sum(one) from memoized_table").show(truncate = false))
|50000   |
invocations=600, time=0.378553505

This time we see that the function was invoked just 600 times which gave us 0.4 sec instead of almost 10!

Filtering and UDFs

Another pain is the filtering of columns based on udf. Let’s modify our example a little bit and add a few filters.
How many times will UDF be invoked here?

scala> timing(spark.sql("select sum(one) from expensive_table where one * one = one ").show(truncate = false))
|50000   |
invocations=200000, time=37.449562222

It was executed four times for each input row, 200000 total and took 37 seconds! That’s too much. What if we use a memoized version?

scala> timing(spark.sql("select sum(one) from memoized_table where one * one = one ").show(truncate = false))
|50000   |
invocations=600, time=0.34141222

Really no difference by comparison, the same counts and timings.

Next steps

This is just a simple example and is meant for production use – you would have to think about possible problems such as:

  • ConcurrentHashMap can’t store null keys nor values, so you would need to wrap them into Option for example;
  • What if there are not 100 but 100M of different arguments? How should we limit cache size, what strategy should we use: keep last, keep first, most used, lru, …?
  • What if invocation of the function may take a long, unknown length of time or hang? computeIfAbsent will be blocked for the same keys (even if the next invocation of the same function can be instant). Should it be bypassed after some reasonable timeout or should you use optimistic lock strategy?
Want to talk with an expert? Schedule a call with our team to get the conversation started.

About the Author

Valentin is a specialist in Big Data and Cloud solutions. He has extensive expertise in Cloudera Hadoop Distribution, Google Cloud Platform and skilled in building scalable performance critical distributed systems and data visualization systems.

5 Comments. Leave new

Very good article,

is there a way to do the same for an udf with a function that takes multiple parameters.
if I try your code I get: Type mismatch: expectedT => NotInferedU, actual (String, Double, Double) => String
My function is :
val func= (dep: String, x: Double, y: Double) => {
// returns a String

Valentin Nikotin
July 17, 2018 8:47 am

You should use smth like:

def memo[T1, T2, T3, U](f: (T1, T2, T3) => U): (T1, T2, T3) => U = {
lazy val cache = new ConcurrentHashMap[(T1, T2, T3), U]()
(t1: T1, t2: T2, t3: T3) => cache.computeIfAbsent((t1, t2, t3), new JF[(T1, T2, T3), U] {
def apply(t: (T1, T2, T3)): U = f(t._1, t._2, t._3)


Thank you,

Is this solution tested in a distributed spark environement ? ie not in standalone or pseudo-distributed modes


Is this solution tested in a distributed spark environement ? ie not in standalone or pseudo-distributed modes? How does this work when there are multiple executors?


Is this solution tested in a distributed spark environment ? ie not in standalone or pseudo-distributed modes? How does this work when there are multiple executors?


Leave a Reply

Your email address will not be published. Required fields are marked *