Användardefinierade aggregeringsfunktioner (UDAF:er)

Gäller för:check markerad ja Databricks Runtime

Användardefinierade aggregeringsfunktioner (UDAF: er) är användarprogrammabla rutiner som fungerar på flera rader samtidigt och returnerar ett enda aggregerat värde som ett resultat. Den här dokumentationen visar de klasser som krävs för att skapa och registrera UDAF:er. Den innehåller också exempel som visar hur du definierar och registrerar UDAF:er i Scala och anropar dem i Spark SQL.

Nyhetsläsare

SyntaxAggregator[-IN, BUF, OUT]

En basklass för användardefinierade sammansättningar, som kan användas i Datauppsättningsåtgärder för att ta alla element i en grupp och minska dem till ett enda värde.

  • IN: Indatatypen för aggregeringen.

  • BUF: Typen av det mellanliggande värdet för minskningen.

  • OUT: Typen av slutresultatet.

  • bufferEncoder: Encoder[BUF]

    Kodaren för den mellanliggande värdetypen.

  • finish(reduction: BUF): OUT

    Transformera utdata från minskningen.

  • merge(b1: BUF, b2: BUF): BUF

    Sammanfoga två mellanliggande värden.

  • outputEncoder: Encoder[OUT]

    Kodaren för den slutliga utdatavärdetypen.

  • reduce(b: BUF, a: IN): BUF

    Aggregera indatavärdet a till aktuellt mellanliggande värde. För prestanda kan funktionen ändra b och returnera den i stället för att skapa ett nytt objekt för b.

  • noll: BUF

    Det initiala värdet för det mellanliggande resultatet för den här aggregeringen.

Exempel

Typsäkra användardefinierade mängdfunktioner

Användardefinierade aggregeringar för starkt skrivna datauppsättningar kretsar kring den Aggregator abstrakta klassen. Ett typsäkert användardefinierat genomsnitt kan till exempel se ut så här:

Otypade användardefinierade mängdfunktioner

Inskrivna aggregeringar, enligt beskrivningen ovan, kan också registreras som otypade sammansättnings-UDF:er för användning med DataFrames. Ett användardefinierat genomsnitt för otypade DataFrames kan till exempel se ut så här:

Scala

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions

case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Long, Average, Double] {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  def zero: Average = Average(0L, 0L)
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  def reduce(buffer: Average, data: Long): Average = {
    buffer.sum += data
    buffer.count += 1
    buffer
  }
  // Merge two intermediate values
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  // Transform the output of the reduction
  def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  // The Encoder for the intermediate value type
  def bufferEncoder: Encoder[Average] = Encoders.product
  // The Encoder for the final output value type
  def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// Register the function to access it
spark.udf.register("myAverage", functions.udaf(MyAverage))

val df = spark.read.format("json").load("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

Java

import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;

public static class Average implements Serializable  {
    private long sum;
    private long count;

    // Constructors, getters, setters...

}

public static class MyAverage extends Aggregator<Long, Average, Double> {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  public Average zero() {
    return new Average(0L, 0L);
  }
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  public Average reduce(Average buffer, Long data) {
    long newSum = buffer.getSum() + data;
    long newCount = buffer.getCount() + 1;
    buffer.setSum(newSum);
    buffer.setCount(newCount);
    return buffer;
  }
  // Merge two intermediate values
  public Average merge(Average b1, Average b2) {
    long mergedSum = b1.getSum() + b2.getSum();
    long mergedCount = b1.getCount() + b2.getCount();
    b1.setSum(mergedSum);
    b1.setCount(mergedCount);
    return b1;
  }
  // Transform the output of the reduction
  public Double finish(Average reduction) {
    return ((double) reduction.getSum()) / reduction.getCount();
  }
  // The Encoder for the intermediate value type
  public Encoder<Average> bufferEncoder() {
    return Encoders.bean(Average.class);
  }
  // The Encoder for the final output value type
  public Encoder<Double> outputEncoder() {
    return Encoders.DOUBLE();
  }
}

// Register the function to access it
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));

Dataset<Row> df = spark.read().format("json").load("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

SQL

-- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.
CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar';

SHOW USER FUNCTIONS;
+------------------+
|          function|
+------------------+
| default.myAverage|
+------------------+

CREATE TEMPORARY VIEW employees
USING org.apache.spark.sql.json
OPTIONS (
    path "examples/src/main/resources/employees.json"
);

SELECT * FROM employees;
+-------+------+
|   name|salary|
+-------+------+
|Michael|  3000|
|   Andy|  4500|
| Justin|  3500|
|  Berta|  4000|
+-------+------+

SELECT myAverage(salary) as average_salary FROM employees;
+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+