Spark Scala UDF primitive type bug

Posted in: Big Data, Technical Track

I was working on an instrumentation framework for Scala UDFs in Spark when I noticed a subtle difference in the execution plan depending on whether I used wrappers or not. It looked like some code was added or was not predicate to check nulls:

val f = (x: Long) => x
val udf0 = udf(f)
...
  .withColumn("udf0", udf0(...))
...
// in explain
if (isnull(...)) null else UDF(...) AS udf0#111L

vs

def identity[T, U](f: T => U): T => U = (t: T) => f(t)
val udf1 = udf(identity(f))
...
  .withColumn("udf1", udf1(...))
...
// in explain
UDF(...) AS udf1#115L

Quick doc checking sheds light on the special case of UDFs based on functions with primitive input arguments:

Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.

In my case I have no really changed types, but I used high order function, something like this:

val f = (x: Long) => x
def identity[T, U](f: T => U): T => U = (t: T) => f(t)
val udf0 = udf(f)
val udf1 = udf(identity(f))

Both udf0 and udf1 look pretty the same at first sight:

scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
identity: [T, U](f: T => U)T => U
scala> val udf0 = udf(f)
udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
scala> val udf1 = udf(identity(f))
udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

While during the execution they worked differently for null input:

scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List()))
scala> spark.range(5).toDF().
     |   withColumn("udf0", udf0(getNull())).
     |   withColumn("udf1", udf1(getNull())).
     |   show()
+---+----+----+
| id|udf0|udf1|
+---+----+----+
|  0|null|   0|
|  1|null|   0|
|  2|null|   0|
|  3|null|   0|
|  4|null|   0|
+---+----+----+
scala> spark.range(5).toDF().
     |   withColumn("udf0", udf0(getNull())).
     |   withColumn("udf1", udf1(getNull())).
     |   explain()
== Physical Plan ==
*Project [id#106L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#111L, UDF(UDF()) AS udf1#115L]
+- *Range (0, 5, step=1, splits=2)

I tracked why this happen through Spark sources:

    • udf
        def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
          val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption
          UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes)
        }
      
    • UserDefinedFunction
      case class UserDefinedFunction protected[sql] (
          f: AnyRef,
          dataType: DataType,
          inputTypes: Option[Seq[DataType]]) {
      ...
        def apply(exprs: Column*): Column = {
          Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil)))
        }
      }
      
    • ScalaUDF
      case class ScalaUDF(
          function: AnyRef,
          dataType: DataType,
          children: Seq[Expression],
          inputTypes: Seq[DataType] = Nil,
          udfName: Option[String] = None)
        extends Expression with ImplicitCastInputTypes with NonSQLExpression {
      ...
      
    • HandleNullInputsForUDF from Catalyst Analyzer (TODO from this piece explained the fact of mess with nullability, it simply doesn’t work when I would expect it does):
        /**
         * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
         * null check.  When user defines a UDF with primitive parameters, there is no way to tell if the
         * primitive parameter is null or not, so here we assume the primitive input is null-propagatable
         * and we should return null if the input is null.
         */
        object HandleNullInputsForUDF extends Rule[LogicalPlan] {
          override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
            case p if !p.resolved => p // Skip unresolved nodes.
            case p => p transformExpressionsUp {
              case udf @ ScalaUDF(func, _, inputs, _, _) =>
                val parameterTypes = ScalaReflection.getParameterTypes(func)
                assert(parameterTypes.length == inputs.length)
                val inputsNullCheck = parameterTypes.zip(inputs)
                  // TODO: skip null handling for not-nullable primitive inputs after we can completely
                  // trust the `nullable` information.
                  // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
                  .filter { case (cls, _) => cls.isPrimitive }
                  .map { case (_, expr) => IsNull(expr) }
                  .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
                inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
            }
          }
        }
      
    • And final piece
        def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
          val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
          assert(methods.length == 1)
          methods.head.getParameterTypes
        }
      

As you can see, it uses java runtime class information, and it’s no surprise “isPrimitive” does’t work the way we would expect due to the type erasure. In this case that is:

scala> ScalaReflection.getParameterTypes(f)
res1: Seq[Class[_]] = WrappedArray(long)
scala> ScalaReflection.getParameterTypes(identity(f))
res2: Seq[Class[_]] = WrappedArray(class java.lang.Object)

Instead it should use TypeTag we have in udf declaration, like this:

scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = {
     |   typeTag[T].tpe.typeSymbol.asClass
     | }
myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol
scala> myGetParameterTypes(f)
res3: reflect.runtime.universe.ClassSymbol = class Long
scala> myGetParameterTypes(f).isPrimitive
res4: Boolean = true

The workaround is quite ugly though, it is to use specialization:

scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => f(t)
identity2: [T, U](f: T => U)T => U
scala> val udf2 = udf(identity2(f))
udf2: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
scala> ScalaReflection.getParameterTypes(identity2(f))
res10: Seq[Class[_]] = WrappedArray(long)

As result I submitted Spark Jira issue SPARK-23833

Be careful when using udf operating primitive types if nullable data can be passed to it. There are many possible scenarios when behavior may be different. It should be a rule that: if nullable data can be passed then you have to use boxed types or Option.

email
Want to talk with an expert? Schedule a call with our team to get the conversation started.

About the Author

Valentin is a specialist in Big Data and Cloud solutions. He has extensive expertise in Cloudera Hadoop Distribution, Google Cloud Platform and skilled in building scalable performance critical distributed systems and data visualization systems.

No comments

Leave a Reply

Your email address will not be published. Required fields are marked *