次の方法で共有


Python ユーザー定義テーブル関数 (UDF)

Von Bedeutung

この機能は、Databricks Runtime 14.3 LTS 以降の パブリック プレビュー 段階にあります。

ユーザー定義テーブル関数 (UDTF) を使用すると、スカラー値の代わりにテーブルを返す関数を登録できます。 各呼び出しから 1 つの結果値を返すスカラー関数とは異なり、各 UDTF は SQL ステートメントの FROM 句で呼び出され、テーブル全体を出力として返します。

各 UDTF 呼び出しでは、0 個以上の引数を受け取ることができます。 これらの引数には、入力テーブル全体を表すスカラー式またはテーブル引数を指定できます。

基本的な UDTF 構文

Apache Spark は Python UDF を Python クラスとして実装し、evalを使用して出力行を出力する必須のyieldメソッドを使用します。

クラスを UDTF として使用するには、PySpark udtf 関数をインポートする必要があります。 Databricks では、この関数をデコレーターとして使用し、 returnType オプションを使用してフィールド名と型を明示的に指定することをお勧めします (後のセクションで説明するようにクラスで analyze メソッドが定義されていない場合)。

次の UDTF は、2 つの整数引数の固定リストを使用してテーブルを作成します。

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

UDTF を登録する

UDTF はローカル SparkSession に登録され、ノートブックまたはジョブレベルで分離されます。

UDF を Unity カタログのオブジェクトとして登録することはできず、UDF は SQL ウェアハウスで使用できません。

関数SparkSessionで SQL クエリで使用するために、UDTF を現在のspark.udtf.register()に登録できます。 SQL 関数と Python UDTF クラスの名前を指定します。

spark.udtf.register("get_sum_diff", GetSumDiff)

登録済みの UDTF を呼び出す

登録したら、 %sql マジック コマンドまたは spark.sql() 関数を使用して、SQL で UDTF を使用できます。

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);").show()
%sql
SELECT * FROM get_sum_diff(1,2);

Apache Arrow を使用する

UDTF が入力として少量のデータを受信し、大きなテーブルを出力する場合、Databricks では Apache Arrow を使用することをお勧めします。 これを有効にするには、UDTF を宣言するときに useArrow パラメーターを指定します。

@udtf(returnType="c1: int, c2: int", useArrow=True)

可変引数リスト - *args および **kwargs

Python *args または **kwargs 構文を使用し、指定されていない数の入力値を処理するロジックを実装できます。

次の例では、引数の入力長と型を明示的に確認しながら、同じ結果を返します。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

同じ例を次に示しますが、キーワード引数を使用します。

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

登録時に静的スキーマを定義する

UDTF は、列の名前と型の順序付けられたシーケンスで構成される出力スキーマを持つ行を返します。 UDTF スキーマが常にすべてのクエリで同じままである必要がある場合は、 @udtf デコレーターの後に静的な固定スキーマを指定できます。 StructTypeのどちらかである必要があります。

StructType().add("c1", StringType())

または、構造体型を表す DDL 文字列。

c1: string

関数呼び出し時に動的スキーマを計算する

UDF は、入力引数の値に応じて、呼び出しごとにプログラムによって出力スキーマを計算することもできます。 これを行うには、特定の UDTF 呼び出しに指定された引数に対応する 0 個以上のパラメーターを受け取る、 analyze という静的メソッドを定義します。

analyze メソッドの各引数は、次のフィールドを含むAnalyzeArgument クラスのインスタンスです。

AnalyzeArgument クラス フィールド 説明
dataType DataTypeとしての入力引数の型。 入力テーブルの引数の場合、これはテーブルの列を表す StructType です。
value Optional[Any]としての入力引数の値。 これは、定数ではないテーブル引数またはリテラル スカラー引数に対して None です。
isTable 入力引数が BooleanTypeとしてのテーブルであるかどうかを示します。
isConstantExpression 入力引数が、 BooleanTypeとして定数で折りたたみ可能な式であるかどうかを示します。

analyze メソッドは、AnalyzeResult クラスのインスタンスを返します。これには、結果テーブルのスキーマがStructTypeと省略可能なフィールドが含まれます。 UDTF が入力テーブル引数を受け入れる場合、 AnalyzeResult には、後で説明するように、複数の UDTF 呼び出しで入力テーブルの行をパーティション分割して並べ替える要求された方法を含めることもできます。

AnalyzeResult クラス フィールド 説明
schema StructTypeとしての結果テーブルのスキーマ。
withSinglePartition すべての入力行を、 BooleanTypeと同じ UDTF クラス インスタンスに送信するかどうか。
partitionBy 非空に設定されている場合、パーティション分割式の値のユニークな組み合わせごとに、UDTF クラスの異なるインスタンスによってすべての行が処理されます。
orderBy 空以外に設定すると、各パーティション内の行の順序が指定されます。
select 設定が空以外の場合、これは UDTF が入力 TABLE 引数の列に対して Catalyst に評価させるよう一連の式を指定するものです。 UDTF は、リスト内の名前ごとに 1 つの入力属性を一覧表示順に受け取ります。

次の analyze 例では、入力文字列引数の単語ごとに 1 つの出力列を返します。

from pyspark.sql.functions import lit, udtf
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult


@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result

MyUDTF(lit("hello world")).columns
['word_0', 'word_1']

将来の eval 呼び出しに状態を転送する

analyze メソッドは、初期化を実行し、同じ UDTF 呼び出しに対する将来の eval メソッド呼び出しに結果を転送するのに便利な場所として機能します。

