Memoization is a powerful technique that allows you to improve performance of repeatable computations. Although it would be a pretty handy feature, there is no memoization or result cache for UDFs in Spark as of today.
In fact it’s something we can easily implement.
All examples below are in Scala.
The problem
Imagine we have a relatively expensive function
spark.udf.register("expensive", udf((x: Int) => { Thread.sleep(1); 1 }))
And assume this function needs to be executed many times for a small set of arguments:
spark.range(50000).toDF() .withColumn("parent_id", col("id").mod(100)) .repartition(col("parent_id")) .createTempView("myTable") spark.sql("select id, expensive(parent_id) as hostname from myTable")
Let’s run some tests. I modified this function to increment invocation counting accumulator.
I executed the test on a small dataproc cluster (Spark 2.2.0)
scala> import java.net.InetAddress import java.net.InetAddress scala> val invocations = spark.sparkContext.longAccumulator("invocations") invocations: org.apache.spark.util.LongAccumulator = LongAccumulator(id: 0, name: Some(invocations), value: 0) scala> def timing[T](body: => T): T = { | val t0 = System.nanoTime() | invocations.reset() | val res = body | val t1 = System.nanoTime() | println(s"invocations=${invocations.value}, time=${(t1 - t0) / 1e9}") | res | } timing: [T](body: => T)T scala> def expensive(n: Int) = { | Thread.sleep(1) | invocations.add(1) | 1 | } expensive: (n: Int)Int scala> spark.udf.register("expensive", udf((x: Int) => expensive(x))) res0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType))) scala> spark.range(50000).toDF(). | withColumn("parent_id", col("id").mod(100)). | createTempView("myTable") scala> spark.sql("select id, expensive(parent_id) as one from myTable").createTempView("expensive_table") scala> timing(spark.sql("select sum(one) from expensive_table").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=50000, time=9.493999374
The sum (one) expensive function was called 50000 times and the job took around 10 seconds.
Simple memoization
How can we improve this timing? We can memoize function results with following simple code:
def memo[T, U](f: T => U): T => U = { lazy val cache = new ConcurrentHashMap[T, U]() (t: T) => cache.computeIfAbsent(t, new JF[T, U] { def apply(t: T): U = f(t) }) } spark.udf.register("memoized", udf(memo((x: Int) => expensive(x))))
It uses lazy val from closure so there is an instance of cache for udf instance.
Let’s run more tests!
scala> import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap scala> import java.util.function.{Function => JF} import java.util.function.{Function=>JF} scala> implicit def toJF[T, U](f: T => U): JF[T, U] = new JF[T, U] { | def apply(t: T): U = f(t) | } warning: there was one feature warning; re-run with -feature for details toJF: [T, U](f: T => U)java.util.function.Function[T,U] scala> def memo[T, U](f: T => U): T => U = { | lazy val cache = new ConcurrentHashMap[T, U]() | (t: T) => cache.computeIfAbsent(t, f) | } memo: [T, U](f: T => U)T => U scala> spark.udf.register("memoized", udf(memo((x: Int) => expensive(x)))) res4: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(IntegerType))) scala> spark.sql("select id, memoized(parent_id) as one from myTable").createTempView("memoized_table") scala> timing(spark.sql("select sum(one) from memoized_table").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=600, time=0.378553505
This time we see that the function was invoked just 600 times which gave us 0.4 sec instead of almost 10!
Filtering and UDFs
Another pain is the filtering of columns based on udf. Let’s modify our example a little bit and add a few filters.
How many times will UDF be invoked here?
scala> timing(spark.sql("select sum(one) from expensive_table where one * one = one ").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=200000, time=37.449562222
It was executed four times for each input row, 200000 total and took 37 seconds! That’s too much. What if we use a memoized version?
scala> timing(spark.sql("select sum(one) from memoized_table where one * one = one ").show(truncate = false)) +--------+ |sum(one)| +--------+ |50000 | +--------+ invocations=600, time=0.34141222
Really no difference by comparison, the same counts and timings.
Next steps
This is just a simple example and is meant for production use – you would have to think about possible problems such as:
- ConcurrentHashMap can’t store null keys nor values, so you would need to wrap them into Option for example;
- What if there are not 100 but 100M of different arguments? How should we limit cache size, what strategy should we use: keep last, keep first, most used, lru, …?
- What if invocation of the function may take a long, unknown length of time or hang? computeIfAbsent will be blocked for the same keys (even if the next invocation of the same function can be instant). Should it be bypassed after some reasonable timeout or should you use optimistic lock strategy?
5 Comments. Leave new
Very good article,
is there a way to do the same for an udf with a function that takes multiple parameters.
if I try your code I get: Type mismatch: expectedT => NotInferedU, actual (String, Double, Double) => String
My function is :
val func= (dep: String, x: Double, y: Double) => {
// returns a String
}
You should use smth like:
def memo[T1, T2, T3, U](f: (T1, T2, T3) => U): (T1, T2, T3) => U = {
lazy val cache = new ConcurrentHashMap[(T1, T2, T3), U]()
(t1: T1, t2: T2, t3: T3) => cache.computeIfAbsent((t1, t2, t3), new JF[(T1, T2, T3), U] {
def apply(t: (T1, T2, T3)): U = f(t._1, t._2, t._3)
})
}
Thank you,
Is this solution tested in a distributed spark environement ? ie not in standalone or pseudo-distributed modes
Is this solution tested in a distributed spark environement ? ie not in standalone or pseudo-distributed modes? How does this work when there are multiple executors?
Is this solution tested in a distributed spark environment ? ie not in standalone or pseudo-distributed modes? How does this work when there are multiple executors?