手記
この記事では、Databricks Runtime 14.1 以降の Databricks Connect について説明します。
この記事では、Databricks Connect for Scala を使用してユーザー定義関数を実行する方法について説明します。 Databricks Connect を使用すると、一般的な IDE、ノートブック サーバー、カスタム アプリケーションを Azure Databricks クラスターに接続できます。 この記事の Python バージョンについては、Databricks Connect for Pythonのユーザー定義関数の
手記
Databricks Connect の使用を開始する前に、Databricks Connect クライアントを設定する必要があります。
Databricks Runtime 14.1 以降では、Databricks Connect for Scala ではユーザー定義関数 (UDF) の実行がサポートされています。
UDF を実行するには、UDF に必要なコンパイル済みクラスと JAR をクラスターにアップロードする必要があります。
addCompiledArtifacts()
API を使用して、アップロードする必要があるコンパイル済みクラスファイルと JAR ファイルを指定できます。
手記
クライアントによって使用される Scala は、Azure Databricks クラスターの Scala バージョンと一致している必要があります。 クラスターの Scala のバージョンを確認するには、 Databricks Runtime のリリース ノートのバージョンと互換性に関するクラスターの Databricks Runtime バージョンの「システム環境」セクションを参照してください。
次の Scala プログラムは、列の値を 2 乗する単純な UDF を設定します。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI
val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()
def squared(x: Int): Int = x * x
val squared_udf = udf(squared _)
spark.range(3)
.withColumn("squared", squared_udf(col("id")))
.select("squared")
.show()
}
}
前の例では、UDF が Main
内に完全に含まれているため、Main
のコンパイル済み成果物のみが追加されます。
UDF が他のクラスに分散している場合、または外部ライブラリ (JAR) を使用している場合は、これらのライブラリもすべて含める必要があります。
Spark セッションが既に初期化されている場合は、spark.addArtifact()
API を使用して、さらにコンパイルされたクラスと JAR をアップロードできます。
手記
JAR をアップロードするときは、すべての推移的な依存関係 JAR をアップロードに含める必要があります。 API では、推移的な依存関係の自動検出は実行されません。
型付きデータセット API
UDF の前のセクションで説明したのと同じメカニズムは、型指定されたデータセット API にも適用されます。
型指定されたデータセット API を使用すると、結果のデータセットに対してマップ、フィルター、集計などの変換を実行できます。 これらは、Databricks クラスター上の UDF と同様に実行されます。
次の Scala アプリケーションでは、map()
API を使用して、結果列の数値をプレフィックス付き文字列に変更します。
import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}
object Main {
def main(args: Array[String]): Unit = {
val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI
val spark = DatabricksSession.builder()
.addCompiledArtifacts(sourceLocation)
.getOrCreate()
spark.range(3).map(f => s"row-$f").show()
}
}
この例では map()
API を使用していますが、これは、filter()
、mapPartitions()
など、型指定された他のデータセット API にも適用されます。