scala – 在Spark SQL中,如何注册和使用通用UDF?

在我的Project中,我想实现ADD()函数,但我的参数可能是LongType,DoubleType,IntType.我使用sqlContext.udf.register(“add”,XXX),但我不知道如何编写XXX,这是制作泛型函数.
最佳答案
您可以通过创建包含struct($“col1”,$“col2”)的StructType来创建一个通用UDF,该struct保存您的值并使您的UDF不受此影响.它作为Row对象传递到UDF中,因此您可以执行以下操作:

val multiAdd = udf[Double,Row](r => {
  var n = 0.0
  r.toSeq.foreach(n1 => n = n + (n1 match {
    case l: Long => l.toDouble
    case i: Int => i.toDouble
    case d: Double => d
    case f: Float => f.toDouble
  }))
  n
})

val df = Seq((1.0,2),(3.0,4)).toDF("c1","c2")
df.withColumn("add", multiAdd(struct($"c1", $"c2"))).show
+---+---+---+
| c1| c2|add|
+---+---+---+
|1.0|  2|3.0|
|3.0|  4|7.0|
+---+---+---+

您甚至可以做一些有趣的事情,例如将可变数量的列作为输入.实际上,我们上面定义的UDF已经做到了:

val df = Seq((1, 2L, 3.0f,4.0),(5, 6L, 7.0f,8.0)).toDF("int","long","float","double")

df.printSchema
root
 |-- int: integer (nullable = false)
 |-- long: long (nullable = false)
 |-- float: float (nullable = false)
 |-- double: double (nullable = false)

df.withColumn("add", multiAdd(struct($"int", $"long", $"float", $"double"))).show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
|  1|   2|  3.0|   4.0|10.0|
|  5|   6|  7.0|   8.0|26.0|
+---+----+-----+------+----+

您甚至可以在混音中添加硬编码的数字:

df.withColumn("add", multiAdd(struct(lit(100), $"int", $"long"))).show
+---+----+-----+------+-----+
|int|long|float|double|  add|
+---+----+-----+------+-----+
|  1|   2|  3.0|   4.0|103.0|
|  5|   6|  7.0|   8.0|111.0|
+---+----+-----+------+-----+

如果要在SQL语法中使用UDF,可以执行以下操作:

sqlContext.udf.register("multiAdd", (r: Row) => {
  var n = 0.0
  r.toSeq.foreach(n1 => n = n + (n1 match {
    case l: Long => l.toDouble
    case i: Int => i.toDouble
    case d: Double => d
    case f: Float => f.toDouble
  }))
  n
})
df.registerTempTable("df")

//  Note that 'int' and 'long' are column names
sqlContext.sql("SELECT *, multiAdd(struct(int, long)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
|  1|   2|  3.0|   4.0| 3.0|
|  5|   6|  7.0|   8.0|11.0|
+---+----+-----+------+----+

这也有效:

sqlContext.sql("SELECT *, multiAdd(struct(*)) as add from df").show
+---+----+-----+------+----+
|int|long|float|double| add|
+---+----+-----+------+----+
|  1|   2|  3.0|   4.0|10.0|
|  5|   6|  7.0|   8.0|26.0|
+---+----+-----+------+----+

转载注明原文:scala – 在Spark SQL中,如何注册和使用通用UDF? - 代码日志