Spark Scala UDF Primitive Type Bug

Posted in: Big Data, Technical Track

I was working on an instrumentation framework for Scala UDFs in Spark when I noticed a subtle difference in the execution plan depending on whether I used wrappers or not. It looked like some code was added or was not predicate to check nulls:

val f = (x: Long) => x
val udf0 = udf(f)
...
  .withColumn("udf0", udf0(...))
...
// in explain
if (isnull(...)) null else UDF(...) AS udf0#111L

vs

def identity[T, U](f: T => U): T => U = (t: T) => f(t)
val udf1 = udf(identity(f))
...
  .withColumn("udf1", udf1(...))
...
// in explain
UDF(...) AS udf1#115L

Quick doc checking sheds light on the special case of UDFs based on functions with primitive input arguments:

Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.

In my case I have no really changed types, but I used high order function, something like this:

val f = (x: Long) => x
def identity[T, U](f: T => U): T => U = (t: T) => f(t)
val udf0 = udf(f)
val udf1 = udf(identity(f))

Both udf0 and udf1 look pretty the same at first sight:

scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
identity: [T, U](f: T => U)T => U

scala> val udf0 = udf(f)
udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

scala> val udf1 = udf(identity(f))
udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))

While during the execution they worked differently for null input:

scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List()))

scala> spark.range(5).toDF().
     |   withColumn("udf0", udf0(getNull())).
     |   withColumn("udf1", udf1(getNull())).
     |   show()
+---+----+----+
| id|udf0|udf1|
+---+----+----+
|  0|null|   0|
|  1|null|   0|
|  2|null|   0|
|  3|null|   0|
|  4|null|   0|
+---+----+----+

scala> spark.range(5).toDF().
     |   withColumn("udf0", udf0(getNull())).
     |   withColumn("udf1", udf1(getNull())).
     |   explain()  
== Physical Plan ==
*Project [id#106L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#111L, UDF(UDF()) AS udf1#115L]
+- *Range (0, 5, step=1, splits=2)

I tracked why this happen through Spark sources:

  • udf
      def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
        val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption
        UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes)
      }
    
  • UserDefinedFunction
    case class UserDefinedFunction protected (
        f: AnyRef,
        dataType: DataType,
        inputTypes: Option[Seq[DataType]]) {
    ...
      def apply(exprs: Column*): Column = {
        Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil)))
      }
    }
    
  • ScalaUDF
    case class ScalaUDF(
        function: AnyRef,
        dataType: DataType,
        children: Seq[Expression],
        inputTypes: Seq[DataType] = Nil,
        udfName: Option[String] = None)
      extends Expression with ImplicitCastInputTypes with NonSQLExpression {
    ...
    
  • HandleNullInputsForUDF from Catalyst Analyzer (TODO from this piece explained the fact of mess with nullability, it simply doesn’t work when I would expect it does):
      /**
       * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
       * null check.  When user defines a UDF with primitive parameters, there is no way to tell if the
       * primitive parameter is null or not, so here we assume the primitive input is null-propagatable
       * and we should return null if the input is null.
       */
      object HandleNullInputsForUDF extends Rule[LogicalPlan] {
        override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
          case p if !p.resolved => p // Skip unresolved nodes.
    
          case p => p transformExpressionsUp {
    
            case udf @ ScalaUDF(func, _, inputs, _, _) =>
              val parameterTypes = ScalaReflection.getParameterTypes(func)
              assert(parameterTypes.length == inputs.length)
    
              val inputsNullCheck = parameterTypes.zip(inputs)
                // TODO: skip null handling for not-nullable primitive inputs after we can completely
                // trust the `nullable` information.
                // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
                .filter { case (cls, _) => cls.isPrimitive }
                .map { case (_, expr) => IsNull(expr) }
                .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
              inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
          }
        }
      }
    
  • And final piece
      def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
        val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
        assert(methods.length == 1)
        methods.head.getParameterTypes
      }
    
  • As you can see, it uses java runtime class information, and it’s no surprise “isPrimitive” does’t work the way we would expect due to the type erasure. In this case that is:

    scala> ScalaReflection.getParameterTypes(f)
    res1: Seq[Class[_]] = WrappedArray(long)
    
    scala> ScalaReflection.getParameterTypes(identity(f))
    res2: Seq[Class[_]] = WrappedArray(class java.lang.Object)
    

    Instead it should use TypeTag we have in udf declaration, like this:

    scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = {
         |   typeTag[T].tpe.typeSymbol.asClass
         | }
    myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol
    
    scala> myGetParameterTypes(f)
    res3: reflect.runtime.universe.ClassSymbol = class Long
    
    scala> myGetParameterTypes(f).isPrimitive
    res4: Boolean = true
    

    The workaround is quite ugly though, it is to use specialization:

    scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => f(t)
    identity2: [T, U](f: T => U)T => U
    
    scala> val udf2 = udf(identity2(f))
    udf2: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
    
    scala> ScalaReflection.getParameterTypes(identity2(f))
    res10: Seq[Class[_]] = WrappedArray(long)
    

    As result I submitted Spark Jira issue SPARK-23833

    Be careful when using udf operating primitive types if nullable data can be passed to it. There are many possible scenarios when behavior may be different. It should be a rule that: if nullable data can be passed then you have to use boxed types or Option.

    email

    Interested in working with Valentin? Schedule a tech call.

About the Author

Valentin is a specialist in Big Data and Cloud solutions. He has extensive expertise in Cloudera Hadoop Distribution, Google Cloud Platform and skilled in building scalable performance critical distributed systems and data visualization systems.

No comments

Leave a Reply

Your email address will not be published. Required fields are marked *