これを行うには、 AnalyzeResult のサブクラスを作成し、 analyze メソッドからサブクラスのインスタンスを返します。 次に、 __init__ メソッドに追加の引数を追加して、そのインスタンスを受け入れます。

次の analyze 例では、定数出力スキーマを返しますが、結果メタデータにカスタム情報を追加して、将来の __init__ メソッド呼び出しで使用できるようにします。

from pyspark.sql.functions import lit, udtf
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

出力行を生成する

eval メソッドは、入力テーブル引数の行ごとに 1 回 (またはテーブル引数が指定されていない場合は 1 回だけ) 実行され、最後に terminate メソッドが 1 回呼び出されます。 いずれかのメソッドは、タプル、リスト、または pyspark.sql.Row オブジェクトを生成することによって、結果スキーマに準拠する 0 行以上の行を出力します。

次の例では、次の 3 つの要素のタプルを指定して行を返します。

def eval(self, x, y, z):
  yield (x, y, z)

かっこは省略することもできます。

def eval(self, x, y, z):
  yield x, y, z

列を 1 つだけ含む行を返すために、末尾のコンマを追加します。

def eval(self, x, y, z):
  yield x,

pyspark.sql.Row オブジェクトを生成することもできます。

def eval(self, x, y, z)
  from pyspark.sql.types import Row
  yield Row(x, y, z)

この例では、Python リストを使用して、 terminate メソッドから出力行を生成します。 この目的のために、UDTF 評価の前の手順からクラス内に状態を格納できます。

def terminate(self):
  yield [self.x, self.y, self.z]

UDTF にスカラー引数を渡す

スカラー引数は、それらに基づいてリテラル値または関数を構成する定数式として UDTF に渡すことができます。 例えば次が挙げられます。

SELECT * FROM get_sum_diff(1, y => 2)

UDTF にテーブル引数を渡す

Python UDF では、スカラー入力引数に加えて、入力テーブルを引数として受け入れることもできます。 1 つの UDTF で、テーブル引数と複数のスカラー引数を受け入れることもできます。

その後、任意の SQL クエリで、 TABLE キーワードを使用し、その後に適切なテーブル識別子を囲むかっこ ( TABLE(t)など) を使用して入力テーブルを提供できます。 または、 TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))などのテーブル サブクエリを渡すことができます。

入力テーブル引数は、pyspark.sql.Row メソッドのeval引数として表され、入力テーブルの各行に対して eval メソッドが 1 回呼び出されます。 標準の PySpark 列フィールド注釈を使用して、各行の列を操作できます。 次の例では、PySpark Row 型を明示的にインポートし、 id フィールドで渡されたテーブルをフィルター処理する方法を示します。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

関数に対してクエリを実行するには、 TABLE SQL キーワードを使用します。

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

関数呼び出しからの入力行のパーティション分割を指定する

テーブル引数を使用して UDTF を呼び出す場合、SQL クエリでは、1 つ以上の入力テーブル列の値に基づいて、複数の UDTF 呼び出しで入力テーブルをパーティション分割できます。

パーティションを指定するには、PARTITION BY引数の後の関数呼び出しでTABLE句を使用します。 これにより、パーティション分割列の値を一意に組み合わせたすべての入力行が、UDTF クラスの 1 つのインスタンスで使用されるようになります。

PARTITION BY句は、単純な列参照に加えて、入力テーブルの列に基づく任意の式も受け入れることに注意してください。 たとえば、文字列の LENGTH を指定したり、日付から月を抽出したり、2 つの値を連結したりできます。

また、WITH SINGLE PARTITIONではなくPARTITION BYを指定して、すべての入力行を UDTF クラスの 1 つのインスタンスで使用する必要があるパーティションを 1 つだけ要求することもできます。

各パーティション内で、UDTF の eval メソッドが入力行を使用する際に必要な順序を必要に応じて指定できます。 これを行うには、前述のORDER BY句またはPARTITION BY句の後にWITH SINGLE PARTITION句を指定します。

たとえば、次の UDTF を考えてみましょう。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

入力テーブルに対して UDTF を呼び出すときに、パーティション分割オプションを複数の方法で指定できます。

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)";
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+

-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

analyze メソッドから入力行のパーティション分割を指定する

SQL クエリで UDF を呼び出すときに入力テーブルをパーティション分割する上記の各方法では、UDTF の analyze メソッドで同じパーティション分割方法を自動的に指定する方法があります。

  • SELECT * FROM udtf(TABLE(t) PARTITION BY a)として UDTF を呼び出す代わりに、analyze メソッドを更新してフィールドpartitionBy=[PartitioningColumn("a")]を設定し、SELECT * FROM udtf(TABLE(t))を使用して関数を呼び出すだけです。
  • 同じトークンを使用して、SQL クエリでTABLE(t) WITH SINGLE PARTITION ORDER BY bを指定する代わりに、analyzewithSinglePartition=trueおよびorderBy=[OrderingColumn("b")]フィールドを設定し、TABLE(t)渡すことができます。
  • SQL クエリでTABLE(SELECT a FROM t)を渡す代わりに、analyzeselect=[SelectedColumn("a")]設定してから、TABLE(t)渡すことができます。

次の例では、 analyze は定数出力スキーマを返し、入力テーブルから列のサブセットを選択し、 date 列の値に基づいて複数の UDTF 呼び出しで入力テーブルをパーティション分割することを指定します。

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add('longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word),
        alias="length_word")])