From 9f78869ea7de9893e0a0ac2670e471fdb64cafdd Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jun 2026 13:23:43 +0200 Subject: [PATCH 1/5] build: consolidate Rust crates into a Cargo workspace + extract native-common Move the standalone `native` crate into a root Cargo workspace and extract shared JNI plumbing (error->exception mapping, Tokio runtime singleton, StreamingReader) into a new `datafusion-jni-common` crate under `native-common/`. `native/src/errors.rs` moves to `native-common/src/errors.rs`; the nine native modules now import error/runtime helpers from `datafusion_jni_common`. Build glue follows: single root `Cargo.lock`, `.cargo/config.toml` redirects output to `rust-target/`, Makefile/CI/poms updated to build `--workspace` and target `-p datafusion-jni`. Core javadoc build commands updated to match. Pure refactor; no behavior change. First of a 6-PR stack splitting the Spark DataSource V2 connector work. Co-Authored-By: Claude Opus 4.8 (1M context) --- .cargo/config.toml | 21 ++ .github/workflows/build.yml | 4 +- .github/workflows/lint.yml | 8 +- .gitignore | 1 + native/Cargo.lock => Cargo.lock | 269 ++++++++++-------- Cargo.toml | 47 +++ Makefile | 10 +- core/pom.xml | 4 +- .../org/apache/datafusion/SessionContext.java | 11 +- .../SessionContextRuntimeStatsTest.java | 2 +- .../SessionContextSubstraitTest.java | 2 +- docs/source/contributor-guide/development.md | 21 +- .../updating-datafusion-version.md | 10 +- native-common/Cargo.toml | 35 +++ {native => native-common}/src/errors.rs | 7 +- native-common/src/lib.rs | 104 +++++++ native/Cargo.toml | 41 ++- native/src/arrow.rs | 2 +- native/src/avro.rs | 2 +- native/src/cache_manager.rs | 2 +- native/src/csv.rs | 2 +- native/src/json.rs | 2 +- native/src/lib.rs | 78 +---- native/src/object_store.rs | 2 +- native/src/proto.rs | 2 +- native/src/runtime_metrics.rs | 6 +- native/src/schema.rs | 2 +- pom.xml | 17 +- 28 files changed, 467 insertions(+), 247 deletions(-) create mode 100644 .cargo/config.toml rename native/Cargo.lock => Cargo.lock (94%) create mode 100644 Cargo.toml create mode 100644 native-common/Cargo.toml rename {native => native-common}/src/errors.rs (97%) create mode 100644 native-common/src/lib.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..d7e0ee2 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Keep Cargo's workspace output out of `target/` so `mvn clean` (which deletes +# the root `target/`) does not nuke the Rust build cache. +[build] +target-dir = "rust-target" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c5db936..da8e65a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,8 +83,8 @@ jobs: path: | ~/.cargo/registry ~/.cargo/git - native/target - key: ${{ runner.os }}-cargo-${{ hashFiles('native/Cargo.lock') }} + rust-target + key: ${{ runner.os }}-cargo-${{ hashFiles('Cargo.lock') }} restore-keys: ${{ runner.os }}-cargo- - name: Build native and run tests diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4cf628f..952bf34 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -54,7 +54,7 @@ jobs: run: ./mvnw -q spotless:check - name: Check Rust formatting - run: cd native && cargo fmt --all -- --check + run: cargo fmt --all -- --check clippy: name: Clippy @@ -81,9 +81,9 @@ jobs: path: | ~/.cargo/registry ~/.cargo/git - native/target - key: ${{ runner.os }}-clippy-${{ hashFiles('native/Cargo.lock') }} + rust-target + key: ${{ runner.os }}-clippy-${{ hashFiles('Cargo.lock') }} restore-keys: ${{ runner.os }}-clippy- - name: Run clippy - run: cd native && cargo clippy --all-targets -- -D warnings + run: cargo clippy --workspace --all-targets -- -D warnings diff --git a/.gitignore b/.gitignore index 719a2a4..25c9216 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target/ +rust-target/ *.class .idea/ .vscode/ diff --git a/native/Cargo.lock b/Cargo.lock similarity index 94% rename from native/Cargo.lock rename to Cargo.lock index 8c56280..286f96f 100644 --- a/native/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ dependencies = [ [[package]] name = "ar_archive_writer" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +checksum = "4087686b4b0a3427190bae57a1d9a478dbb2d40c5dc1bd6e2b6d797913bdd348" dependencies = [ "object", ] @@ -119,9 +119,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "607e64bb911ee4f90483e044fe78f175989148c2892e659a2cd25429e782ec54" +checksum = "378530e55cd479eda3c14eb345310799717e6f76d0c332041e8487022166b471" dependencies = [ "arrow-arith", "arrow-array", @@ -140,9 +140,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e754319ed8a85d817fe7adf183227e0b5308b82790a737b426c1124626b48118" +checksum = "a0ab212d2c1886e802f51c5212d78ebbcbb0bec980fff9dadc1eb8d45cd0b738" dependencies = [ "arrow-array", "arrow-buffer", @@ -154,9 +154,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841321891f247aa86c6112c80d83d89cb36e0addd020fa2425085b8eb6c3f579" +checksum = "cfd33d3e92f207444098c75b42de99d329562be0cf686b307b097cc52b4e999e" dependencies = [ "ahash", "arrow-buffer", @@ -173,9 +173,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f955dfb73fae000425f49c8226d2044dab60fb7ad4af1e24f961756354d996c9" +checksum = "0c6cd424c2693bcdbc150d843dc9d4d137dd2de4782ce6df491ad11a3a0416c0" dependencies = [ "bytes", "half", @@ -185,9 +185,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5e686972523798f76bef355145bc1ae25a84c731e650268d31ab763c701663" +checksum = "4c5aefb56a2c02e9e2b30746241058b85f8983f0fcff2ba0c6d09006e1cded7f" dependencies = [ "arrow-array", "arrow-buffer", @@ -207,9 +207,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86c276756867fc8186ec380c72c290e6e3b23a1d4fb05df6b1d62d2e62666d48" +checksum = "e94e8cf7e517657a52b91ea1263acf38c4ca62a84655d72458a3359b12ab97de" dependencies = [ "arrow-array", "arrow-cast", @@ -222,9 +222,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3b5846209775b6dc8056d77ff9a032b27043383dd5488abd0b663e265b9373" +checksum = "3c88210023a2bfee1896af366309a3028fc3bcbd6515fa29a7990ee1baa08ee0" dependencies = [ "arrow-buffer", "arrow-schema", @@ -235,9 +235,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd8907ddd8f9fbabf91ec2c85c1d81fe2874e336d2443eb36373595e28b98dd5" +checksum = "238438f0834483703d88896db6fe5a7138b2230debc31b34c0336c2996e3c64f" dependencies = [ "arrow-array", "arrow-buffer", @@ -251,9 +251,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4518c59acc501f10d7dcae397fe12b8db3d81bc7de94456f8a58f9165d6f502" +checksum = "205ca2119e6d679d5c133c6f30e68f027738d95ed948cf77677ea69c7800036b" dependencies = [ "arrow-array", "arrow-buffer", @@ -276,9 +276,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efa70d9d6b1356f1fb9f1f651b84a725b7e0abb93f188cf7d31f14abfa2f2e6f" +checksum = "1bffd8fd2579286a5d63bac898159873e5094a79009940bcb42bbfce4f19f1d0" dependencies = [ "arrow-array", "arrow-buffer", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faec88a945338192beffbbd4be0def70135422930caa244ac3cec0cd213b26b4" +checksum = "bab5994731204603c73ba69267616c50f80780774c6bb0476f1f830625115e0c" dependencies = [ "arrow-array", "arrow-buffer", @@ -302,9 +302,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18aa020f6bc8e5201dcd2d4b7f98c68f8a410ef37128263243e6ff2a47a67d4f" +checksum = "f633dbfdf39c039ada1bf9e34c694816eb71fbb7dc78f613993b7245e078a1ed" dependencies = [ "bitflags", "serde_core", @@ -313,9 +313,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a657ab5132e9c8ca3b24eb15a823d0ced38017fe3930ff50167466b02e2d592c" +checksum = "8cd065c54172ac787cf3f2f8d4107e0d3fdc26edba76fdf4f4cc170258942222" dependencies = [ "ahash", "arrow-array", @@ -327,9 +327,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6de2efbbd1a9f9780ceb8d1ff5d20421b35863b361e3386b4f571f1fc69fcb8" +checksum = "29dd7cda3ab9692f43a2e4acc444d760cc17b12bb6d8232ddf64e9bab7c06b42" dependencies = [ "arrow-array", "arrow-buffer", @@ -393,9 +393,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "base64" @@ -419,9 +419,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.11.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "b4388bee8683e3d04af747c73422af53102d2bd24d9eadb6cbc100baef4b43f8" [[package]] name = "blake2" @@ -457,9 +457,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.9.1" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" +checksum = "b2f04f6fef12d70d42a77b1433c9e0f065238479a6cefc4f5bab105e9873a3c3" dependencies = [ "bon-macros", "rustversion", @@ -467,9 +467,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.9.1" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" +checksum = "7d0bd4c2f75335ad98052a37efb54f428b492f64340257143b3429c8a508fa7b" dependencies = [ "darling", "ident_case", @@ -482,9 +482,9 @@ dependencies = [ [[package]] name = "brotli" -version = "8.0.2" +version = "8.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +checksum = "8119e4516436f5708bbc474a9d395bf12f1b5395e93a92a56e647ac3388c8610" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -493,9 +493,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "5.0.0" +version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +checksum = "5962523e1b92ce1b5e793d9169b9943eece10d39f62550bc04bb605d75b94924" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -503,9 +503,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "byteorder" @@ -530,9 +530,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.62" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "jobserver", @@ -571,9 +571,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +checksum = "1aa79e62e7697b8e29b513a68abacf485adcd1fe8284a4316c5ae868e6633327" dependencies = [ "iana-time-zone", "num-traits", @@ -789,9 +789,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "6.1.0" +version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -1299,6 +1299,16 @@ dependencies = [ "datafusion-physical-expr-common", ] +[[package]] +name = "datafusion-java-example-bridge" +version = "0.1.0" +dependencies = [ + "arrow", + "datafusion", + "datafusion-spark-bridge", + "tokio", +] + [[package]] name = "datafusion-jni" version = "0.1.0" @@ -1306,6 +1316,7 @@ dependencies = [ "arrow", "async-trait", "datafusion", + "datafusion-jni-common", "datafusion-proto", "datafusion-substrait", "futures", @@ -1319,6 +1330,16 @@ dependencies = [ "url", ] +[[package]] +name = "datafusion-jni-common" +version = "0.1.0" +dependencies = [ + "datafusion", + "futures", + "jni", + "tokio", +] + [[package]] name = "datafusion-macros" version = "53.1.0" @@ -1527,6 +1548,21 @@ dependencies = [ "parking_lot", ] +[[package]] +name = "datafusion-spark-bridge" +version = "0.1.0" +dependencies = [ + "arrow", + "async-trait", + "datafusion", + "datafusion-jni-common", + "datafusion-proto", + "futures", + "jni", + "prost", + "tokio", +] + [[package]] name = "datafusion-sql" version = "53.1.0" @@ -1579,9 +1615,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" dependencies = [ "proc-macro2", "quote", @@ -1596,9 +1632,9 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "equivalent" @@ -1904,9 +1940,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "http" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" dependencies = [ "bytes", "itoa", @@ -1949,9 +1985,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", @@ -2241,13 +2277,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "f2025f20d7a4fa7785846e7b63d10a76d3f1cee98ee5cb79ea59703f95e42162" dependencies = [ "cfg-if", "futures-util", - "once_cell", "wasm-bindgen", ] @@ -2316,9 +2351,9 @@ dependencies = [ [[package]] name = "libbz2-rs-sys" -version = "0.2.3" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a6a8c165077efc8f3a971534c50ea6a1a18b329ef4a66e897a7e3a1494565f" +checksum = "34b357333733e8260735ba5894eb928c02ecc69c78715f01a8019e7fa7f2db4c" [[package]] name = "libc" @@ -2375,9 +2410,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "953f07c43838f8e6f9758cab68bf5bed85465e7587ebe0b823f1bcd81978ad3a" [[package]] name = "lru-slab" @@ -2406,9 +2441,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" [[package]] name = "miniz_oxide" @@ -2422,9 +2457,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "wasi", @@ -2570,9 +2605,9 @@ dependencies = [ [[package]] name = "parquet" -version = "58.2.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d7efd3052f7d6ef601085559a246bc991e9a8cc77e02753737df6322ce35f1" +checksum = "5dafa7d01085b62a47dd0c1829550a0a36710ea9c4fe358a05a85477cec8a908" dependencies = [ "ahash", "arrow-array", @@ -2734,9 +2769,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +checksum = "528ac67416ff8646872a3c02cad9cc4ee5dc9f9540c9b10771855c95cb2e5ae1" dependencies = [ "bytes", "prost-derive", @@ -2744,9 +2779,9 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +checksum = "03da047801ff44bb6a4d407d4860c05fd70bb81714e6b2f3812603d5b145b042" dependencies = [ "heck", "itertools", @@ -2763,9 +2798,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +checksum = "b570b25f7617e43d59005d0990ccb79e950a423952cea19671b7a876da390adf" dependencies = [ "anyhow", "itertools", @@ -2776,9 +2811,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.3" +version = "0.14.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +checksum = "f94967dc7688f3054c7fac87473ffae4cc4c3904800e2d9f5b857246d8963b0a" dependencies = [ "prost", ] @@ -3035,9 +3070,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.12.3" +version = "1.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +checksum = "f1292b7759ae1cb9ec195452d1390a074f0cd8541ab7a5a8c31cd6db45d4a6ba" dependencies = [ "aho-corasick", "memchr", @@ -3064,9 +3099,9 @@ checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" [[package]] name = "regex-syntax" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +checksum = "d6f6ff9a378485b298a5286656da665ba74413d36db0979633275d2e708145d4" [[package]] name = "regress" @@ -3178,9 +3213,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +checksum = "dab5152771c58876a2146916e53e35057e1a4dfa2b9df0f0305b07f611fdea4d" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -3361,9 +3396,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", @@ -3422,9 +3457,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "simd-adler32" @@ -3464,9 +3499,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -3861,9 +3896,9 @@ checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "typify" @@ -3920,9 +3955,9 @@ checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-segmentation" -version = "1.13.2" +version = "1.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" +checksum = "c6f5d3c3b1bf09027a88a6bc961fc00497d651009560b5463668dc81b0fa87a8" [[package]] name = "unicode-width" @@ -3968,9 +4003,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "144d6b123cef80b301b8f72a9e2ca4370ddec21950d0a103dd22c437006d2db7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -4029,9 +4064,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "a254a4b10c19a76f09a27640e7ffbf9bc30bf67e16a3bf28aaefa4920fe81563" dependencies = [ "cfg-if", "once_cell", @@ -4042,9 +4077,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "54568702fabf5d4849ce2b90fadfa64168a097eaf4b351ce9df8b687a0086aaf" dependencies = [ "js-sys", "wasm-bindgen", @@ -4052,9 +4087,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "24a40fc75b0ec6f3746ceb10d36f53a93dcd68a93b11b6445983945d79eba0dc" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4062,9 +4097,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "908f34bd9b9ce3d4caf07b72dfab63d61504d156856c6bd3cd87fa350cf3985b" dependencies = [ "bumpalo", "proc-macro2", @@ -4075,9 +4110,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.123" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "7acbf7616c27b194bbb550bf77ed0c2c3e5b7fd1260a93082b95fb7f47959b92" dependencies = [ "unicode-ident", ] @@ -4131,9 +4166,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6e0871acf327f283dc6da28a1696cdc64fb355ba9f935d052021fa77f35cce69" dependencies = [ "js-sys", "wasm-bindgen", @@ -4541,9 +4576,9 @@ checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "yoke" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +checksum = "709fe23a0424b6a435d82152b1bd3fdfb0833487d5fa90d05d42762a9891fef5" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -4564,18 +4599,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.48" +version = "0.8.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" dependencies = [ "proc-macro2", "quote", @@ -4584,9 +4619,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..4be0260 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[workspace] +resolver = "2" +members = [ + "native", + "native-common", +] + +# Every dependency used by any workspace member is declared here so version +# bumps live in one place and the resolver picks a single version of each +# crate across the workspace. Members reference these via `{ workspace = true }` +# and add per-crate flags (optional, features, default-features) at the use +# site. +[workspace.dependencies] +arrow = { version = "58", features = ["ffi"] } +async-trait = "0.1" +datafusion = { version = "53.1.0" } +datafusion-proto = "53.1.0" +datafusion-substrait = "53.1.0" +futures = "0.3" +jni = "0.21" +# Pinned to the major DataFusion 53.1 pulls in transitively (0.13.x) so we +# share the same `dyn ObjectStore` vtable and don't double-link. +object_store = { version = "0.13", default-features = false } +prost = "0.14" +prost-build = "0.14" +protoc-bin-vendored = "3" +tokio = { version = "1", features = ["rt-multi-thread"] } +# Optional, cfg-gated. See `native/Cargo.toml` for the build-flag dance. +tokio-metrics = "0.5" +url = "2" diff --git a/Makefile b/Makefile index 6d9b0ae..d6bcf2c 100644 --- a/Makefile +++ b/Makefile @@ -20,14 +20,14 @@ all: native jvm native: - cd native && cargo build + cargo build --workspace -# Build the native crate with the `runtime-metrics` Cargo feature enabled. +# Build the JNI crate with the `runtime-metrics` Cargo feature enabled. # Requires `--cfg tokio_unstable` because tokio-metrics gates its API there. # Default `make native` does not pull this in; callers who need # SessionContext.runtimeStats() pick this target explicitly. native-runtime-metrics: - cd native && RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics + RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics jvm: ./mvnw package -DskipTests @@ -39,10 +39,10 @@ test: native # `:check` form inline in .github/workflows/lint.yml. format: ./mvnw -q spotless:apply - cd native && cargo fmt --all + cargo fmt --all clean: - cd native && cargo clean + cargo clean ./mvnw clean tpch-data: diff --git a/core/pom.xml b/core/pom.xml index 5ddf107..1e25736 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -102,8 +102,8 @@ under the License. - + value="${maven.multiModuleProjectDirectory}/rust-target/${datafusion.native.profile}/${datafusion.lib.filename}"/> + diff --git a/core/src/main/java/org/apache/datafusion/SessionContext.java b/core/src/main/java/org/apache/datafusion/SessionContext.java index ffc58dd..27d2b16 100644 --- a/core/src/main/java/org/apache/datafusion/SessionContext.java +++ b/core/src/main/java/org/apache/datafusion/SessionContext.java @@ -113,10 +113,11 @@ public DataFrame fromProto(byte[] planBytes) { * other Substrait-emitting tool — and hand them to DataFusion without round-tripping through SQL. * *

Substrait support is gated behind the {@code substrait} Cargo feature on the native crate - * and is off by default. Rebuild the native crate with {@code cargo build - * --features substrait} (or {@code cargo build --features substrait,protoc} for hermetic builds - * that vendor {@code protoc} via {@code cmake}) to enable it. If invoked against a native binary - * built without the feature, this method throws {@link RuntimeException} pointing at the flag. + * and is off by default. Rebuild the native crate with {@code cargo build -p + * datafusion-jni --features substrait} (or {@code ... --features substrait,protoc} for hermetic + * builds that vendor {@code protoc} via {@code cmake}) to enable it. If invoked against a native + * binary built without the feature, this method throws {@link RuntimeException} pointing at the + * flag. * * @throws IllegalArgumentException if {@code planBytes} is {@code null}. * @throws IllegalStateException if this context is closed. @@ -183,7 +184,7 @@ public MemoryUsage memoryUsage() { * Rebuild with: * *

{@code
-   * RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics
+   * RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics
    * }
* *

If invoked against a native binary built without the feature, this method throws {@link diff --git a/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java b/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java index 120d179..d567275 100644 --- a/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java +++ b/core/src/test/java/org/apache/datafusion/SessionContextRuntimeStatsTest.java @@ -37,7 +37,7 @@ * #checkFeatureEnabled}. Run * *

{@code
- * (cd native && RUSTFLAGS="--cfg tokio_unstable" cargo build --features runtime-metrics)
+ * RUSTFLAGS="--cfg tokio_unstable" cargo build -p datafusion-jni --features runtime-metrics
  * }
* * before {@code ./mvnw test} to exercise this class. diff --git a/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java b/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java index 34db3b5..a2cfb0a 100644 --- a/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java +++ b/core/src/test/java/org/apache/datafusion/SessionContextSubstraitTest.java @@ -50,7 +50,7 @@ * *

The {@code substrait} Cargo feature is off by default in {@code native/Cargo.toml}; if the * native crate was built without it, every test here is skipped (see {@link #checkFeatureEnabled}). - * Run {@code (cd native && cargo build --features substrait)} before {@code ./mvnw test} to + * Run {@code cargo build -p datafusion-jni --features substrait} before {@code ./mvnw test} to * exercise this class. */ class SessionContextSubstraitTest { diff --git a/docs/source/contributor-guide/development.md b/docs/source/contributor-guide/development.md index 984d77c..fdb00f4 100644 --- a/docs/source/contributor-guide/development.md +++ b/docs/source/contributor-guide/development.md @@ -42,7 +42,7 @@ This builds the native Rust crate and runs the JUnit tests. The steps can be run individually: ```sh -cd native && cargo build +cargo build --workspace ./mvnw test ``` @@ -74,14 +74,25 @@ disk space. The repository is a multi-module Maven build: -- `pom.xml` — parent POM declaring the `core` and `examples` modules and - shared plugin/dependency versions. +- `Cargo.toml` — Rust workspace root declaring the three crate members + (`native`, `native-common`, `examples/native`, `spark/bridge`) and `[workspace.dependencies]` + that pin shared versions in one place. Cargo writes artifacts to + `rust-target/` (overridden in `.cargo/config.toml`) so `mvn clean` at the + repo root does not nuke the Rust build cache. +- `pom.xml` — parent POM declaring the `core`, `spark`, and `examples` + modules and shared plugin/dependency versions. - `core/` — `datafusion-java` library module (Java sources, tests, and generated protobuf classes). +- `spark/` — `datafusion-java-spark` Spark DataSource V2 connector + (Scala + Java, pure JVM) and its `spark/bridge/` Rust SDK crate + (`datafusion-spark-bridge`: widening, scan machinery, `export_bridge!`). - `examples/` — `datafusion-java-examples` module containing runnable examples that depend on the library; built alongside the library so they - cannot fall out of sync with the API. -- `native/` — Rust crate (JNI + Arrow C Data Interface). + cannot fall out of sync with the API. Includes `examples/native/`, a + small `export_bridge!` cdylib used by the Spark connector demo + (`ExampleBridgeProviderFactory` + the pyspark script under + `examples/python/`). +- `native/` — `datafusion-jni` Rust crate (JNI + Arrow C Data Interface). - `proto/` — Protobuf definitions shared between Java and Rust. - `Makefile` — top-level build orchestration (`make test`, `make format`, `make tpch-data`). diff --git a/docs/source/contributor-guide/updating-datafusion-version.md b/docs/source/contributor-guide/updating-datafusion-version.md index 56d50dc..ef6cd10 100644 --- a/docs/source/contributor-guide/updating-datafusion-version.md +++ b/docs/source/contributor-guide/updating-datafusion-version.md @@ -21,7 +21,9 @@ under the License. Three things must move together when bumping DataFusion: -1. `native/Cargo.toml` — the `datafusion` crate dependency. +1. `Cargo.toml` (workspace root) — the `datafusion`, `datafusion-ffi`, + `datafusion-proto`, and `datafusion-substrait` entries in + `[workspace.dependencies]`. Members inherit from there. 2. `pom.xml` — the `` Maven property. **Must equal the Cargo version**; a mismatch means JVM-built protobuf plans won't deserialize on the native side. @@ -32,9 +34,9 @@ Three things must move together when bumping DataFusion: ## Recipe ```sh -# 1. Bump the Cargo dep -$EDITOR native/Cargo.toml # set datafusion = "" -(cd native && cargo update -p datafusion) +# 1. Bump the workspace dep +$EDITOR Cargo.toml # set datafusion = "" in [workspace.dependencies] +cargo update -p datafusion # 2. Bump the Maven property to match $EDITOR pom.xml # set diff --git a/native-common/Cargo.toml b/native-common/Cargo.toml new file mode 100644 index 0000000..0a797b4 --- /dev/null +++ b/native-common/Cargo.toml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-jni-common" +version = "0.1.0" +edition = "2021" +publish = false + +[features] +# `datafusion-jni` builds DataFusion with `avro`, which adds the +# `DataFusionError::AvroError` variant our classifier maps to IoException. +# Feature-forwarded so consumers that don't read Avro (the Spark helper) +# don't pull the apache-avro stack into their cdylib. +avro = ["datafusion/avro"] + +[dependencies] +datafusion = { workspace = true } +futures = { workspace = true } +jni = { workspace = true } +tokio = { workspace = true } diff --git a/native/src/errors.rs b/native-common/src/errors.rs similarity index 97% rename from native/src/errors.rs rename to native-common/src/errors.rs index d926544..caa2540 100644 --- a/native/src/errors.rs +++ b/native-common/src/errors.rs @@ -96,8 +96,11 @@ fn classify(err: &DataFusionError) -> &'static str { } DataFusionError::IoError(_) | DataFusionError::ObjectStore(_) - | DataFusionError::ParquetError(_) - | DataFusionError::AvroError(_) => "org/apache/datafusion/IoException", + | DataFusionError::ParquetError(_) => "org/apache/datafusion/IoException", + // The AvroError variant only exists when DataFusion is built with its + // `avro` feature, forwarded by this crate's own `avro` feature. + #[cfg(feature = "avro")] + DataFusionError::AvroError(_) => "org/apache/datafusion/IoException", // ArrowError is a 21-variant grab bag -- only some of those variants // are actually IO-shaped. DivideByZero / ArithmeticOverflow / Compute // / Cast / InvalidArgument / Memory etc. are execution-time failures diff --git a/native-common/src/lib.rs b/native-common/src/lib.rs new file mode 100644 index 0000000..f143d43 --- /dev/null +++ b/native-common/src/lib.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! JNI plumbing shared by this workspace's native crates (`datafusion-jni` +//! and `datafusion-spark-bridge`, and through the latter every bridge +//! cdylib): the error-to-Java-exception mapping, the per-cdylib Tokio +//! runtime singleton, and the async-stream-to-`FFI_ArrowArrayStream` +//! bridge. +//! +//! Each cdylib statically links its own copy of this rlib, so [`runtime`] is +//! a per-cdylib singleton -- exactly the behaviour each crate had when this +//! code lived inline. Nothing here is exported with `#[no_mangle]`, so +//! linking this crate into several cdylibs loaded in one JVM cannot collide. + +pub mod errors; + +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::OnceLock; + +use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::ArrowError; +use datafusion::arrow::record_batch::RecordBatchReader; +use datafusion::execution::SendableRecordBatchStream; +use futures::StreamExt; +use tokio::runtime::{Handle, Runtime}; + +static RT: OnceLock = OnceLock::new(); + +/// The cdylib-wide Tokio runtime. +pub fn runtime() -> &'static Runtime { + runtime_with_init(|_| {}) +} + +/// Same singleton as [`runtime`], with a hook that runs exactly once, when +/// the runtime is created. `datafusion-jni` uses it to install its +/// runtime-metrics accumulator so the sampling baseline coincides with +/// runtime start; every later call (either entry point) returns the existing +/// runtime without invoking the hook. +pub fn runtime_with_init(init: impl FnOnce(&Handle)) -> &'static Runtime { + RT.get_or_init(|| { + let rt = Runtime::new().expect("failed to create Tokio runtime"); + init(rt.handle()); + rt + }) +} + +/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous +/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore +/// the Java `ArrowReader`) consumes. Each call to `next()` drives one +/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the +/// executor pipeline plus a single in-flight batch. +pub struct StreamingReader { + pub schema: SchemaRef, + pub stream: SendableRecordBatchStream, +} + +impl Iterator for StreamingReader { + type Item = Result; + + fn next(&mut self) -> Option { + // Arrow's C ABI invokes this iterator through FFI_ArrowArrayStream's + // vtable, outside the JNI handler's try_unwrap_or_throw guard. A panic + // here (buggy UDF, arrow cast that panics, runtime poison) would + // unwind across C/FFI -- undefined behaviour. Catch it and surface as + // an ArrowError so the Java side sees a normal exception instead. + let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next()))); + match next { + Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))), + Err(panic) => { + let msg = if let Some(s) = panic.downcast_ref::() { + s.clone() + } else if let Some(s) = panic.downcast_ref::<&str>() { + (*s).to_string() + } else { + "rust panic with non-string payload".to_string() + }; + Some(Err(ArrowError::ExternalError( + format!("panic in DataFrame stream: {msg}").into(), + ))) + } + } + } +} + +impl RecordBatchReader for StreamingReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/native/Cargo.toml b/native/Cargo.toml index c462408..0f4ca83 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -23,8 +23,8 @@ publish = false [lib] # `rlib` alongside `cdylib` so `cargo test` has a Rust-level harness for -# native-only invariants (e.g. error-classification routing through wrapped -# DataFusionError chains). The `cdylib` is still the artifact the JVM loads. +# native-only invariants (the error-classification tests now live in +# `datafusion-jni-common`). The `cdylib` is still the artifact the JVM loads. crate-type = ["cdylib", "rlib"] [features] @@ -69,24 +69,23 @@ protoc = ["datafusion-substrait?/protoc"] runtime-metrics = ["dep:tokio-metrics"] [dependencies] -arrow = { version = "58", features = ["ffi"] } -async-trait = "0.1" -datafusion = { version = "53.1.0", features = ["avro"] } -datafusion-proto = "53.1.0" -datafusion-substrait = { version = "53.1.0", optional = true } -futures = "0.3" -jni = "0.21" -# Pin to the same major as DataFusion 53.1 pulls in transitively (0.13.x) -# so we share the same `dyn ObjectStore` vtable and don't double-link. -object_store = { version = "0.13", default-features = false } -prost = "0.14" -tokio = { version = "1", features = ["rt-multi-thread"] } -# Tokio runtime metrics. Optional + cfg-gated: this crate's API surface lives -# behind `--cfg tokio_unstable`, so enabling the `runtime-metrics` feature also -# requires the caller to set `RUSTFLAGS="--cfg tokio_unstable"` at build time. -tokio-metrics = { version = "0.5", optional = true } -url = "2" +arrow = { workspace = true } +async-trait = { workspace = true } +datafusion = { workspace = true, features = ["avro"] } +# Shared JNI plumbing (error->exception mapping, runtime singleton, +# StreamingReader). `avro` keeps the classifier's AvroError->IoException arm +# in sync with the `avro` feature on `datafusion` above. +datafusion-jni-common = { path = "../native-common", features = ["avro"] } +datafusion-proto = { workspace = true } +datafusion-substrait = { workspace = true, optional = true } +futures = { workspace = true } +jni = { workspace = true } +object_store = { workspace = true } +prost = { workspace = true } +tokio = { workspace = true } +tokio-metrics = { workspace = true, optional = true } +url = { workspace = true } [build-dependencies] -prost-build = "0.14" -protoc-bin-vendored = "3" +prost-build = { workspace = true } +protoc-bin-vendored = { workspace = true } diff --git a/native/src/arrow.rs b/native/src/arrow.rs index 2bbe7b0..67e5caf 100644 --- a/native/src/arrow.rs +++ b/native/src/arrow.rs @@ -23,10 +23,10 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::ArrowReadOptionsProto; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_arrow_options( env: &mut JNIEnv, diff --git a/native/src/avro.rs b/native/src/avro.rs index 85d4a07..257ae32 100644 --- a/native/src/avro.rs +++ b/native/src/avro.rs @@ -23,10 +23,10 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::AvroReadOptionsProto; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_avro_options( env: &mut JNIEnv, diff --git a/native/src/cache_manager.rs b/native/src/cache_manager.rs index 3b9e286..ec38dc8 100644 --- a/native/src/cache_manager.rs +++ b/native/src/cache_manager.rs @@ -34,8 +34,8 @@ use datafusion::execution::cache::cache_unit::{ }; use datafusion::execution::cache::DefaultListFilesCache; -use crate::errors::JniResult; use crate::proto_gen::CacheManagerOptionsProto; +use datafusion_jni_common::errors::JniResult; /// Build a [`CacheManagerConfig`] from the proto. Returns `Ok(None)` if the /// caller did not set any cache-manager field, so the JNI layer can skip the diff --git a/native/src/csv.rs b/native/src/csv.rs index 3ae4627..b79ed59 100644 --- a/native/src/csv.rs +++ b/native/src/csv.rs @@ -26,12 +26,12 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::{ CsvReadOptionsProto, CsvWriteOptionsProto, FileCompressionType as ProtoFileCompressionType, }; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_csv_options( env: &mut JNIEnv, diff --git a/native/src/json.rs b/native/src/json.rs index 8eea32f..b87be78 100644 --- a/native/src/json.rs +++ b/native/src/json.rs @@ -27,12 +27,12 @@ use jni::sys::jlong; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::proto_gen::{ FileCompressionType as ProtoFileCompressionType, JsonWriteOptionsProto, NdJsonReadOptionsProto, }; use crate::runtime; use crate::schema::decode_optional_schema; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; fn with_json_options( env: &mut JNIEnv, diff --git a/native/src/lib.rs b/native/src/lib.rs index 4fd7a8a..6e1a79f 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -19,7 +19,6 @@ mod arrow; mod avro; mod cache_manager; mod csv; -mod errors; mod jni_util; mod json; mod memory; @@ -34,16 +33,13 @@ pub(crate) mod proto_gen { include!(concat!(env!("OUT_DIR"), "/datafusion_java.rs")); } -use std::panic::{catch_unwind, AssertUnwindSafe}; use std::path::PathBuf; use std::sync::{Arc, OnceLock}; -use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::arrow::error::ArrowError; use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; use datafusion::arrow::ipc::writer::StreamWriter; -use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; +use datafusion::arrow::record_batch::RecordBatchIterator; use datafusion::common::{JoinType, UnnestOptions}; use datafusion::config::TableParquetOptions; use datafusion::dataframe::DataFrame; @@ -51,11 +47,9 @@ use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::DataFusionError; use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::SendableRecordBatchStream; use datafusion::logical_expr::Expr; use datafusion::logical_expr::{col, Partitioning, ScalarUDF, Signature, SortExpr}; use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; -use futures::StreamExt; use jni::objects::{JBooleanArray, JByteArray, JClass, JObject, JObjectArray, JString}; use jni::sys::{jboolean, jbyte, jbyteArray, jint, jlong}; use jni::JNIEnv; @@ -63,7 +57,10 @@ use jni::JavaVM; use prost::Message; use tokio::runtime::Runtime; -use crate::errors::{try_unwrap_or_throw, JniResult}; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; +// Re-exported so sibling modules keep their crate-local `crate::StreamingReader` path. +pub(crate) use datafusion_jni_common::StreamingReader; + use crate::proto_gen::ParquetReadOptionsProto; use crate::proto_gen::SessionOptions; use crate::schema::decode_optional_schema; @@ -84,18 +81,15 @@ pub(crate) fn jvm() -> &'static JavaVM { } pub(crate) fn runtime() -> &'static Runtime { - static RT: OnceLock = OnceLock::new(); - RT.get_or_init(|| { - let rt = Runtime::new().expect("failed to create Tokio runtime"); - // Eagerly install the runtime-metrics accumulator (no-op when the - // `runtime-metrics` Cargo feature is off). Initialising here -- not - // lazily on the first `runtimeStats()` call -- means the - // RuntimeMonitor's sampling baseline coincides with runtime start, so - // poll/park/busy totals reflect activity from the first query onward - // rather than from the first observation. - crate::runtime_metrics::init(rt.handle()); - rt - }) + // The singleton itself lives in datafusion-jni-common (shared with the + // datafusion-spark-bridge SDK; each cdylib statically links its own + // copy, so the runtime stays per-library). The init hook eagerly installs the + // runtime-metrics accumulator (no-op when the `runtime-metrics` Cargo + // feature is off). Initialising here -- not lazily on the first + // `runtimeStats()` call -- means the RuntimeMonitor's sampling baseline + // coincides with runtime start, so poll/park/busy totals reflect activity + // from the first query onward rather than from the first observation. + datafusion_jni_common::runtime_with_init(crate::runtime_metrics::init) } /// Wrap the (already-built) `RuntimeEnvBuilder`'s memory pool with a @@ -289,50 +283,6 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_collectDataFrame<'lo }) } -/// Bridges DataFusion's async [`SendableRecordBatchStream`] to the synchronous -/// [`RecordBatchReader`] interface that `FFI_ArrowArrayStream` (and therefore -/// the Java `ArrowReader`) consumes. Each call to `next()` drives one -/// `runtime().block_on(stream.next())`, so memory pressure stays bounded by the -/// executor pipeline plus a single in-flight batch. -struct StreamingReader { - schema: SchemaRef, - stream: SendableRecordBatchStream, -} - -impl Iterator for StreamingReader { - type Item = Result; - - fn next(&mut self) -> Option { - // Arrow's C ABI invokes this iterator through FFI_ArrowArrayStream's - // vtable, outside the JNI handler's try_unwrap_or_throw guard. A panic - // here (buggy UDF, arrow cast that panics, runtime poison) would - // unwind across C/FFI -- undefined behaviour. Catch it and surface as - // an ArrowError so the Java side sees a normal exception instead. - let next = catch_unwind(AssertUnwindSafe(|| runtime().block_on(self.stream.next()))); - match next { - Ok(item) => item.map(|r| r.map_err(|e| ArrowError::ExternalError(Box::new(e)))), - Err(panic) => { - let msg = if let Some(s) = panic.downcast_ref::() { - s.clone() - } else if let Some(s) = panic.downcast_ref::<&str>() { - (*s).to_string() - } else { - "rust panic with non-string payload".to_string() - }; - Some(Err(ArrowError::ExternalError( - format!("panic in DataFrame stream: {msg}").into(), - ))) - } - } - } -} - -impl RecordBatchReader for StreamingReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_DataFrame_executeStreamDataFrame<'local>( mut env: JNIEnv<'local>, diff --git a/native/src/object_store.rs b/native/src/object_store.rs index eefccf2..985d721 100644 --- a/native/src/object_store.rs +++ b/native/src/object_store.rs @@ -28,9 +28,9 @@ use std::sync::Arc; use datafusion::prelude::SessionContext; use url::Url; -use crate::errors::JniResult; use crate::proto_gen::object_store_registration::Backend; use crate::proto_gen::ObjectStoreRegistration; +use datafusion_jni_common::errors::JniResult; #[cfg(feature = "object-store-gcp")] use crate::proto_gen::GcsOptions; diff --git a/native/src/proto.rs b/native/src/proto.rs index 4f187bc..c1315f9 100644 --- a/native/src/proto.rs +++ b/native/src/proto.rs @@ -28,8 +28,8 @@ use jni::sys::{jbyteArray, jlong}; use jni::JNIEnv; use prost::Message; -use crate::errors::{try_unwrap_or_throw, JniResult}; use crate::runtime; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_SessionContext_createDataFrameFromProto< diff --git a/native/src/runtime_metrics.rs b/native/src/runtime_metrics.rs index e69410e..dd60dcb 100644 --- a/native/src/runtime_metrics.rs +++ b/native/src/runtime_metrics.rs @@ -38,7 +38,7 @@ //! 10 totalOverflowCount #[cfg(not(feature = "runtime-metrics"))] -use crate::errors::JniResult; +use datafusion_jni_common::errors::JniResult; /// Number of i64 values in the snapshot array; kept here so the Java side and /// the feature-off stub agree on the layout. @@ -51,7 +51,7 @@ mod imp { use tokio_metrics::{RuntimeIntervals, RuntimeMonitor}; use super::STATS_FIELD_COUNT; - use crate::errors::JniResult; + use datafusion_jni_common::errors::JniResult; /// `RuntimeMonitor::intervals().next()` returns *delta* metrics covering /// the period since the previous call (or, on the very first call, since @@ -196,7 +196,7 @@ pub fn runtime_stats() -> JniResult<[i64; STATS_FIELD_COUNT]> { Err( "datafusion-jni was built without the `runtime-metrics` Cargo feature; \ rebuild the native crate with \ - `RUSTFLAGS=\"--cfg tokio_unstable\" cargo build --features runtime-metrics` \ + `RUSTFLAGS=\"--cfg tokio_unstable\" cargo build -p datafusion-jni --features runtime-metrics` \ to enable SessionContext.runtimeStats" .into(), ) diff --git a/native/src/schema.rs b/native/src/schema.rs index 968a73a..0c3c7ab 100644 --- a/native/src/schema.rs +++ b/native/src/schema.rs @@ -20,7 +20,7 @@ use datafusion::arrow::ipc::reader::StreamReader; use jni::objects::JByteArray; use jni::JNIEnv; -use crate::errors::JniResult; +use datafusion_jni_common::errors::JniResult; /// Decode an optional Arrow-IPC schema byte array passed in from Java. /// Returns `None` if the byte-array reference is null. diff --git a/pom.xml b/pom.xml index 6210841..b92cf72 100644 --- a/pom.xml +++ b/pom.xml @@ -95,6 +95,11 @@ under the License. + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + org.apache.maven.plugins maven-surefire-plugin @@ -159,6 +164,7 @@ under the License. README.md CONTRIBUTING.md docs/** + **/*.md .gitignore .idea/** @@ -173,12 +179,17 @@ under the License. .mvn/** **/target/** - native/target/** + rust-target/** tpch-data/** - - native/Cargo.lock + + Cargo.lock + + **/META-INF/services/** dev/release/rat_exclude_files.txt + + spark/scaffold/bridge-template/** From 28efd7c971b1200a84202fb75f585a17436027f2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jun 2026 13:24:28 +0200 Subject: [PATCH 2/5] feat(spark): add datafusion-spark-bridge SDK for static bridges New `spark/bridge` workspace crate providing the `export_bridge!` macro that generates the six JNI entry points a Spark connector bridge exposes (providerSchemaIpc, createScan, partitionCount, executeStreamPartition, executeStream, closeScan). Includes the options decoder, scan planning/execution glue, and the Arrow type-widening layer (wraps any TableProvider for Spark type compatibility). Self-contained SDK with no Java/Scala coupling. Depends only on datafusion-jni-common. Second of the 6-PR connector stack. Co-Authored-By: Claude Opus 4.8 (1M context) --- Cargo.lock | 10 - Cargo.toml | 1 + spark/bridge/Cargo.toml | 34 +++ spark/bridge/src/lib.rs | 213 ++++++++++++++++ spark/bridge/src/options.rs | 158 ++++++++++++ spark/bridge/src/scan.rs | 325 +++++++++++++++++++++++++ spark/bridge/src/widening.rs | 376 +++++++++++++++++++++++++++++ spark/bridge/tests/export_macro.rs | 52 ++++ 8 files changed, 1159 insertions(+), 10 deletions(-) create mode 100644 spark/bridge/Cargo.toml create mode 100644 spark/bridge/src/lib.rs create mode 100644 spark/bridge/src/options.rs create mode 100644 spark/bridge/src/scan.rs create mode 100644 spark/bridge/src/widening.rs create mode 100644 spark/bridge/tests/export_macro.rs diff --git a/Cargo.lock b/Cargo.lock index 286f96f..307f24b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1299,16 +1299,6 @@ dependencies = [ "datafusion-physical-expr-common", ] -[[package]] -name = "datafusion-java-example-bridge" -version = "0.1.0" -dependencies = [ - "arrow", - "datafusion", - "datafusion-spark-bridge", - "tokio", -] - [[package]] name = "datafusion-jni" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 4be0260..fffea43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ resolver = "2" members = [ "native", "native-common", + "spark/bridge", ] # Every dependency used by any workspace member is declared here so version diff --git a/spark/bridge/Cargo.toml b/spark/bridge/Cargo.toml new file mode 100644 index 0000000..8ed4684 --- /dev/null +++ b/spark/bridge/Cargo.toml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark-bridge" +version = "0.1.0" +edition = "2021" +publish = false +description = "SDK for building Spark connector bridges over DataFusion TableProviders" + +[dependencies] +arrow = { workspace = true } +async-trait = { workspace = true } +datafusion = { workspace = true } +datafusion-jni-common = { path = "../../native-common" } +datafusion-proto = { workspace = true } +futures = { workspace = true } +jni = { workspace = true } +prost = { workspace = true } +tokio = { workspace = true } diff --git a/spark/bridge/src/lib.rs b/spark/bridge/src/lib.rs new file mode 100644 index 0000000..52ef1c1 --- /dev/null +++ b/spark/bridge/src/lib.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SDK for building Spark connector bridges over DataFusion `TableProvider`s. +//! +//! Everything the Spark connector needs DataFusion-side lives here: the +//! Spark-type [`widening`] layer, and the [`scan`] machinery (session from +//! pinned config, projection, proto filters, planning, partition streams). +//! A bridge cdylib depends on this crate and invokes [`export_bridge!`] with +//! a builder that constructs its concrete `TableProvider` from option / +//! partition bytes — one cdylib, no FFI provider boundary; the only foreign +//! interface is JNI plus Arrow's C stream for the results. + +pub mod options; +pub mod scan; +pub mod widening; + +// Re-exported so `export_bridge!` expansions resolve these crates inside the +// bridge author's crate without extra dependencies, and so builder signatures +// can be written against `datafusion_spark_bridge::datafusion::...`. +pub use datafusion; +pub use datafusion_jni_common::errors::JniResult; +pub use jni; + +use tokio::runtime::Handle; + +/// Execution environment handed to a bridge's provider builder. +/// +/// Provider construction frequently needs async IO (remote catalogs, +/// object-store metadata); run it on the bridge runtime via [`block_on`] +/// rather than creating a runtime of your own. +/// +/// [`block_on`]: BridgeContext::block_on +pub struct BridgeContext { + handle: &'static Handle, +} + +impl BridgeContext { + /// Used by `export_bridge!` expansions; not part of the public API. + #[doc(hidden)] + pub fn get() -> Self { + BridgeContext { + handle: runtime_handle(), + } + } + + /// The cdylib-wide Tokio runtime handle (also the runtime scans run on). + pub fn handle(&self) -> &Handle { + self.handle + } + + /// Block the current (JVM) thread on `fut`, driving it on the bridge + /// runtime. + pub fn block_on(&self, fut: F) -> F::Output { + self.handle.block_on(fut) + } +} + +/// Per-cdylib Tokio runtime (the singleton from `datafusion-jni-common`). +pub(crate) fn runtime_handle() -> &'static Handle { + datafusion_jni_common::runtime().handle() +} + +/// Generate the JNI entry points for a bridge cdylib. +/// +/// `jni_class` is the **underscore-mangled** binary name of the Java class +/// declaring the matching `native` methods: dots become underscores +/// (`com.example.mybridge.BridgeNative` → `"com_example_mybridge_BridgeNative"`). +/// If the class or package name itself contains an underscore, JNI mangling +/// requires it written as `_1`. Per-bridge class names are what let several +/// bridges coexist in one Spark JVM. +/// +/// `build_provider` is anything callable as +/// `Fn(&BridgeContext, &[u8], &[u8]) -> JniResult>`, +/// receiving the options bytes and partition bytes your JVM factory encoded. +/// The schema probe calls it with empty partition bytes; the scan path passes +/// each task's payload. Return errors boxed from `DataFusionError` to surface +/// as the typed `org.apache.datafusion.*` exception hierarchy. +/// +/// The generated Java-side surface (declare these as `static native` on the +/// class named by `jni_class`): +/// +/// ```java +/// static native byte[] providerSchemaIpc(byte[] options, byte[] partition); +/// static native long createScan(byte[] options, byte[] partition, +/// int targetPartitions, int batchSize, String[] optionKeys, +/// String[] optionValues, String[] projectionColumns, byte[][] filterProtos); +/// static native int partitionCount(long scanHandle); +/// static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); +/// static native void executeStream(long scanHandle, long ffiStreamAddr); +/// static native void closeScan(long scanHandle); +/// ``` +#[macro_export] +macro_rules! export_bridge { + (jni_class: $cls:literal, build_provider: $builder:expr $(,)?) => { + const _: () = { + use $crate::jni::objects::{JByteArray, JClass, JObjectArray}; + use $crate::jni::sys::{jbyteArray, jint, jlong}; + use $crate::jni::JNIEnv; + + fn __df_bridge_build( + env: &mut JNIEnv, + options: &JByteArray, + partition: &JByteArray, + ) -> $crate::JniResult> + { + let opts: Vec = if options.is_null() { + Vec::new() + } else { + env.convert_byte_array(options)? + }; + let part: Vec = if partition.is_null() { + Vec::new() + } else { + env.convert_byte_array(partition)? + }; + let ctx = $crate::BridgeContext::get(); + ($builder)(&ctx, opts.as_slice(), part.as_slice()) + } + + #[export_name = concat!("Java_", $cls, "_providerSchemaIpc")] + extern "system" fn __df_bridge_provider_schema_ipc<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + options: JByteArray<'local>, + partition: JByteArray<'local>, + ) -> jbyteArray { + $crate::scan::provider_schema_ipc(&mut env, |env| { + __df_bridge_build(env, &options, &partition) + }) + } + + #[export_name = concat!("Java_", $cls, "_createScan")] + #[allow(clippy::too_many_arguments)] + extern "system" fn __df_bridge_create_scan<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + options: JByteArray<'local>, + partition: JByteArray<'local>, + target_partitions: jint, + batch_size: jint, + option_keys: JObjectArray<'local>, + option_values: JObjectArray<'local>, + projection_columns: JObjectArray<'local>, + filter_protos: JObjectArray<'local>, + ) -> jlong { + $crate::scan::create_scan( + &mut env, + |env| __df_bridge_build(env, &options, &partition), + target_partitions, + batch_size, + &option_keys, + &option_values, + &projection_columns, + &filter_protos, + ) + } + + #[export_name = concat!("Java_", $cls, "_partitionCount")] + extern "system" fn __df_bridge_partition_count<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ) -> jint { + $crate::scan::partition_count(&mut env, handle) + } + + #[export_name = concat!("Java_", $cls, "_executeStreamPartition")] + extern "system" fn __df_bridge_execute_stream_partition<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + partition: jint, + ffi_stream_addr: jlong, + ) { + $crate::scan::execute_stream_partition(&mut env, handle, partition, ffi_stream_addr) + } + + #[export_name = concat!("Java_", $cls, "_executeStream")] + extern "system" fn __df_bridge_execute_stream<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ffi_stream_addr: jlong, + ) { + $crate::scan::execute_stream(&mut env, handle, ffi_stream_addr) + } + + #[export_name = concat!("Java_", $cls, "_closeScan")] + extern "system" fn __df_bridge_close_scan<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + handle: jlong, + ) { + $crate::scan::close_scan(&mut env, handle) + } + }; + }; +} diff --git a/spark/bridge/src/options.rs b/spark/bridge/src/options.rs new file mode 100644 index 0000000..117ca9d --- /dev/null +++ b/spark/bridge/src/options.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder for the connector's default options wire format. +//! +//! `BridgeProviderFactory.encodeOptions`'s default (`OptionsCodec` on the JVM +//! side) encodes the Spark options map as length-prefixed UTF-8 pairs, +//! sorted by key: big-endian `i32` entry count, then per entry key length, +//! key bytes, value length, value bytes. Key-sorting makes the bytes a pure +//! function of the map contents — the shared-scan determinism contract uses +//! the options bytes as the scan identity. +//! +//! Bridges using the default JVM encoding read their options here: +//! +//! ```ignore +//! let opts = datafusion_spark_bridge::options::decode_options(options_bytes)?; +//! let url = opts.get("url").ok_or("missing required option 'url'")?; +//! ``` +//! +//! The two implementations are pinned to each other by the shared fixture in +//! the tests below; `OptionsCodecTest` on the JVM side asserts the same +//! bytes. + +use std::collections::BTreeMap; + +/// Decode bytes produced by the JVM `OptionsCodec.encode` (or +/// [`encode_options`]). Empty input decodes as an empty map. +pub fn decode_options(bytes: &[u8]) -> Result, String> { + let mut out = BTreeMap::new(); + if bytes.is_empty() { + return Ok(out); + } + let mut cursor = Cursor { bytes, pos: 0 }; + let count = cursor.read_len("entry count")?; + for i in 0..count { + let key = cursor.read_string(&format!("key of entry {i}"))?; + let value = cursor.read_string(&format!("value of entry {i}"))?; + out.insert(key, value); + } + if cursor.pos != bytes.len() { + return Err(format!( + "options blob has {} trailing byte(s) after {count} entries", + bytes.len() - cursor.pos + )); + } + Ok(out) +} + +/// Encode in the same format (key-sorted via `BTreeMap`). Primarily for +/// tests and Rust-side tooling; production encoding normally happens on the +/// JVM driver. +pub fn encode_options(options: &BTreeMap) -> Vec { + let mut out = Vec::new(); + out.extend_from_slice(&(options.len() as i32).to_be_bytes()); + for (key, value) in options { + out.extend_from_slice(&(key.len() as i32).to_be_bytes()); + out.extend_from_slice(key.as_bytes()); + out.extend_from_slice(&(value.len() as i32).to_be_bytes()); + out.extend_from_slice(value.as_bytes()); + } + out +} + +struct Cursor<'a> { + bytes: &'a [u8], + pos: usize, +} + +impl Cursor<'_> { + fn read_len(&mut self, what: &str) -> Result { + if self.bytes.len() - self.pos < 4 { + return Err(format!("options blob truncated reading {what}")); + } + let raw = i32::from_be_bytes(self.bytes[self.pos..self.pos + 4].try_into().unwrap()); + self.pos += 4; + usize::try_from(raw).map_err(|_| format!("negative length for {what}: {raw}")) + } + + fn read_string(&mut self, what: &str) -> Result { + let len = self.read_len(&format!("length of {what}"))?; + if self.bytes.len() - self.pos < len { + return Err(format!("options blob truncated reading {what}")); + } + let slice = &self.bytes[self.pos..self.pos + len]; + self.pos += len; + String::from_utf8(slice.to_vec()).map_err(|e| format!("{what} is not UTF-8: {e}")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Shared fixture: must stay byte-identical to the one asserted by the + /// JVM-side `OptionsCodecTest`. {"table": "t1", "url": "grpc://h:1"} + /// encodes (sorted: table < url) as below. + fn fixture_bytes() -> Vec { + let mut b = Vec::new(); + b.extend_from_slice(&2i32.to_be_bytes()); + for (k, v) in [("table", "t1"), ("url", "grpc://h:1")] { + b.extend_from_slice(&(k.len() as i32).to_be_bytes()); + b.extend_from_slice(k.as_bytes()); + b.extend_from_slice(&(v.len() as i32).to_be_bytes()); + b.extend_from_slice(v.as_bytes()); + } + b + } + + #[test] + fn decodes_fixture() { + let map = decode_options(&fixture_bytes()).unwrap(); + assert_eq!(map.len(), 2); + assert_eq!(map.get("table").map(String::as_str), Some("t1")); + assert_eq!(map.get("url").map(String::as_str), Some("grpc://h:1")); + } + + #[test] + fn round_trips() { + let mut map = BTreeMap::new(); + map.insert("b".to_string(), "2".to_string()); + map.insert("a".to_string(), "1".to_string()); + map.insert("unicode".to_string(), "héllo→world".to_string()); + let bytes = encode_options(&map); + assert_eq!(decode_options(&bytes).unwrap(), map); + } + + #[test] + fn empty_input_is_empty_map() { + assert!(decode_options(&[]).unwrap().is_empty()); + let empty = encode_options(&BTreeMap::new()); + assert!(decode_options(&empty).unwrap().is_empty()); + } + + #[test] + fn rejects_truncation_and_trailing_bytes() { + let bytes = fixture_bytes(); + assert!(decode_options(&bytes[..bytes.len() - 1]) + .unwrap_err() + .contains("truncated")); + let mut extended = bytes.clone(); + extended.push(0); + assert!(decode_options(&extended).unwrap_err().contains("trailing")); + } +} diff --git a/spark/bridge/src/scan.rs b/spark/bridge/src/scan.rs new file mode 100644 index 0000000..ad27dff --- /dev/null +++ b/spark/bridge/src/scan.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Planning and execution of a Spark scan. +//! +//! Every function here is the body of one JNI entry point generated by a +//! bridge's `export_bridge!` expansion, which supplies only how the provider +//! is obtained, as a `make` closure. The provider is wrapped in a +//! [`WideningTableProvider`] here, so every bridge gets identical +//! Spark-compatible Arrow types. +//! +//! [`create_scan`] registers the widened provider on a private +//! `SessionContext` built from the caller-pinned config, applies the pruned +//! projection and the proto-encoded pushed filters, and plans exactly once. +//! The returned handle supports: +//! +//! - [`partition_count`] — output partitions of the physical plan +//! (shared-scan mode probes this on the driver and indexes tasks by it); +//! - [`execute_stream_partition`] — an independent stream over ONE plan +//! partition, concurrently callable from multiple JVM threads +//! (`ExecutionPlan` and `TaskContext` are `Send + Sync`; each call only +//! clones their `Arc`s). Re-executing the same partition index (Spark +//! task retry / speculative execution) opens its own stream, but only +//! succeeds when every operator in that partition's pipeline supports +//! repeated `execute()` — stateless scans do, `RepartitionExec` +//! pipelines do not; +//! - [`execute_stream`] — the whole plan as one stream (per-partition +//! mode, where the provider itself is the task's slice); +//! - [`close_scan`] — drop the plan. The single unsafe interleaving is +//! closing a handle that still has an in-flight call; the Java consumer +//! (the shared-scan cache) prevents it with a refcount covering every +//! open reader. +//! +//! Pinned-config determinism: the driver resolves `target_partitions` / +//! `batch_size` / option overrides once and ships them to every executor, so +//! a plan that yields N partitions on the driver yields N everywhere. This +//! module applies whatever it is handed and stays policy-free. + +use std::sync::Arc; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; +use datafusion::arrow::ipc::writer::StreamWriter; +use datafusion::catalog::TableProvider; +use datafusion::dataframe::DataFrame; +use datafusion::execution::TaskContext; +use datafusion::physical_plan::{execute_stream as df_execute_stream, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_jni_common::errors::{try_unwrap_or_throw, JniResult}; +use datafusion_jni_common::StreamingReader; +use datafusion_proto::logical_plan::from_proto::parse_expr; +use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; +use datafusion_proto::protobuf::LogicalExprNode; +use jni::objects::{JByteArray, JObjectArray, JString}; +use jni::sys::{jbyteArray, jint, jlong}; +use jni::JNIEnv; +use prost::Message; + +use crate::runtime_handle; +use crate::widening::WideningTableProvider; + +/// Registration name of the (single) provider on the scan's private context. +/// Never surfaces in SQL — the plan is built through the DataFrame API — so +/// no quoting/collision concerns. +const SCAN_TABLE_NAME: &str = "df_spark_scan"; + +struct ScanState { + /// Kept alive for the plan's lifetime; the registered provider and the + /// runtime env both hang off it. + _ctx: SessionContext, + plan: Arc, + task_ctx: Arc, +} + +fn widen(provider: Arc) -> Arc { + Arc::new(WideningTableProvider::new(provider)) +} + +fn collect_string_array(env: &mut JNIEnv, arr: &JObjectArray) -> JniResult> { + if arr.is_null() { + return Ok(Vec::new()); + } + let len = env.get_array_length(arr)?; + let mut owned: Vec = Vec::with_capacity(len as usize); + for i in 0..len { + let elem = env.get_object_array_element(arr, i)?; + let jstr: JString = elem.into(); + owned.push(env.get_string(&jstr)?.into()); + } + Ok(owned) +} + +fn collect_byte_arrays(env: &mut JNIEnv, arr: &JObjectArray) -> JniResult>> { + if arr.is_null() { + return Ok(Vec::new()); + } + let len = env.get_array_length(arr)?; + let mut owned: Vec> = Vec::with_capacity(len as usize); + for i in 0..len { + let elem = env.get_object_array_element(arr, i)?; + let bytes: JByteArray = elem.into(); + owned.push(env.convert_byte_array(&bytes)?); + } + Ok(owned) +} + +/// Driver-side schema probe: widened Arrow schema of the provider, as IPC +/// bytes (deserialized JVM-side with `MessageSerializer.deserializeSchema`). +/// `make` runs once; the provider drops before returning. +pub fn provider_schema_ipc( + env: &mut JNIEnv, + make: impl FnOnce(&mut JNIEnv) -> JniResult>, +) -> jbyteArray { + try_unwrap_or_throw(env, std::ptr::null_mut(), |env| -> JniResult { + let widened = widen(make(env)?); + let schema = widened.schema(); + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema.as_ref())?; + writer.finish()?; + } + let arr = env.byte_array_from_slice(&buf)?; + Ok(arr.into_raw()) + }) +} + +/// Build the scan: widen the provider from `make`, register it on a private +/// context with the pinned config, apply projection + pushed filters, plan +/// once. +/// +/// `target_partitions` / `batch_size` <= 0 leave the DataFusion defaults; +/// `option_keys`/`option_values` are parallel arrays of config overrides; +/// empty `projection_columns` selects all columns; each element of +/// `filter_protos` is a serialized `datafusion.LogicalExprNode`. +#[allow(clippy::too_many_arguments)] +pub fn create_scan( + env: &mut JNIEnv, + make: impl FnOnce(&mut JNIEnv) -> JniResult>, + target_partitions: jint, + batch_size: jint, + option_keys: &JObjectArray, + option_values: &JObjectArray, + projection_columns: &JObjectArray, + filter_protos: &JObjectArray, +) -> jlong { + try_unwrap_or_throw(env, 0, |env| -> JniResult { + let widened = widen(make(env)?); + + let keys = collect_string_array(env, option_keys)?; + let values = collect_string_array(env, option_values)?; + if keys.len() != values.len() { + return Err(format!( + "option key/value arrays differ in length: {} vs {}", + keys.len(), + values.len() + ) + .into()); + } + let projection = collect_string_array(env, projection_columns)?; + let filters = collect_byte_arrays(env, filter_protos)?; + + let mut config = SessionConfig::new(); + if target_partitions > 0 { + config = config.with_target_partitions(target_partitions as usize); + } + if batch_size > 0 { + config = config.with_batch_size(batch_size as usize); + } + for (key, value) in keys.iter().zip(values.iter()) { + config.options_mut().set(key, value)?; + } + + let ctx = SessionContext::new_with_config(config); + ctx.register_table(SCAN_TABLE_NAME, widened)?; + + let mut df: DataFrame = runtime_handle().block_on(ctx.table(SCAN_TABLE_NAME))?; + if !projection.is_empty() { + let refs: Vec<&str> = projection.iter().map(String::as_str).collect(); + df = df.select_columns(&refs)?; + } + for bytes in &filters { + let node = LogicalExprNode::decode(bytes.as_slice())?; + // TaskContext implements FunctionRegistry; the default codec is + // enough because the translator only emits column/literal/builtin + // expressions. + let registry = df.task_ctx(); + let expr = parse_expr(&node, ®istry, &DefaultLogicalExtensionCodec {})?; + df = df.filter(expr)?; + } + + // task_ctx() borrows; capture before create_physical_plan consumes df. + let task_ctx = Arc::new(df.task_ctx()); + let plan = runtime_handle().block_on(df.create_physical_plan())?; + + let state = ScanState { + _ctx: ctx, + plan, + task_ctx, + }; + Ok(Box::into_raw(Box::new(state)) as jlong) + }) +} + +/// Output partition count of the planned physical plan. +pub fn partition_count(env: &mut JNIEnv, handle: jlong) -> jint { + try_unwrap_or_throw(env, 0, |_env| -> JniResult { + if handle == 0 { + return Err("scan handle is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + Ok(state + .plan + .properties() + .output_partitioning() + .partition_count() as jint) + }) +} + +/// Open an independent stream over one plan partition, writing an +/// `FFI_ArrowArrayStream` into the caller-allocated struct at +/// `ffi_stream_addr`. +pub fn execute_stream_partition( + env: &mut JNIEnv, + handle: jlong, + partition: jint, + ffi_stream_addr: jlong, +) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + if ffi_stream_addr == 0 { + return Err("ffi stream address is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + + let partition_count = state + .plan + .properties() + .output_partitioning() + .partition_count(); + if partition < 0 || partition as usize >= partition_count { + return Err(format!( + "partition index {partition} out of range: plan has {partition_count} partition(s)" + ) + .into()); + } + + let plan = Arc::clone(&state.plan); + let task_ctx = Arc::clone(&state.task_ctx); + let schema: SchemaRef = plan.schema(); + + // ExecutionPlan::execute is synchronous, but operators may + // tokio::spawn at execute() time (RepartitionExec et al.), which + // requires a runtime context to be entered. + let stream = { + let _guard = runtime_handle().enter(); + plan.execute(partition as usize, task_ctx)? + }; + + let reader = StreamingReader { schema, stream }; + let ffi = FFI_ArrowArrayStream::new(Box::new(reader)); + unsafe { + std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi); + } + Ok(()) + }) +} + +/// Whole-plan stream for per-partition mode (the provider +/// itself is the task's slice, so all plan partitions merge into one reader). +pub fn execute_stream(env: &mut JNIEnv, handle: jlong, ffi_stream_addr: jlong) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + if ffi_stream_addr == 0 { + return Err("ffi stream address is null".into()); + } + let state = unsafe { &*(handle as *const ScanState) }; + + let plan = Arc::clone(&state.plan); + let task_ctx = Arc::clone(&state.task_ctx); + let schema: SchemaRef = plan.schema(); + + // execute_stream coalesces multi-partition plans behind one stream. + let stream = { + let _guard = runtime_handle().enter(); + df_execute_stream(plan, task_ctx)? + }; + + let reader = StreamingReader { schema, stream }; + let ffi = FFI_ArrowArrayStream::new(Box::new(reader)); + unsafe { + std::ptr::write(ffi_stream_addr as *mut FFI_ArrowArrayStream, ffi); + } + Ok(()) + }) +} + +/// Drop the planned scan. Must not race an in-flight stream-open on the same +/// handle; the Java consumer's refcount enforces this. +pub fn close_scan(env: &mut JNIEnv, handle: jlong) { + try_unwrap_or_throw(env, (), |_env| -> JniResult<()> { + if handle == 0 { + return Err("scan handle is null".into()); + } + drop(unsafe { Box::from_raw(handle as *mut ScanState) }); + Ok(()) + }) +} diff --git a/spark/bridge/src/widening.rs b/spark/bridge/src/widening.rs new file mode 100644 index 0000000..86c4abf --- /dev/null +++ b/spark/bridge/src/widening.rs @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernel-level Arrow type widening for Spark consumption. +//! +//! Spark 3.5's `ArrowColumnVector` has no accessor for unsigned ints, Time*, +//! Float16, or non-microsecond Timestamp. The widening machinery here wraps +//! an inner `TableProvider` with one that exposes a "widened" schema — +//! UInt*→Int wider, Float16→Float32, Time*→Int wider, Timestamp(*, tz)→ +//! Timestamp(Microsecond, tz), recursing into List/LargeList/FixedSizeList +//! children — and applies `arrow::compute::cast` to each produced +//! RecordBatch column-wise. No SQL, no SessionContext, no view machinery. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Schema as ArrowSchema, SchemaRef, TimeUnit}; +use async_trait::async_trait; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{DataFusionError, Result}; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, +}; +use futures::stream::StreamExt; + +/// Compute the cast-target DataType for an Arrow type not directly readable +/// by Spark's `ArrowColumnVector`. Returns `None` if the type passes through. +pub fn arrow_cast_widening(dt: &DataType) -> Option { + match dt { + DataType::UInt8 => Some(DataType::Int16), + DataType::UInt16 => Some(DataType::Int32), + DataType::UInt32 => Some(DataType::Int64), + // UInt64 → Int64: lossy for values ≥ 2^63. Documented in REARCHITECTURE.md. + DataType::UInt64 => Some(DataType::Int64), + DataType::Float16 => Some(DataType::Float32), + DataType::Time32(_) => Some(DataType::Int32), + DataType::Time64(_) => Some(DataType::Int64), + DataType::Timestamp(unit, tz) => { + if *unit == TimeUnit::Microsecond { + None + } else { + Some(DataType::Timestamp(TimeUnit::Microsecond, tz.clone())) + } + } + DataType::List(field) => arrow_cast_widening(field.data_type()) + .map(|inner| DataType::List(widened_child(field, inner))), + DataType::LargeList(field) => arrow_cast_widening(field.data_type()) + .map(|inner| DataType::LargeList(widened_child(field, inner))), + // Spark 3.5's ArrowColumnVector cannot read FixedSizeList at all, so + // always convert it to a (variable) List — which Spark maps to + // ArrayType — widening the child element type when needed too. + DataType::FixedSizeList(field, _size) => { + let child = match arrow_cast_widening(field.data_type()) { + Some(inner) => widened_child(field, inner), + None => Arc::clone(field), + }; + Some(DataType::List(child)) + } + _ => None, + } +} + +fn widened_child(field: &Arc, new_type: DataType) -> Arc { + Arc::new(Field::new(field.name(), new_type, field.is_nullable())) +} + +/// Build the widened schema by walking inner fields and replacing types. +/// Returns the widened schema plus per-column target types (None where no cast). +fn widened_schema(inner: &ArrowSchema) -> (SchemaRef, Vec>) { + let mut fields = Vec::with_capacity(inner.fields().len()); + let mut targets = Vec::with_capacity(inner.fields().len()); + for f in inner.fields() { + match arrow_cast_widening(f.data_type()) { + Some(target) => { + fields.push(Arc::new(Field::new( + f.name(), + target.clone(), + f.is_nullable(), + ))); + targets.push(Some(target)); + } + None => { + fields.push(Arc::clone(f)); + targets.push(None); + } + } + } + (Arc::new(ArrowSchema::new(fields)), targets) +} + +/// TableProvider wrapping an inner provider, exposing a widened schema and +/// emitting RecordBatches whose columns have been cast to the widened types. +#[derive(Debug)] +pub struct WideningTableProvider { + inner: Arc, + widened: SchemaRef, + /// Targets indexed by the inner-schema column position; `None` = pass through. + targets: Vec>, +} + +impl WideningTableProvider { + pub fn new(inner: Arc) -> Self { + let (widened, targets) = widened_schema(&inner.schema()); + Self { + inner, + widened, + targets, + } + } +} + +#[async_trait] +impl TableProvider for WideningTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.widened) + } + + fn table_type(&self) -> TableType { + self.inner.table_type() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + self.inner.supports_filters_pushdown(filters) + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let inner_plan = self.inner.scan(session, projection, filters, limit).await?; + let (projected_widened, projected_targets) = match projection { + Some(idxs) => { + let fields: Vec> = idxs + .iter() + .map(|i| Arc::clone(&self.widened.fields()[*i])) + .collect(); + let targets: Vec> = + idxs.iter().map(|i| self.targets[*i].clone()).collect(); + (Arc::new(ArrowSchema::new(fields)) as SchemaRef, targets) + } + None => (Arc::clone(&self.widened), self.targets.clone()), + }; + Ok(Arc::new(WideningExec::new( + inner_plan, + projected_widened, + projected_targets, + ))) + } +} + +/// ExecutionPlan that runs the inner plan and casts each output RecordBatch +/// column-wise per the supplied targets. Pure stream-map wrapper; no +/// buffering, no internal state. +pub struct WideningExec { + inner: Arc, + schema: SchemaRef, + /// One entry per output column; `None` = pass through. + targets: Vec>, + properties: Arc, +} + +impl WideningExec { + fn new( + inner: Arc, + schema: SchemaRef, + targets: Vec>, + ) -> Self { + let inner_props = inner.properties(); + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + inner_props.partitioning.clone(), + inner_props.emission_type, + inner_props.boundedness, + )); + Self { + inner, + schema, + targets, + properties, + } + } +} + +impl fmt::Debug for WideningExec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WideningExec") + .field("schema", &self.schema) + .field("targets", &self.targets) + .finish() + } +} + +impl DisplayAs for WideningExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let cast_count = self.targets.iter().filter(|t| t.is_some()).count(); + write!(f, "WideningExec: casts={cast_count}") + } +} + +impl ExecutionPlan for WideningExec { + fn name(&self) -> &str { + "WideningExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return Err(DataFusionError::Internal( + "WideningExec::with_new_children expects exactly one child".to_string(), + )); + } + Ok(Arc::new(WideningExec::new( + children.into_iter().next().unwrap(), + Arc::clone(&self.schema), + self.targets.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let inner_stream = self.inner.execute(partition, context)?; + let schema = Arc::clone(&self.schema); + let targets = self.targets.clone(); + let mapped = inner_stream.map(move |batch_res| match batch_res { + Err(e) => Err(e), + Ok(batch) => cast_batch(&batch, &schema, &targets), + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + mapped, + ))) + } +} + +fn cast_batch( + batch: &RecordBatch, + out_schema: &SchemaRef, + targets: &[Option], +) -> Result { + if batch.num_columns() != targets.len() { + return Err(DataFusionError::Internal(format!( + "WideningExec: produced batch has {} columns, expected {}", + batch.num_columns(), + targets.len() + ))); + } + let mut new_cols = Vec::with_capacity(batch.num_columns()); + for (col, target) in batch.columns().iter().zip(targets.iter()) { + match target { + Some(t) => new_cols.push(cast(col, t).map_err(DataFusionError::from)?), + None => new_cols.push(Arc::clone(col)), + } + } + RecordBatch::try_new(Arc::clone(out_schema), new_cols).map_err(DataFusionError::from) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn unsigned_ints_widen_to_signed_wider() { + assert_eq!(arrow_cast_widening(&DataType::UInt8), Some(DataType::Int16)); + assert_eq!( + arrow_cast_widening(&DataType::UInt16), + Some(DataType::Int32) + ); + assert_eq!( + arrow_cast_widening(&DataType::UInt32), + Some(DataType::Int64) + ); + assert_eq!( + arrow_cast_widening(&DataType::UInt64), + Some(DataType::Int64) + ); + } + + #[test] + fn float16_widens_to_float32() { + assert_eq!( + arrow_cast_widening(&DataType::Float16), + Some(DataType::Float32) + ); + } + + #[test] + fn time_widens_to_int() { + assert_eq!( + arrow_cast_widening(&DataType::Time32(TimeUnit::Millisecond)), + Some(DataType::Int32) + ); + assert_eq!( + arrow_cast_widening(&DataType::Time64(TimeUnit::Nanosecond)), + Some(DataType::Int64) + ); + } + + #[test] + fn timestamp_normalizes_unit_preserving_tz() { + let ns = DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("UTC"))); + assert_eq!( + arrow_cast_widening(&ns), + Some(DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from("UTC")) + )) + ); + let us_no_tz = DataType::Timestamp(TimeUnit::Microsecond, None); + assert_eq!(arrow_cast_widening(&us_no_tz), None); + } + + #[test] + fn list_recurses_into_children() { + let inner_field = Arc::new(Field::new("item", DataType::UInt16, true)); + let list_ty = DataType::List(inner_field); + let widened = arrow_cast_widening(&list_ty).expect("should widen"); + match widened { + DataType::List(field) => assert_eq!(field.data_type(), &DataType::Int32), + other => panic!("expected List, got {other:?}"), + } + } + + #[test] + fn signed_int_passes_through() { + assert_eq!(arrow_cast_widening(&DataType::Int32), None); + assert_eq!(arrow_cast_widening(&DataType::Utf8), None); + } +} diff --git a/spark/bridge/tests/export_macro.rs b/spark/bridge/tests/export_macro.rs new file mode 100644 index 0000000..14751c8 --- /dev/null +++ b/spark/bridge/tests/export_macro.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Compile-level test of `export_bridge!`: the macro must expand to valid +//! `extern "system"` items against a plain builder function. JNI entry +//! points can't be exercised without a live JVM, so the assertion here is +//! that this test crate links with the generated symbols present. + +use std::sync::Arc; + +use datafusion_spark_bridge::datafusion::arrow::datatypes::Schema; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::datafusion::datasource::MemTable; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +fn build_provider( + _ctx: &BridgeContext, + _options: &[u8], + _partition: &[u8], +) -> JniResult> { + let schema = Arc::new(Schema::empty()); + let table = MemTable::try_new(schema, vec![vec![]])?; + Ok(Arc::new(table)) +} + +export_bridge! { + jni_class: "com_example_testbridge_BridgeNative", + build_provider: build_provider, +} + +#[test] +fn builder_contract_runs_outside_jvm() { + // Expansion + linking is the macro test; this additionally runs the + // builder through the same BridgeContext the expansion hands it. + let ctx = BridgeContext::get(); + let provider = build_provider(&ctx, &[], &[]).expect("builder failed"); + assert_eq!(provider.schema().fields().len(), 0); +} From fd9d42ba1f8387777e16a7b0bbab36104b4694d8 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jun 2026 13:25:53 +0200 Subject: [PATCH 3/5] feat(spark): add Spark connector Java SPI module Introduce the `spark` Maven module and the pure-Java contracts a bridge implements: BridgeProviderFactory (no-arg factory + scanBackend()), ScanBackend (delegates to the bridge's JNI methods), NativeLibraryLoader (cdylib extraction/loading), OptionsCodec (cross-language options encoder), PartitionInfo (one entry per Spark task), and ReportedPartitioning (optional shuffle-elision declaration). Compiles standalone with no Scala main yet. Includes the two SPI-only tests (OptionsCodecTest, BridgeProviderFactoryDefaultsTest). Third of the 6-PR stack. Co-Authored-By: Claude Opus 4.8 (1M context) --- pom.xml | 1 + spark/pom.xml | 150 ++++++++++++++++ .../spark/BridgeProviderFactory.java | 160 ++++++++++++++++++ .../datafusion/spark/NativeLibraryLoader.java | 107 ++++++++++++ .../io/datafusion/spark/OptionsCodec.java | 113 +++++++++++++ .../io/datafusion/spark/PartitionInfo.java | 74 ++++++++ .../spark/ReportedPartitioning.java | 87 ++++++++++ .../java/io/datafusion/spark/ScanBackend.java | 79 +++++++++ .../BridgeProviderFactoryDefaultsTest.scala | 112 ++++++++++++ .../datafusion/spark/OptionsCodecTest.scala | 89 ++++++++++ 10 files changed, 972 insertions(+) create mode 100644 spark/pom.xml create mode 100644 spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java create mode 100644 spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java create mode 100644 spark/src/main/java/io/datafusion/spark/OptionsCodec.java create mode 100644 spark/src/main/java/io/datafusion/spark/PartitionInfo.java create mode 100644 spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java create mode 100644 spark/src/main/java/io/datafusion/spark/ScanBackend.java create mode 100644 spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala create mode 100644 spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala diff --git a/pom.xml b/pom.xml index b92cf72..6baeb94 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ under the License. core + spark examples diff --git a/spark/pom.xml b/spark/pom.xml new file mode 100644 index 0000000..90e4e6d --- /dev/null +++ b/spark/pom.xml @@ -0,0 +1,150 @@ + + + + 4.0.0 + + + org.apache.datafusion + datafusion-java-parent + 0.2.0-SNAPSHOT + + + datafusion-java-spark_2.13 + jar + + Apache DataFusion Java Spark Connector + + Generic Spark DataSource V2 connector for DataFusion TableProviders. + Domain bridges implement BridgeProviderFactory over a cdylib built + with the datafusion-spark-bridge Rust SDK; this module supplies the + Spark plumbing, predicate translation, Arrow-to-Spark schema + conversion, and the shared-scan cache. Pure JVM artifact — the + native code ships inside each bridge's own jar. + + + + 2.13 + 2.13.14 + 3.5.7 + + + + + org.scala-lang + scala-library + ${scala.version} + + + + org.apache.spark + spark-core_${scala.compat.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + + org.apache.datafusion + datafusion-java + + + + org.apache.arrow + arrow-vector + + + org.apache.arrow + arrow-c-data + + + org.apache.arrow + arrow-memory-netty + runtime + + + + org.scalatest + scalatest_${scala.compat.version} + 3.2.18 + test + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.8.1 + + + + compile + testCompile + + + + + ${scala.version} + + -deprecation + -feature + -unchecked + + all + + + + org.apache.maven.plugins + maven-surefire-plugin + + --add-opens=java.base/java.nio=ALL-UNNAMED + true + + + + org.scalatest + scalatest-maven-plugin + 2.2.0 + + ${project.build.directory}/scalatest-reports + . + WDF TestSuite.txt + --add-opens=java.base/java.nio=ALL-UNNAMED + + + + test + test + + + + + + + diff --git a/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java new file mode 100644 index 0000000..3bcf7ad --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.util.Map; + +/** + * Bridge interface implemented per domain (HDF5, custom Iceberg, an in-house format, etc.). A + * bridge owns its options encoding and a native scan implementation built with {@code + * datafusion_spark_bridge::export_bridge!}; the connector Spark plumbing is generic — it knows only + * this interface. + * + *

The single required method is {@link #scanBackend()}, returning the delegations to the JNI + * class the bridge named in its {@code export_bridge!} invocation. Everything else has a working + * default: {@link #encodeOptions(Map)} encodes the Spark options via {@link OptionsCodec}, and + * {@link #listPartitions(byte[])} reports a single partition. + * + *

Implementations must be no-arg constructable so the Spark connector can instantiate them + * reflectively via {@link Class#forName(String)} on the executor. + */ +public interface BridgeProviderFactory { + + /** + * The native scan implementation this bridge talks to: delegations to the JNI class named in the + * bridge's {@code export_bridge!} invocation, whose generated {@code createScan} builds the + * provider from the options/partition bytes in process. Called wherever the connector needs + * native work — driver-side schema/plan probes and executor-side streams — always on a factory + * freshly instantiated from its class name, so the returned backend never has to be serializable. + */ + ScanBackend scanBackend(); + + /** + * Convert Spark's flat option map to the bridge's encoded options. Driver-side only; the bytes + * ship verbatim through {@code DatafusionInputPartition} and are the scan's identity in + * shared-scan mode (encode deterministically). + * + *

Default: {@link OptionsCodec#encode(Map)} — the key-sorted length-prefixed pair format that + * {@code datafusion_spark_bridge::options} decodes on the Rust side. Override only if the bridge + * already has its own options schema (e.g. a protobuf). + * + * @throws IllegalArgumentException if required options are missing or invalid + */ + default byte[] encodeOptions(Map sparkOptions) { + return OptionsCodec.encode(sparkOptions); + } + + /** + * Enumerate partitions for this dataset. One Spark task is created per returned {@link + * PartitionInfo}. Driver-side only. + * + *

Each partition's {@code partitionBytes} ships verbatim through {@code + * DatafusionInputPartition} to the executor, where it is passed to {@link + * ScanBackend#createScan}. Use it to encode whatever slice metadata (row range, sub-options, file + * offsets, segment id, …) the bridge needs to materialise *that* partition. + * + *

Each partition's {@code preferredLocations} hostnames are returned from {@code + * InputPartition.preferredLocations()} so Spark co-locates the task with the data; empty array = + * no preference. + * + *

Default: one partition ({@code "p0"}, empty payload, no host preference) — one Spark task + * scans the whole dataset. Fine for small tables and first bring-up; override (or opt into {@link + * #sharedScan(byte[])}) before pointing it at anything large. Size guidance lives in {@code + * spark/README.md}. + */ + default PartitionInfo[] listPartitions(byte[] optionsBytes) { + return new PartitionInfo[] {new PartitionInfo("p0", new byte[0], new String[0])}; + } + + /** + * Filter-aware variant of {@link #listPartitions(byte[])}. The connector calls this overload with + * the pushed-down predicates ({@code LogicalExprNode} proto bytes, one array per predicate, same + * encoding the executor later replays via {@link ScanBackend#createScan}). Bridges that can map + * predicates onto their partition layout (e.g. {@code segment_id = 'x'}) should prune partitions + * that cannot match — pruning here eliminates whole Spark tasks, whereas the per-task filter only + * reduces rows inside a task. + * + *

Pruning must be conservative: only drop a partition when NO row in it can satisfy the + * conjunction of all pushed predicates. The default delegates to the filter-unaware overload (no + * pruning), which is always correct. + */ + default PartitionInfo[] listPartitions(byte[] optionsBytes, byte[][] filterProtoBytes) { + return listPartitions(optionsBytes); + } + + /** + * Opt into shared-scan mode for this dataset. Default {@code false} (per-partition payload mode, + * the {@link #listPartitions(byte[])} path). + * + *

When {@code true}, the connector builds ONE provider per (executor JVM × scan) with empty + * {@code partitionBytes}, plans it once, and runs one Spark task per DataFusion output partition + * — task {@code i} streams plan partition {@code i} from the shared, cached plan. This amortises + * provider construction cost across all tasks on an executor and is the right model when the + * dataset has many small partitions or provider construction is expensive (remote metadata, + * connections). {@link #listPartitions(byte[])} and {@link #reportPartitioning(byte[])} are NOT + * called in this mode, and the scan reports {@code UnknownPartitioning} (DataFusion-native + * partitions carry no key contract). + * + *

Determinism contract. The driver counts partitions by planning once; every executor + * re-plans independently and must arrive at the same result. A bridge returning {@code true} + * guarantees: + * + *

    + *
  • The provider's schema, partitioning, and per-partition row content are a pure function of + * {@code optionsBytes}. Remote sources must pin a snapshot (version, timestamp) inside + * the options; data that compacts or moves between driver planning and executor execution + * otherwise yields wrong results that no runtime check can catch. + *
  • The provider's {@code ExecutionPlan} supports calling {@code execute(i)} more than once + * per plan instance (Spark task retry and speculative execution re-execute a partition + * index, sometimes concurrently). Stateless scans satisfy this; single-shot streams do not. + *
+ * + *

The connector fails tasks with a clear error when the executor's partition count diverges + * from the driver's — but identical counts with different contents cannot be detected. + */ + default boolean sharedScan(byte[] optionsBytes) { + return false; + } + + /** + * Declare how rows are partitioned across the {@link PartitionInfo} entries returned by {@link + * #listPartitions(byte[])}. Driver-side only. + * + *

When non-null, the connector surfaces a {@code KeyGroupedPartitioning(keys, + * listPartitions(...).length)} to Spark via {@code SupportsReportPartitioning} so the optimizer + * can elide shuffles ahead of joins/aggregations on the declared keys. + * + *

Default returns {@code null} — no partitioning guarantees, Spark plans as if the scan's + * output ordering and grouping are unknown. + * + *

If a bridge implements this, it must hold the {@link ReportedPartitioning} contract: every + * row in a given partition evaluates to the same tuple of key values under the declared + * transforms. + * + *

Spark 3.3+ caveat: the reported partitioning only takes effect when every {@link + * PartitionInfo} also carries {@link PartitionInfo#partitionKeyValues()} (surfaced to Spark via + * {@code HasPartitionKey}); without key values Spark ignores the declared {@code + * KeyGroupedPartitioning} entirely. Storage-partitioned joins additionally require {@code + * spark.sql.sources.v2.bucketing.enabled=true}. + */ + default ReportedPartitioning reportPartitioning(byte[] optionsBytes) { + return null; + } +} diff --git a/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java new file mode 100644 index 0000000..eb4766a --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Extracts a cdylib bundled inside a jar to a temp file and loads it via {@link System#load}. + * Expected layout inside the jar: + * + *

+ *   <resourcePrefix>/<os>/<arch>/lib<name>.<ext>
+ * 
+ * + * where {@code } is one of {@code linux}, {@code darwin}, {@code windows} and {@code } is + * {@code x86_64} or {@code aarch64}. + * + *

Bridges call {@link #load(Class, String, String)} from their native class's static + * initializer, with their own resource prefix, instead of hand-rolling extraction. Bundle the + * cdylib with the antrun-copy pattern shown in "Packaging your bridge" in {@code spark/README.md}. + */ +public final class NativeLibraryLoader { + + /** {@code /} entries already extracted and loaded by this classloader. */ + private static final Set LOADED = ConcurrentHashMap.newKeySet(); + + private NativeLibraryLoader() {} + + /** + * Extract {@code ///} from {@code anchor}'s classloader + * and {@link System#load} it. Idempotent per (prefix, name): repeated calls — e.g. one per Spark + * task instantiating the bridge's native class — load once. + * + * @param anchor class whose classloader holds the resource (the bridge's own native class, so the + * lookup works under Spark's per-application classloaders) + * @param resourcePrefix jar-internal directory, no leading or trailing slash (e.g. {@code + * "com/example/mybridge"}) + * @param name unmapped library name (e.g. {@code "my_bridge"} for {@code libmy_bridge.so}) + * @throws UnsatisfiedLinkError if the resource is missing or extraction fails + */ + public static void load(Class anchor, String resourcePrefix, String name) { + String key = resourcePrefix + "/" + name; + if (!LOADED.add(key)) { + return; + } + String resource = + String.format( + "/%s/%s/%s/%s", + resourcePrefix, currentOs(), currentArch(), System.mapLibraryName(name)); + try (InputStream in = anchor.getResourceAsStream(resource)) { + if (in == null) { + LOADED.remove(key); + throw new UnsatisfiedLinkError("Native library not found on classpath: " + resource); + } + Path tmp = Files.createTempFile("libdatafusion-spark-", "-" + System.mapLibraryName(name)); + tmp.toFile().deleteOnExit(); + Files.copy(in, tmp, StandardCopyOption.REPLACE_EXISTING); + System.load(tmp.toAbsolutePath().toString()); + } catch (IOException e) { + LOADED.remove(key); + throw new UnsatisfiedLinkError( + "Failed to extract native library " + resource + ": " + e.getMessage()); + } catch (RuntimeException | Error e) { + LOADED.remove(key); + throw e; + } + } + + private static String currentOs() { + String os = System.getProperty("os.name", "").toLowerCase(Locale.ROOT); + if (os.contains("linux")) return "linux"; + if (os.contains("mac") || os.contains("darwin")) return "darwin"; + if (os.contains("windows")) return "windows"; + throw new UnsupportedOperationException("Unsupported OS: " + os); + } + + private static String currentArch() { + String arch = System.getProperty("os.arch", "").toLowerCase(Locale.ROOT); + if (arch.equals("amd64") || arch.equals("x86_64")) return "x86_64"; + if (arch.equals("aarch64") || arch.equals("arm64")) return "aarch64"; + throw new UnsupportedOperationException("Unsupported arch: " + arch); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/OptionsCodec.java b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java new file mode 100644 index 0000000..0d16a28 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.TreeMap; + +/** + * Default wire format for {@link BridgeProviderFactory#encodeOptions(Map)}: the Spark options map + * as length-prefixed UTF-8 pairs, sorted by key. + * + *

Layout (all integers big-endian {@code int32}): entry count, then per entry key length, key + * bytes, value length, value bytes. Key-sorting makes the bytes a pure function of the map's + * contents regardless of source iteration order — required by the shared-scan determinism contract, + * where the options bytes are the cache/plan identity. + * + *

The Rust decoder lives in {@code datafusion_spark_bridge::options}; bridges using the default + * {@code encodeOptions} read their options there as a {@code BTreeMap}. The two + * implementations are pinned to each other by a shared test fixture. + */ +public final class OptionsCodec { + + private OptionsCodec() {} + + /** Encode {@code options} sorted by key. {@code null} or empty map encodes as count 0. */ + public static byte[] encode(Map options) { + TreeMap sorted = options == null ? new TreeMap<>() : new TreeMap<>(options); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeInt(out, sorted.size()); + for (Map.Entry e : sorted.entrySet()) { + if (e.getKey() == null || e.getValue() == null) { + throw new IllegalArgumentException("OptionsCodec does not accept null keys or values"); + } + writeBytes(out, e.getKey().getBytes(StandardCharsets.UTF_8)); + writeBytes(out, e.getValue().getBytes(StandardCharsets.UTF_8)); + } + return out.toByteArray(); + } + + /** Decode bytes produced by {@link #encode(Map)}. Preserves the encoded (sorted) order. */ + public static Map decode(byte[] bytes) { + Map out = new LinkedHashMap<>(); + if (bytes == null || bytes.length == 0) { + return out; + } + ByteBuffer buf = ByteBuffer.wrap(bytes); + int count = readCount(buf, "entry count"); + for (int i = 0; i < count; i++) { + String key = readString(buf, "key of entry " + i); + String value = readString(buf, "value of entry " + i); + out.put(key, value); + } + if (buf.hasRemaining()) { + throw new IllegalArgumentException( + "OptionsCodec: " + buf.remaining() + " trailing byte(s) after " + count + " entries"); + } + return out; + } + + private static void writeInt(ByteArrayOutputStream out, int v) { + out.write((v >>> 24) & 0xFF); + out.write((v >>> 16) & 0xFF); + out.write((v >>> 8) & 0xFF); + out.write(v & 0xFF); + } + + private static void writeBytes(ByteArrayOutputStream out, byte[] bytes) { + writeInt(out, bytes.length); + out.write(bytes, 0, bytes.length); + } + + private static int readCount(ByteBuffer buf, String what) { + if (buf.remaining() < 4) { + throw new IllegalArgumentException("OptionsCodec: truncated " + what); + } + int v = buf.getInt(); + if (v < 0) { + throw new IllegalArgumentException("OptionsCodec: negative " + what + ": " + v); + } + return v; + } + + private static String readString(ByteBuffer buf, String what) { + int len = readCount(buf, "length of " + what); + if (buf.remaining() < len) { + throw new IllegalArgumentException("OptionsCodec: truncated " + what); + } + byte[] bytes = new byte[len]; + buf.get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/PartitionInfo.java b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java new file mode 100644 index 0000000..e6e061b --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +/** + * Driver-side descriptor for a single partition produced by {@link + * BridgeProviderFactory#listPartitions(byte[])}. Carries the bridge-specific slice payload that the + * executor passes back into {@link ScanBackend#createScan}, plus + * optional host hints for Spark's scheduler. + * + *

Fields: + * + *

    + *
  • {@code id} — stable, human-readable identifier for this partition (e.g. a segment id). + * Surfaces in Spark UI, logs, and exception messages. Must be non-empty. + *
  • {@code partitionBytes} — opaque per-partition payload. Bridge encodes whatever the executor + * needs to materialise *this* slice (offsets, row ranges, sub-options, etc.). Combined with + * the global {@code optionsBytes} in {@link ScanBackend#createScan}. Empty array = no + * per-partition state (single-partition table). + *
  • {@code preferredLocations} — hostnames where this partition's data lives. Returned from + * {@code InputPartition.preferredLocations()} so Spark can co-locate the task with the data. + * Empty array = no preference. Honoured subject to {@code spark.locality.wait}. + *
  • {@code partitionKeyValues} — optional values of the partitioning keys for every row in this + * partition, in the same order as {@link BridgeProviderFactory#reportPartitioning(byte[])}'s + * declared transforms. {@code null} = no key (the default). When the bridge reports a + * partitioning AND every partition carries key values, the connector exposes them to Spark + * via {@code HasPartitionKey} — required on Spark 3.3+ for the reported {@code + * KeyGroupedPartitioning} to have any effect (and storage-partitioned joins additionally + * require {@code spark.sql.sources.v2.bucketing.enabled=true}). Values must be Java types + * that Spark's {@code CatalystTypeConverters} can convert for the key columns' data types + * (e.g. {@code String}, {@code Long}, {@code Integer}, {@code java.time.Instant}, {@code + * java.time.LocalDate}, {@code java.math.BigDecimal}), and the array length must equal the + * number of declared keys. + *
+ */ +public record PartitionInfo( + String id, byte[] partitionBytes, String[] preferredLocations, Object[] partitionKeyValues) { + + public PartitionInfo { + if (id == null || id.isEmpty()) { + throw new IllegalArgumentException("PartitionInfo: id must be non-empty"); + } + if (partitionBytes == null) { + partitionBytes = new byte[0]; + } + if (preferredLocations == null) { + preferredLocations = new String[0]; + } + // partitionKeyValues stays null when absent: null and "no key" are the same state, + // and DatafusionBatch distinguishes keyed from unkeyed partitions by it. + } + + /** Without partition key values — the common case. */ + public PartitionInfo(String id, byte[] partitionBytes, String[] preferredLocations) { + this(id, partitionBytes, preferredLocations, null); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java new file mode 100644 index 0000000..639fec9 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +import java.util.Arrays; + +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.Transform; + +/** + * Driver-side declaration of how a bridge's data is partitioned on the key columns. When supplied + * via {@link BridgeProviderFactory#reportPartitioning(byte[])}, the connector surfaces a {@link + * org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} from {@link + * org.apache.spark.sql.connector.read.SupportsReportPartitioning#outputPartitioning()} — Spark's + * optimizer can then skip the shuffle ahead of joins/aggregations whose grouping keys line up with + * these transforms. + * + *

Contract: for any partition reported by {@link BridgeProviderFactory#listPartitions(byte[])}, + * every row produced by that partition must evaluate to the same tuple of key values under these + * transforms. Different partitions may share key values (Spark will fuse them); they must + * not straddle key values. + * + *

The partition count Spark sees is {@code listPartitions(...).length}; it is not carried here + * to keep a single source of truth. + */ +public final class ReportedPartitioning { + + private final Transform[] keys; + + public ReportedPartitioning(Transform[] keys) { + if (keys == null || keys.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning: keys must contain at least one transform"); + } + this.keys = keys; + } + + public Transform[] keys() { + return keys; + } + + /** + * Convenience: declare identity partitioning on one or more columns (a row in partition P has the + * same {@code (col1, col2, …)} values as every other row in P). + */ + public static ReportedPartitioning identity(String... columns) { + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.identity: at least one column required"); + } + Transform[] ts = Arrays.stream(columns).map(Expressions::identity).toArray(Transform[]::new); + return new ReportedPartitioning(ts); + } + + /** + * Convenience: declare hash-bucket partitioning. Mirrors Spark's {@code bucket(N, cols…)} + * transform — each row is assigned to bucket {@code hash(cols) mod numBuckets}. + */ + public static ReportedPartitioning bucket(int numBuckets, String... columns) { + if (numBuckets <= 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.bucket: numBuckets must be > 0, got " + numBuckets); + } + if (columns == null || columns.length == 0) { + throw new IllegalArgumentException( + "ReportedPartitioning.bucket: at least one column required"); + } + return new ReportedPartitioning(new Transform[] {Expressions.bucket(numBuckets, columns)}); + } +} diff --git a/spark/src/main/java/io/datafusion/spark/ScanBackend.java b/spark/src/main/java/io/datafusion/spark/ScanBackend.java new file mode 100644 index 0000000..a994c98 --- /dev/null +++ b/spark/src/main/java/io/datafusion/spark/ScanBackend.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark; + +/** + * Native scan surface the connector plumbing talks to: one method per JNI entry point generated by + * the bridge's {@code datafusion_spark_bridge::export_bridge!} invocation. A bridge's + * implementation is six one-line delegations to the JNI class named in that macro, whose {@code + * createScan} builds the provider from {@code options}/{@code partitionBytes} in process. + * + *

Implementations must be stateless or thread-safe: the driver probes schemas and plans through + * one instance while executor tasks stream through others, and scan handles are shared across + * threads by the shared-scan cache. Handle-based methods accept handles produced by {@code + * createScan} on any instance of the same implementation. + */ +public interface ScanBackend { + + /** + * Driver-side schema probe: the widened Arrow schema of the provider described by {@code options} + * + {@code partitionBytes}, serialized as Arrow IPC bytes (deserialize with {@code + * MessageSerializer.deserializeSchema}). + */ + byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes); + + /** + * Build a planned scan and return its handle. {@code targetPartitions}/{@code batchSize} {@code + * <= 0} leave DataFusion defaults; {@code optionKeys}/{@code optionValues} are parallel config + * override arrays; empty {@code projectionColumns} selects all columns; each {@code filterProtos} + * element is a serialized {@code datafusion.LogicalExprNode}. + * + *

The caller owns the handle and must pair it with {@link #closeScan(long)}. Closing while a + * stream opened from the handle is in flight is undefined behaviour — the shared-scan cache's + * refcount enforces this; any other caller must serialize close itself. + */ + long createScan( + byte[] options, + byte[] partitionBytes, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos); + + /** Output partition count of the planned physical plan. */ + int partitionCount(long scanHandle); + + /** + * Open an independent stream over ONE plan partition, writing an {@code FFI_ArrowArrayStream} + * into the caller-allocated struct at {@code ffiStreamAddr}. Concurrent-safe across JVM threads. + */ + void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + + /** + * Stream the WHOLE plan (all partitions coalesced) into the caller-allocated {@code + * FFI_ArrowArrayStream} at {@code ffiStreamAddr}. Used by per-partition mode. + */ + void executeStream(long scanHandle, long ffiStreamAddr); + + /** Drop the planned scan. See {@link #createScan} for the close-vs-in-flight contract. */ + void closeScan(long scanHandle); +} diff --git a/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala new file mode 100644 index 0000000..0b94eee --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.scalatest.funsuite.AnyFunSuite + +class BridgeProviderFactoryDefaultsTest extends AnyFunSuite { + + /** Backend stub: the defaults under test never touch native code. */ + private object StubBackend extends ScanBackend { + def providerSchemaIpc(options: Array[Byte], partitionBytes: Array[Byte]): Array[Byte] = + throw new UnsupportedOperationException + def createScan( + options: Array[Byte], + partitionBytes: Array[Byte], + targetPartitions: Int, + batchSize: Int, + optionKeys: Array[String], + optionValues: Array[String], + projectionColumns: Array[String], + filterProtos: Array[Array[Byte]]): Long = throw new UnsupportedOperationException + def partitionCount(scanHandle: Long): Int = throw new UnsupportedOperationException + def executeStreamPartition(scanHandle: Long, partition: Int, ffiStreamAddr: Long): Unit = + throw new UnsupportedOperationException + def executeStream(scanHandle: Long, ffiStreamAddr: Long): Unit = + throw new UnsupportedOperationException + def closeScan(scanHandle: Long): Unit = throw new UnsupportedOperationException + } + + /** Factory overriding only listPartitions (to spy on its inputs). */ + private class MinimalFactory extends BridgeProviderFactory { + var lastListPartitionsOpts: Array[Byte] = _ + + override def scanBackend(): ScanBackend = StubBackend + + override def listPartitions(optionsBytes: Array[Byte]): Array[PartitionInfo] = { + lastListPartitionsOpts = optionsBytes + Array(new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String])) + } + } + + /** Only the required method implemented — the literal minimum a bridge can ship. */ + private class EmptyFactory extends BridgeProviderFactory { + override def scanBackend(): ScanBackend = StubBackend + } + + test("sharedScan defaults to false") { + assert(!new MinimalFactory().sharedScan(Array[Byte](1, 2, 3))) + } + + test("default encodeOptions uses OptionsCodec") { + val opts = new java.util.HashMap[String, String]() + opts.put("url", "grpc://h:1") + val bytes = new EmptyFactory().encodeOptions(opts) + assert(java.util.Arrays.equals(bytes, OptionsCodec.encode(opts))) + assert(OptionsCodec.decode(bytes).get("url") == "grpc://h:1") + } + + test("default listPartitions reports a single whole-dataset partition") { + val partitions = new EmptyFactory().listPartitions(Array[Byte](1)) + assert(partitions.length == 1) + assert(partitions(0).id == "p0") + assert(partitions(0).partitionBytes().isEmpty) + assert(partitions(0).preferredLocations().isEmpty) + } + + test("filter-aware listPartitions delegates to the filter-unaware overload") { + val factory = new MinimalFactory + val opts = Array[Byte](7, 8) + val filters = Array(Array[Byte](1), Array[Byte](2)) + val partitions = factory.listPartitions(opts, filters) + assert(partitions.length == 1) + assert(partitions(0).id == "p0") + assert(factory.lastListPartitionsOpts eq opts) + } + + test("reportPartitioning defaults to null") { + assert(new MinimalFactory().reportPartitioning(Array.emptyByteArray) == null) + } + + test("PartitionInfo 3-arg constructor leaves partitionKeyValues null") { + val p = new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String]) + assert(p.partitionKeyValues() == null) + } + + test("PartitionInfo 4-arg constructor carries key values") { + val p = new PartitionInfo( + "p0", + Array.emptyByteArray, + Array.empty[String], + Array[AnyRef]("segment-a", Long.box(42L))) + assert(p.partitionKeyValues().length == 2) + assert(p.partitionKeyValues()(0) == "segment-a") + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala new file mode 100644 index 0000000..59f6c8f --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.scalatest.funsuite.AnyFunSuite + +class OptionsCodecTest extends AnyFunSuite { + + /** + * Shared fixture: must stay byte-identical to the one asserted by the Rust-side + * `datafusion_spark_bridge::options` tests. {"table": "t1", "url": "grpc://h:1"} encodes + * (sorted: table < url) as below. + */ + private def fixtureBytes(): Array[Byte] = { + val out = new ByteArrayOutputStream() + def writeInt(v: Int): Unit = { + out.write((v >>> 24) & 0xFF); out.write((v >>> 16) & 0xFF) + out.write((v >>> 8) & 0xFF); out.write(v & 0xFF) + } + def writeString(s: String): Unit = { + val b = s.getBytes(StandardCharsets.UTF_8) + writeInt(b.length) + out.write(b, 0, b.length) + } + writeInt(2) + Seq("table" -> "t1", "url" -> "grpc://h:1").foreach { case (k, v) => + writeString(k); writeString(v) + } + out.toByteArray + } + + test("encodes the cross-language fixture byte-identically, sorted by key") { + // Insertion order deliberately unsorted; encoding must sort. + val opts = new java.util.LinkedHashMap[String, String]() + opts.put("url", "grpc://h:1") + opts.put("table", "t1") + assert(java.util.Arrays.equals(OptionsCodec.encode(opts), fixtureBytes())) + } + + test("round-trips including unicode values") { + val opts = new java.util.HashMap[String, String]() + opts.put("a", "1") + opts.put("unicode", "héllo→world") + val decoded = OptionsCodec.decode(OptionsCodec.encode(opts)) + assert(decoded.size() == 2) + assert(decoded.get("unicode") == "héllo→world") + } + + test("null and empty maps encode to a zero count and decode back empty") { + assert(OptionsCodec.decode(OptionsCodec.encode(null)).isEmpty) + assert(OptionsCodec.decode(Array.emptyByteArray).isEmpty) + } + + test("rejects truncation and trailing bytes") { + val bytes = fixtureBytes() + intercept[IllegalArgumentException] { + OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length - 1)) + } + intercept[IllegalArgumentException] { + OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length + 1)) + } + } + + test("rejects null keys or values") { + val opts = new java.util.HashMap[String, String]() + opts.put("k", null) + intercept[IllegalArgumentException] { OptionsCodec.encode(opts) } + } +} From 5db93ae3c5e29411855ba7a86ff0668f7831337c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jun 2026 13:26:29 +0200 Subject: [PATCH 4/5] feat(spark): Spark DataSource V2 connector (Scala) The connector implementation on top of the Java SPI and the bridge SDK: DatafusionSource/Table/Scan/ScanBuilder DSv2 wiring, per-partition columnar read path (FfiStream + Arrow->Spark batch conversion), V2 predicate pushdown (SparkPredicateTranslator), shared-scan mode with a per-executor refcounted cache (SharedScanCache, SharedScanPartitionReader, NativeSharedScanResources, PinnedSessionConfig), and SupportsReportPartitioning for shuffle elision. These pieces share the DatafusionScanMode sealed trait and the scan builder, so they land together. Includes the connector test suite and the module README. DataSourceRegister SPI file registers DatafusionSource. Fourth of the 6-PR stack. Co-Authored-By: Claude Opus 4.8 (1M context) --- spark/README.md | 454 ++++++++++++++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/ArrowColumnarBatchIteration.scala | 58 +++ .../datafusion/spark/ArrowToSparkSchema.scala | 152 ++++++ .../io/datafusion/spark/DatafusionBatch.scala | 127 +++++ .../DatafusionColumnarPartitionReader.scala | 97 ++++ .../spark/DatafusionInputPartition.scala | 115 +++++ .../DatafusionPartitionReaderFactory.scala | 53 ++ .../io/datafusion/spark/DatafusionScan.scala | 104 ++++ .../spark/DatafusionScanBuilder.scala | 151 ++++++ .../datafusion/spark/DatafusionSource.scala | 95 ++++ .../io/datafusion/spark/DatafusionTable.scala | 51 ++ .../scala/io/datafusion/spark/FfiStream.scala | 44 ++ .../spark/NativeSharedScanResources.scala | 100 ++++ .../spark/NonClosingArrowColumnVector.scala | 33 ++ .../spark/PinnedSessionConfig.scala | 80 +++ .../io/datafusion/spark/SharedScanCache.scala | 197 ++++++++ .../spark/SharedScanPartitionReader.scala | 82 ++++ .../spark/SparkPredicateTranslator.scala | 214 +++++++++ .../spark/ArrowToSparkSchemaTest.scala | 106 ++++ .../spark/PartitionKeyConversionTest.scala | 76 +++ .../spark/SharedScanCacheTest.scala | 195 ++++++++ .../spark/SparkPredicateTranslatorTest.scala | 87 ++++ 23 files changed, 2672 insertions(+) create mode 100644 spark/README.md create mode 100644 spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/FfiStream.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala create mode 100644 spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala create mode 100644 spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala create mode 100644 spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala create mode 100644 spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala create mode 100644 spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala diff --git a/spark/README.md b/spark/README.md new file mode 100644 index 0000000..5cc3d3c --- /dev/null +++ b/spark/README.md @@ -0,0 +1,454 @@ +# DataFusion Spark Connector + +This module (`datafusion-java-spark`) lets you expose a [DataFusion +`TableProvider`](https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html) +written in Rust as an [Apache Spark DataSource +V2](https://spark.apache.org/docs/latest/sql-data-sources.html) table. If you +have data that DataFusion can already read — an in-house file format, a custom +catalog, a remote service — this connector is the bridge that makes +`spark.read.format(...)` work against it, with predicate pushdown, column +pruning, and partitioned parallel reads. + +You write two small pieces (a Rust function and a Java class); the connector +supplies everything else. + +## How it fits together + +Two layers, one of which already exists: + +``` + your bridge (you write this) this module (already written) ++--------------------------------+ +----------------------------------+ +| cdylib on datafusion-spark- | | Scala/Java DSv2 plumbing | +| bridge (spark/bridge SDK): | | (spark/src) schema inference, | +| your TableProvider + one |<--| pushdown, task planning, | +| export_bridge! invocation; |-->| shared-scan cache | +| the SDK supplies widening, | | | +| session, filters, planning, | | (pure JVM — all native code | +| partition streams | | ships inside YOUR jar) | ++--------------------------------+ +----------------------------------+ + | + v + spark.read.format("...").load() +``` + +The only things that cross between the JVM and your cdylib are opaque +`byte[]` blobs that *you* define (options and per-partition payloads — the +connector never inspects them) going in, and Arrow C streams coming back. +Everything DataFusion-side (planning, filter application, execution) happens +inside your bridge's native library. There is no DataFusion session on the +JVM side at all, and no `FFI_TableProvider` boundary anywhere — your +concrete provider is linked into the same cdylib as the scan machinery. + +## Getting started: generate a bridge + +Don't hand-assemble the pieces below — stamp them out: + +```bash +python3 spark/scaffold/new_bridge.py --name acme --package com.example.acme +``` + +generates a standalone project (Rust cdylib with a working demo provider, +the four Java classes, service registration, shaded-jar pom with the cdylib +bundled, pyspark smoke test, README with the build commands). Replace the +demo `MemTable` in its `native/src/lib.rs` and you have a connector. The +sections below explain what each generated piece is for. + +## What you implement + +| # | Piece | Language | Contract lives at | Working example | +|---|-------|----------|-------------------|-----------------| +| 1 | A provider builder + one `export_bridge!` invocation | Rust | [`bridge/src/lib.rs`](bridge/src/lib.rs) (macro rustdoc) | [`examples/native/src/lib.rs`](../examples/native/src/lib.rs) | +| 2 | A `BridgeProviderFactory` implementation (one required method) + the JNI/backend boilerplate | Java | [`src/main/java/io/datafusion/spark/BridgeProviderFactory.java`](src/main/java/io/datafusion/spark/BridgeProviderFactory.java) | [`examples/.../ExampleBridgeProviderFactory.java`](../examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeProviderFactory.java) | +| 3 | (optional) A `DatafusionSource` subclass giving your source a short name | Scala/Java | [`src/main/scala/io/datafusion/spark/DatafusionSource.scala`](src/main/scala/io/datafusion/spark/DatafusionSource.scala) | see "Wiring it into Spark" below | + +An end-to-end runnable version of all three — in-memory table, factory, and a +PySpark script that scans, filters, and projects it — lives under +[`examples/python/`](../examples/python/). + +### 1. The Rust side + +Depend on the [`datafusion-spark-bridge`](bridge/) SDK crate and let it +generate the JNI surface; you supply one builder turning your option / +partition bytes into a concrete `TableProvider`: + +```rust +use std::sync::Arc; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +fn build_provider( + ctx: &BridgeContext, + options: &[u8], + partition: &[u8], +) -> JniResult> { + let opts = MyOptions::decode(options)?; + Ok(ctx.block_on(MyProvider::connect(opts, partition))?) +} + +export_bridge! { + // Underscore-mangled name of YOUR Java class declaring the native + // methods (dots -> underscores). Per-bridge names let several bridges + // coexist in one Spark JVM. + jni_class: "com_example_mybridge_BridgeNative", + build_provider: build_provider, +} +``` + +The macro's rustdoc lists the exact `static native` method set the named +Java class must declare; your factory routes the connector to it by +overriding `scanBackend()` (see section 2). One cdylib total: your provider +and the SDK's scan machinery are the same library, so there is no provider +hand-off across a binary boundary and no `datafusion-ffi` anywhere. The +builder receives empty partition bytes for the driver-side schema probe — +schema must not depend on per-partition state. + +[`examples/native/src/lib.rs`](../examples/native/src/lib.rs) +is a complete, commented version of this for a `MemTable`. + +### 2. The Java factory + +`BridgeProviderFactory` is the contract between Spark and your bridge. It +must have a no-arg constructor (executors instantiate it reflectively by +class name). The single required method is `scanBackend()` — Spark options +are encoded with `OptionsCodec` by default (decode them in Rust via +`datafusion_spark_bridge::options::decode_options`), and `listPartitions` +defaults to one whole-dataset partition: + +```java +public final class MyBridgeProviderFactory implements BridgeProviderFactory { + + @Override + public ScanBackend scanBackend() { + return new MyBridgeBackend(); // six one-line delegations to BridgeNative + } +} + +/** Declares the native methods generated by export_bridge! and loads the cdylib. */ +final class BridgeNative { + static { + NativeLibraryLoader.load(BridgeNative.class, "com/example/mybridge", "my_bridge"); + } + static native byte[] providerSchemaIpc(byte[] options, byte[] partition); + static native long createScan(byte[] options, byte[] partition, + int targetPartitions, int batchSize, String[] optionKeys, + String[] optionValues, String[] projectionColumns, byte[][] filterProtos); + static native int partitionCount(long scanHandle); + static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + static native void executeStream(long scanHandle, long ffiStreamAddr); + static native void closeScan(long scanHandle); +} +``` + +(`MyBridgeBackend implements ScanBackend` forwards each method to +`BridgeNative` — pure boilerplate the scaffold generates.) + +Override `encodeOptions` only if the bridge already has its own options +schema (e.g. a protobuf), and `listPartitions` when the dataset should split +into more than one Spark task: + +```java + @Override + public PartitionInfo[] listPartitions(byte[] optionsBytes) { + MySlice[] slices = MyBridgeNative.listSlices(optionsBytes); + PartitionInfo[] out = new PartitionInfo[slices.length]; + for (int i = 0; i < slices.length; i++) { + out[i] = new PartitionInfo(slices[i].id(), slices[i].payload(), slices[i].hosts()); + } + return out; + } +``` + +The remaining optional methods — `sharedScan`, `reportPartitioning`, and the +filter-aware `listPartitions(opts, filters)` overload — are covered in their +own sections below. Their javadoc in +[`BridgeProviderFactory.java`](src/main/java/io/datafusion/spark/BridgeProviderFactory.java) +is the authoritative contract. + +### 3. Wiring it into Spark + +Either pass your factory class per read: + +```python +df = (spark.read.format("datafusion") + .option("df.factory", "com.example.MyBridgeProviderFactory") + .option("url", "...") + .option("table", "my_dataset") + .load()) +``` + +or ship a ~10-line subclass so users get a short format name: + +```scala +class MyDataSource extends DatafusionSource { + override def shortName(): String = "my_format" + override protected def factoryFqcn(opts: CaseInsensitiveStringMap): String = + "com.example.MyBridgeProviderFactory" +} +``` + +registered via a +`META-INF/services/org.apache.spark.sql.sources.DataSourceRegister` file +(this module registers `datafusion` the same way — see +[`src/main/resources/META-INF/services/`](src/main/resources/META-INF/services/)). + +## Packaging your bridge + +The end-user experience to aim for is one artifact: + +```python +# spark.jars (or --packages) gets exactly one jar, then: +df = spark.read.format("my_format").option("url", "...").load() +``` + +Three pieces make that work: + +**Bundle your cdylib inside the jar.** Copy it into your jar's resources at +`///` and load it from your native +class's static initializer with the connector's loader — no hand-rolled +extraction code: + +```java +static { + NativeLibraryLoader.load(BridgeNative.class, "com/example/mybridge", "my_bridge"); +} +``` + +The pom side is one antrun copy execution plus per-host profiles; the +examples module is a complete working copy of the pattern (see the +`copy-example-bridge-cdylib` execution and the `native-*` profiles in +[`examples/pom.xml`](../examples/pom.xml), and the loader call in +[`ExampleBridgeNative.java`](../examples/src/main/java/org/apache/datafusion/examples/ExampleBridgeNative.java)). +For a multi-platform jar, build the cdylib per platform in CI and copy each +into its own `//` directory before `mvn package` — the layout +supports them side by side. + +**Shade your dependencies into one fat jar** with `maven-shade-plugin`, so +users don't assemble a jar list: + +```xml + + org.apache.maven.plugins + maven-shade-plugin + + + package + shade + + + + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + +``` + +Include in the shaded jar: this connector (`datafusion-java-spark`), the core +jar (`datafusion-java` — exception classes and, if you push predicates, the +generated proto classes), the Arrow Java artifacts you compile against, and +your own classes + cdylib. Keep `spark-sql`/`scala-library` `provided` — the +cluster supplies them. + +**Do NOT relocate JNI-bound or JNI-loading packages.** JNI binds native +methods by the class's fully-qualified name; `arrow-c-data` and the Arrow +memory modules likewise load their own natives. Relocating +`io.datafusion.spark`, `org.apache.arrow`, or your own native class breaks +the symbol lookup at runtime. Practical consequences: + +- Ship a plain (unrelocated) fat jar. Two bridges in one Spark app then share + one copy of the connector classes — fine when they're built against the + same connector version, which is the only configuration we support anyway + (their cdylibs stay distinct via per-bridge JNI class names). +- Spark bundles its own (often older) Arrow. Since yours can't be relocated + away, have users set `spark.executor.userClassPathFirst=true` and + `spark.driver.userClassPathFirst=true` (the pyspark demo under + [`examples/python/`](../examples/python/) shows the working incantation), + or build with Arrow pinned to the cluster's version. + +## Spark tasks vs. DataFusion partitions + +This is the most important design decision when building a connector, so it +gets its own section. + +Spark parallelism and DataFusion parallelism are different things: + +- A **Spark task** is the unit Spark schedules onto an executor core. Each + task carries fixed overhead: scheduling on the driver, (de)serializing the + task, instantiating your factory, building a provider, planning a scan. +- A **DataFusion partition** is one output stream of a planned physical + query. A single plan usually has several. + +The connector supports two ways of mapping one onto the other: + +### Default mode: one Spark task per `PartitionInfo` + +`listPartitions` returns N entries → Spark runs N tasks. Each task calls +`createProvider(opts, partitionBytes)` with *its own* entry's payload, so each +task plans and scans only its slice. If DataFusion happens to plan that slice +into multiple internal partitions, they are merged into one stream for the +task — within a task there is no extra parallelism, by design (the +parallelism budget belongs to Spark). + +You control the mapping entirely through what you return from +`listPartitions`. Sizing guidance: + +- **Don't emit one `PartitionInfo` per tiny fragment.** A Spark task should + do meaningfully more work than its overhead — as a rule of thumb at least + ~100 ms of scan time, or order-100 MB of data (Spark's own file sources + default to 128 MB per task for the same reason). If your natural unit is a + small chunk (an object-store key, a time slice, a recording segment), + **bin-pack several into one entry**: `partitionBytes` is opaque, so encode + a *list* of chunk ids and have your `createProvider` materialise all of + them in one provider. +- **Watch the total task count.** The Spark driver schedules and tracks every + task; beyond the low thousands of tasks per stage you pay growing driver + CPU/memory and UI lag for no extra throughput once the cluster's cores are + saturated. A healthy target is roughly 2–3 tasks per available core, and + rarely more than a few thousand per scan. Tens of thousands of + single-digit-megabyte tasks is a smell — bin-pack first. +- **Locality and partition keys only exist here.** `preferredLocations` + (host affinity) and `HasPartitionKey`/`reportPartitioning` (shuffle + elision) are properties of `PartitionInfo` entries. If you need either, + use this mode. + +### Shared-scan mode: one Spark task per DataFusion partition + +When provider construction itself is expensive (remote metadata, connection +setup) or the dataset has thousands of small natural partitions, per-task +provider builds dominate. Opting in via + +```java +@Override +public boolean sharedScan(byte[] optionsBytes) { return true; } +``` + +flips the mapping: the provider is built **once per executor JVM per query** +(with empty `partitionBytes`), planned once, and Spark runs one task per +*DataFusion output partition* — task `i` streams plan partition `i` from the +executor-local cached plan. `listPartitions` is not called at all. + +The DataFusion partition count — and therefore the Spark task count — is +pinned by `spark.datafusion.sharedScan.targetPartitions` (default 8). The +value is resolved on the driver and shipped to executors, because +DataFusion's default would vary with each machine's core count and the +partition indices must mean the same thing everywhere. + +Choosing between the modes: + +| Choose | When | +|--------|------| +| Default (per-partition payload) | slices have host affinity, you want partition-key semantics, per-slice provider construction is cheap. Bin-pack small slices before abandoning this mode. | +| Shared-scan | provider construction is expensive, there are thousands of small partitions with no locality story, the workload is scan + filter + projection. Provider builds drop from one-per-task to one-per-executor (plus one driver probe per query). | + +Shared-scan's price of admission is a **determinism contract**: the +provider's schema, partitioning, and per-partition contents must be a pure +function of `optionsBytes`. Remote sources must pin a snapshot +(version/timestamp) inside the options. The connector fails tasks when an +executor's partition count diverges from the driver's, but equal counts with +different contents are undetectable by construction. The provider's +`ExecutionPlan` must also tolerate `execute(i)` being called more than once +per plan instance (Spark retries and speculatively re-executes tasks). Full +contract: `BridgeProviderFactory.sharedScan` javadoc. + +Shared-scan operational details: + +- Executor cache ([`SharedScanCache.scala`](src/main/scala/io/datafusion/spark/SharedScanCache.scala)): + entries keyed per query (`scanId`), refcounted by open readers, evicted + after an idle TTL. Build failures are not cached; eviction between task + waves just rebuilds. +- Spark conf (read on the driver at planning time, shipped to executors): + - `spark.datafusion.sharedScan.targetPartitions` (default 8) + - `spark.datafusion.sharedScan.batchSize` (default 8192) + - `spark.datafusion.sharedScan.idleTtlMs` (default 120000) + +## What the connector does for you + +- **Schema inference** — your provider's Arrow schema, widened, becomes the + Spark schema. Driver-side, one probe build with empty `partitionBytes`. +- **Type widening** — Spark's columnar readers reject several Arrow types + DataFusion happily produces. The SDK (inside your bridge's cdylib) + transparently casts + unsigned ints → wider signed, `Float16` → `Float32`, `Time*` → wider ints, + any-unit/tz `Timestamp` → microsecond, recursively through + `List`/`LargeList`/`FixedSizeList` (see + [`native/src/widening.rs`](native/src/widening.rs)). Caveat: unsigned types + nested inside `Struct`/`Map` are not yet covered. +- **Predicate pushdown** — Spark V2 `Predicate`s are translated to DataFusion + expressions ([`SparkPredicateTranslator.scala`](src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala)), + shipped as `datafusion-proto` bytes, and applied inside the native plan, so + your provider's `supports_filters_pushdown`/`scan` sees real Rust `Expr`s. + Anything untranslatable stays in Spark as a residual filter — over-claiming + is impossible by construction. +- **Column pruning** — Spark's required-columns projection becomes a + DataFusion projection on the native plan. +- **Partition-aware joins/aggregations** (default mode, optional) — declare + `reportPartitioning` + per-partition key values and Spark can elide + shuffles. See the javadoc on + [`ReportedPartitioning.java`](src/main/java/io/datafusion/spark/ReportedPartitioning.java) + and [`PartitionInfo.java`](src/main/java/io/datafusion/spark/PartitionInfo.java); + note Spark 3.3+ additionally requires + `spark.sql.sources.v2.bucketing.enabled=true` for storage-partitioned + joins. + +## What runs where + +| Phase | Where | Path | +| ----- | ----- | ---- | +| Schema inference | Driver | `factory.encodeOptions` → `backend.providerSchemaIpc(opts, EMPTY)` — bridge cdylib builds + widens the provider, returns the Arrow schema | +| Scan planning (default mode) | Driver | `factory.listPartitions(opts[, filters])` → one task per entry, with its `partitionBytes` + `preferredLocations` | +| Scan planning (shared-scan) | Driver | probe build (same code path executors use) → plan partition count `N` → `N` tasks | +| Predicate translation | Driver | `SparkPredicateTranslator` → proto bytes per pushed predicate | +| Per-task scan (default mode) | Executor | `backend.createScan(opts, partitionBytes, ...)` (build provider, widen, project, filter, plan) → stream whole plan | +| Per-task scan (shared-scan) | Executor | cache-acquire by `scanId` (first task builds) → stream plan partition `i` → release | + +The native machinery backing all of this is +[`bridge/src/scan.rs`](bridge/src/scan.rs), exported into each bridge's +cdylib by `export_bridge!` and reached through its [`ScanBackend`](src/main/java/io/datafusion/spark/ScanBackend.java). + +## Module layout + +``` +spark/ +├── src/main/java/io/datafusion/spark/ public SPI (Java on purpose: +│ bridge jars stay Scala-free) +│ BridgeProviderFactory.java <- the contract you implement +│ ScanBackend.java <- native scan surface (delegations +│ to your bridge's JNI class) +│ NativeLibraryLoader.java <- bundled-cdylib extraction/loading +│ PartitionInfo.java <- one entry = one Spark task +│ ReportedPartitioning.java <- optional shuffle-elision declaration +├── src/main/scala/io/datafusion/spark/ connector internals (DSv2 wiring, +│ readers, pushdown, shared-scan cache) +└── bridge/ datafusion-spark-bridge SDK rlib: + widening + scan machinery + + export_bridge! (the native side of + every bridge cdylib) +``` + +## Caveats + +- Pushed filter expressions are deserialized with DataFusion's default + logical-extension codec, which covers columns, literals, and built-in + functions. Anything the Spark-side translator can't express stays in Spark + as a residual filter, so coverage gaps cost performance, never + correctness. +- The bridge cdylib's DataFusion version is the SDK's: cargo resolves one + `datafusion` for your provider and the scan machinery together, pinned in + this repo's workspace [`Cargo.toml`](../Cargo.toml). Upgrading DataFusion + means rebuilding the bridge against a newer SDK. +- The SDK's Tokio runtime is per-cdylib and `Once`-gated; TLS-using bridges + should `Once`-gate their rustls install the same way. diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..3e612e0 --- /dev/null +++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +io.datafusion.spark.DatafusionSource diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala new file mode 100644 index 0000000..30f62f8 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.vector.FieldVector +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} + +/** + * Shared `next()`/`get()` loop for the connector's columnar readers: each `loadNextBatch()` + * yields a `VectorSchemaRoot` wrapped as a `ColumnarBatch` of [[NonClosingArrowColumnVector]]s + * (the reader owns the vectors; Spark must not close them per batch). + */ +private[spark] trait ArrowColumnarBatchIteration { + + /** The Arrow stream this reader drains. Stable for the reader's lifetime. */ + protected def arrowReader: ArrowReader + + private var currentBatch: ColumnarBatch = _ + + def next(): Boolean = { + if (currentBatch != null) { + currentBatch = null + } + if (!arrowReader.loadNextBatch()) return false + val root = arrowReader.getVectorSchemaRoot + val vectors: java.util.List[FieldVector] = root.getFieldVectors + val cols = new Array[ColumnVector](vectors.size()) + var i = 0 + while (i < vectors.size()) { + cols(i) = new NonClosingArrowColumnVector(vectors.get(i)) + i += 1 + } + val batch = new ColumnarBatch(cols) + batch.setNumRows(root.getRowCount) + currentBatch = batch + true + } + + def get(): ColumnarBatch = currentBatch +} diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala new file mode 100644 index 0000000..2e8f1a5 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} +import org.apache.spark.sql.types._ + +/** + * Arrow Schema → Spark StructType converter. + * + * The reported Spark schema MUST be one whose runtime ArrowColumnVector accessor Spark can pick + * for the underlying Arrow vector. Spark 3.5's `ArrowColumnVector` supports the following + * accessors: Boolean, Byte, Short, Int, Long, Float, Double, Decimal, Date, Timestamp, + * TimestampNTZ, Duration (DayTimeInterval), String, LargeString, Binary, Array, Map, Struct, + * Null. No unsigned-int or Time accessor exists; we surface a clear error at schema discovery + * for those — the alternative is silent corruption. + * + * The widening layer (datafusion-spark-bridge, compiled into every bridge cdylib) inserts a + * `WideningTableProvider` upstream of the Spark reader that casts unsupported types kernel-side + * (UInt*→signed wider, Float16→Float32, non-µs Timestamp→µs Timestamp, Time→Int) so Spark only + * ever sees compatible Arrow types. + */ +object ArrowToSparkSchema { + + def toSparkSchema(schema: Schema): StructType = + StructType(schema.getFields.asScala.toSeq.map(toSparkField)) + + private def toSparkField(f: Field): StructField = { + val dt = + Option(f.getDictionary) match { + case Some(_) => + unsupported(f, "dictionary-encoded fields (need dictionary value schema in JNI)") + case None => toSparkType(f) + } + StructField(f.getName, dt, f.isNullable) + } + + private def toSparkType(f: Field): DataType = f.getType match { + case _: ArrowType.Bool => BooleanType + + case t: ArrowType.Int => + (t.getBitWidth, t.getIsSigned) match { + case (8, true) => ByteType + case (16, true) => ShortType + case (32, true) => IntegerType + case (64, true) => LongType + case (bits, false) => + unsupported( + f, + s"unsigned integer UInt$bits (Spark ArrowColumnVector has no unsigned accessor; " + + "widening layer casts these before Spark sees them — this branch indicates the " + + "WideningTableProvider was bypassed)" + ) + case (bits, signed) => unsupported(f, s"Int(bits=$bits, signed=$signed)") + } + + case t: ArrowType.FloatingPoint => + t.getPrecision match { + case FloatingPointPrecision.HALF => + unsupported(f, "Float16 (widening layer must cast to Float32 before Spark)") + case FloatingPointPrecision.SINGLE => FloatType + case FloatingPointPrecision.DOUBLE => DoubleType + case other => unsupported(f, s"FloatingPoint($other)") + } + + case _: ArrowType.Utf8 => StringType + case _: ArrowType.LargeUtf8 => StringType + case _: ArrowType.Binary => BinaryType + case _: ArrowType.LargeBinary => BinaryType + case _: ArrowType.FixedSizeBinary => BinaryType + + case d: ArrowType.Date => + d.getUnit match { + case DateUnit.DAY | DateUnit.MILLISECOND => DateType + case other => unsupported(f, s"Date($other)") + } + + case t: ArrowType.Timestamp => + val _unused = t.getUnit + if (t.getTimezone == null) TimestampNTZType else TimestampType + + case ti: ArrowType.Time => + unsupported( + f, + s"Time(${ti.getUnit}, ${ti.getBitWidth}-bit) — Spark has no time-of-day type" + ) + + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + + case _: ArrowType.Null => NullType + + case _: ArrowType.Duration => DayTimeIntervalType() + + case iv: ArrowType.Interval => + iv.getUnit match { + case IntervalUnit.YEAR_MONTH => YearMonthIntervalType() + case IntervalUnit.DAY_TIME => DayTimeIntervalType() + case IntervalUnit.MONTH_DAY_NANO => + unsupported(f, "Interval(MONTH_DAY_NANO) — no clean Spark equivalent") + } + + case _: ArrowType.Struct => + StructType(f.getChildren.asScala.toSeq.map(toSparkField)) + + case _: ArrowType.List => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + case _: ArrowType.LargeList => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + case _: ArrowType.FixedSizeList => + val child = f.getChildren.get(0) + ArrayType(toSparkType(child), containsNull = child.isNullable) + + case _: ArrowType.Map => + val entries = f.getChildren.get(0) + val keyValue = entries.getChildren + val keyField = keyValue.get(0) + val valueField = keyValue.get(1) + MapType(toSparkType(keyField), toSparkType(valueField), valueField.isNullable) + + case _: ArrowType.Union => + unsupported(f, "Union (Spark has no equivalent)") + + case other => unsupported(f, s"$other") + } + + private def unsupported(f: Field, detail: String): Nothing = + throw new UnsupportedOperationException( + s"Column '${f.getName}': $detail" + ) +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala new file mode 100644 index 0000000..684e9fd --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} + +/** + * Spark `Batch` for a DataFusion-backed scan. Driver-side partition planning: + * - [[PerPartitionMode]]: one task per `PartitionInfo` (resolved by [[DatafusionScanBuilder]]); when + * the bridge reported a partitioning and every entry carries key values, tasks implement + * `HasPartitionKey` so Spark can actually use the `KeyGroupedPartitioning`. + * - [[SharedScanMode]]: one task per DataFusion plan partition index. + */ +class DatafusionBatch(val scan: DatafusionScan) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val projection = scan.prunedSchema.fieldNames + val filterBytes: Array[Array[Byte]] = scan.pushedPredicateBytes + + scan.mode match { + case PerPartitionMode(partitions, reported) => + val keyed = DatafusionBatch.validateKeyedState(scan.factoryFqcn, partitions, reported) + partitions.iterator.map { p => + val base = DatafusionInputPartition( + factoryFqcn = scan.factoryFqcn, + optionsBytes = scan.optionsBytes, + projectionColumnNames = projection, + filterProtoBytes = filterBytes, + partitionId = p.id, + partitionBytes = p.partitionBytes, + preferredLocs = p.preferredLocations + ) + val out: DatafusionPartition = + if (keyed) { + DatafusionKeyedInputPartition( + base, + DatafusionBatch.toKeyRow(p.id, p.partitionKeyValues, reported)) + } else base + out.asInstanceOf[InputPartition] + }.toArray + + case SharedScanMode(scanId, numPartitions, pinnedConfig, idleTtlMs) => + Array.tabulate[InputPartition](numPartitions) { i => + DatafusionSharedScanPartition( + factoryFqcn = scan.factoryFqcn, + optionsBytes = scan.optionsBytes, + projectionColumnNames = projection, + filterProtoBytes = filterBytes, + scanId = scanId, + partitionIndex = i, + numPartitions = numPartitions, + pinnedConfig = pinnedConfig, + idleTtlMs = idleTtlMs + ) + } + } + } + + override def createReaderFactory(): PartitionReaderFactory = + new DatafusionPartitionReaderFactory(scan.prunedSchema) +} + +private[spark] object DatafusionBatch { + + /** + * Keyed partitions require a reported partitioning AND key values on EVERY partition. A mixed + * state means the bridge violated its own contract; failing driver-side beats Spark silently + * planning without the declared grouping. + */ + def validateKeyedState( + factoryFqcn: String, + partitions: Array[PartitionInfo], + reported: ReportedPartitioning): Boolean = { + if (reported == null) { + return false + } + val withKeys = partitions.count(_.partitionKeyValues != null) + if (withKeys == 0) { + return false + } + if (withKeys != partitions.length) { + throw new IllegalStateException( + s"BridgeProviderFactory '$factoryFqcn' reported a partitioning but only $withKeys of " + + s"${partitions.length} PartitionInfo entries carry partitionKeyValues; either all " + + "partitions must carry key values or none") + } + true + } + + /** + * Convert a bridge-supplied `Object[]` of key values into Spark's internal row representation + * (String → UTF8String, Instant → micros, LocalDate → days, BigDecimal → Decimal, ...). + */ + def toKeyRow( + partitionId: String, + values: Array[AnyRef], + reported: ReportedPartitioning): InternalRow = { + val keyCount = reported.keys().length + if (values.length != keyCount) { + throw new IllegalStateException( + s"PartitionInfo '$partitionId' carries ${values.length} partitionKeyValues but the " + + s"reported partitioning declares $keyCount key(s)") + } + val converted = values.map(v => CatalystTypeConverters.convertToCatalyst(v)) + new GenericInternalRow(converted) + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala new file mode 100644 index 0000000..96b7548 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Per-task columnar reader for the per-partition path. Lifecycle: + * + * 1. Reflectively instantiate the bridge's `BridgeProviderFactory` (no-arg) and take its + * [[ScanBackend]]. + * 2. `backend.createScan(options, partitionBytes, ...)` — builds the provider for the slice + * described by `partitionBytes` and does the rest natively: widening wrap, private + * `SessionContext`, projection, pushed proto filters, physical plan. + * 3. `backend.executeStream` streams the whole plan (the provider already IS the task's + * slice); batches surface through [[ArrowColumnarBatchIteration]]. + */ +class DatafusionColumnarPartitionReader( + partition: DatafusionInputPartition, + readSchema: StructType +) extends PartitionReader[ColumnarBatch] + with ArrowColumnarBatchIteration { + + private val allocator = new RootAllocator(Long.MaxValue) + + private val backend: ScanBackend = instantiateFactory(partition.factoryFqcn).scanBackend() + + private val scanHandle: Long = + try { + backend.createScan( + partition.optionsBytes, + partition.partitionBytes, + /* targetPartitions = */ -1, + /* batchSize = */ -1, + Array.empty[String], + Array.empty[String], + partition.projectionColumnNames, + partition.filterProtoBytes + ) + } catch { + case t: Throwable => + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + + override protected val arrowReader: ArrowReader = + try { + FfiStream.importReader(allocator) { addr => + backend.executeStream(scanHandle, addr) + } + } catch { + case t: Throwable => + try backend.closeScan(scanHandle) + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(arrowReader.close()) + safe(backend.closeScan(scanHandle)) + safe(allocator.close()) + if (first != null) throw first + } + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala new file mode 100644 index 0000000..5255644 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} + +/** + * Marker for the connector's task payloads, shipped driver → executor via Java serialization. + * [[DatafusionPartitionReaderFactory]] dispatches on the concrete type. + */ +sealed trait DatafusionPartition extends InputPartition + +/** + * Per-task payload for the per-partition read path. + * + * - `factoryFqcn`: fully-qualified class name of the bridge's `BridgeProviderFactory`. The + * executor reflectively instantiates this and calls + * `scanBackend().createScan(optionsBytes, partitionBytes, …)`. + * - `optionsBytes`: bridge-specific global connection options, encoded by the bridge. + * Opaque to connector-core. Same bytes ride along on every partition. + * - `projectionColumnNames`: pruned column list (post-`pruneColumns`). + * - `filterProtoBytes`: V2 `Predicate` → DataFusion `LogicalExprNode` proto bytes; each one is + * applied natively via `ScanBackend.createScan`. + * - `partitionId`: stable identifier (e.g. a segment or file id) — surfaces in Spark UI/logs/errors. + * - `partitionBytes`: opaque per-partition payload from `PartitionInfo.partitionBytes`. Passed + * back into `ScanBackend.createScan` so the bridge materialises *this* slice. + * - `preferredLocs`: hostnames where this partition's data lives; returned from + * `preferredLocations()` so Spark schedules the task there subject to `spark.locality.wait`. + */ +final case class DatafusionInputPartition( + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + partitionId: String, + partitionBytes: Array[Byte], + preferredLocs: Array[String] +) extends DatafusionPartition { + + override def preferredLocations(): Array[String] = preferredLocs +} + +/** + * Per-partition payload that additionally carries this partition's key values, precomputed + * driver-side into an [[InternalRow]]. Emitted by [[DatafusionBatch]] when the bridge reported a + * partitioning AND every `PartitionInfo` carries `partitionKeyValues` — implementing + * [[HasPartitionKey]] is what makes the reported `KeyGroupedPartitioning` visible to Spark 3.3+ + * (`DataSourceV2ScanExecBase.groupPartitions` ignores it otherwise). + */ +final case class DatafusionKeyedInputPartition( + base: DatafusionInputPartition, + keyRow: InternalRow +) extends DatafusionPartition + with HasPartitionKey { + + override def preferredLocations(): Array[String] = base.preferredLocations() + + override def partitionKey(): InternalRow = keyRow +} + +/** + * Per-task payload for shared-scan mode: task `partitionIndex` streams that DataFusion plan + * partition from the executor's cached entry (see [[SharedScanCache]]). + * + * - `scanId`: driver-minted UUID identifying this scan; the executor cache key. + * - `partitionIndex`: DataFusion output partition this task drives. + * - `numPartitions`: the driver probe's partition count; executors fail fast when their re-plan + * diverges (determinism guard). + * - `pinnedConfig`: DataFusion session knobs resolved once on the driver and replicated on + * every executor so both plan identically. + * - `idleTtlMs`: cache-entry idle eviction window, resolved from driver conf. + * + * No preferred locations: the shared plan materialises the whole dataset on whichever executors + * Spark picks; there is no per-slice host mapping in this mode. + */ +final case class DatafusionSharedScanPartition( + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + scanId: String, + partitionIndex: Int, + numPartitions: Int, + pinnedConfig: PinnedSessionConfig, + idleTtlMs: Long +) extends DatafusionPartition { + + def toSpec: SharedScanSpec = + SharedScanSpec( + scanId = scanId, + factoryFqcn = factoryFqcn, + optionsBytes = optionsBytes, + projectionColumnNames = projectionColumnNames, + filterProtoBytes = filterProtoBytes, + pinnedConfig = pinnedConfig + ) +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala new file mode 100644 index 0000000..ba5409c --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Per-task `PartitionReader` factory. Columnar-only: row-based reads would force the connector + * to convert Arrow → `InternalRow` per row, defeating the zero-copy path that is the whole + * reason we are in-process. + */ +class DatafusionPartitionReaderFactory(val readSchema: StructType) extends PartitionReaderFactory { + + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = + throw new UnsupportedOperationException( + "DatafusionPartitionReaderFactory: row-based read not supported; consumers must opt into columnar" + ) + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = + partition match { + case p: DatafusionInputPartition => + new DatafusionColumnarPartitionReader(p, readSchema) + case p: DatafusionKeyedInputPartition => + new DatafusionColumnarPartitionReader(p.base, readSchema) + case p: DatafusionSharedScanPartition => + new SharedScanPartitionReader(p, SharedScanCache.global) + case other => + throw new IllegalArgumentException( + s"unexpected InputPartition type: ${other.getClass.getName}") + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala new file mode 100644 index 0000000..38f0a8b --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Batch, Scan, SupportsReportPartitioning} +import org.apache.spark.sql.connector.read.partitioning.{ + KeyGroupedPartitioning, + Partitioning, + UnknownPartitioning +} +import org.apache.spark.sql.types.StructType + +/** + * How the scan maps to Spark tasks — resolved once, driver-side, in + * [[DatafusionScanBuilder.build]]. + */ +sealed trait DatafusionScanMode extends Serializable + +/** + * Per-partition payload mode: one task per [[PartitionInfo]], each task builds its own provider + * from that entry's `partitionBytes`. `reported` is the bridge's optional partitioning + * declaration (may be null). + */ +final case class PerPartitionMode( + partitions: Array[PartitionInfo], + reported: ReportedPartitioning +) extends DatafusionScanMode + +/** + * Shared-scan mode: one cached provider + plan per (executor × scan), `numPartitions` tasks each + * driving one DataFusion output partition. See [[BridgeProviderFactory#sharedScan]] for the + * determinism contract. + */ +final case class SharedScanMode( + scanId: String, + numPartitions: Int, + pinnedConfig: PinnedSessionConfig, + idleTtlMs: Long +) extends DatafusionScanMode + +/** + * Read plan for a DataFusion-backed scan. Holds pruning state, the pushed predicates (for + * `description()` / `explain(True)`), the corresponding `LogicalExprNode` proto byte arrays the + * executor applies natively via `ScanBackend.createScan`, and the driver-resolved + * [[DatafusionScanMode]]. + * + * Per-partition mode with a bridge-declared [[ReportedPartitioning]] surfaces `KeyGroupedPartitioning` + * via `SupportsReportPartitioning`; note Spark 3.3+ only consumes it when the input partitions + * also implement `HasPartitionKey` (see [[DatafusionBatch]]). Shared-scan mode always reports + * `UnknownPartitioning` — DataFusion-native partitions carry no key contract. + */ +class DatafusionScan( + val factoryFqcn: String, + val optionsBytes: Array[Byte], + val fullSchema: StructType, + val prunedSchema: StructType, + val pushedPredicates: Array[Predicate], + val pushedPredicateBytes: Array[Array[Byte]], + val mode: DatafusionScanMode +) extends Scan + with SupportsReportPartitioning { + + override def readSchema(): StructType = prunedSchema + + override def description(): String = { + val modeDesc = mode match { + case PerPartitionMode(partitions, reported) => + s"mode=per-partition, partitions=${partitions.length}," + + s" reportedPartitioning=${if (reported == null) "unknown" else "key-grouped"}" + case SharedScanMode(scanId, n, _, _) => + s"mode=shared-scan, scanId=$scanId, partitions=$n" + } + s"DatafusionScan(factory=$factoryFqcn, projection=${prunedSchema.fieldNames.mkString(",")}," + + s" pushedPredicates=${pushedPredicates.length}, $modeDesc)" + } + + override def toBatch: Batch = new DatafusionBatch(this) + + override def outputPartitioning(): Partitioning = mode match { + case PerPartitionMode(partitions, reported) => + if (reported == null) new UnknownPartitioning(partitions.length) + else new KeyGroupedPartitioning(reported.keys().toArray, partitions.length) + case SharedScanMode(_, numPartitions, _, _) => + new UnknownPartitioning(numPartitions) + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala new file mode 100644 index 0000000..a74029c --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.UUID + +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +/** + * ScanBuilder with V2 Predicate pushdown + column pruning. Every translatable predicate is + * marked Exact and dropped from Spark's post-scan Filter; the rest stay residual. + * + * Pushdown discipline: over-claiming Exact = wrong results, under-claiming = full scans. The + * translator (see [[SparkPredicateTranslator]]) only emits proto for predicates it can encode + * losslessly — anything else returns `None` and lands in residuals. + * + * `build()` resolves the driver-side facts the optimizer needs *before* it starts asking the + * [[DatafusionScan]] about its output partitioning. Spark guarantees `pushPredicates` and + * `pruneColumns` run first, so both paths see the final projection + filters: + * - per-partition payload mode: `listPartitions(opts, filters)` (filter-aware overload — the + * bridge can prune whole partitions) + the optional [[ReportedPartitioning]]; + * - shared-scan mode: a probe build of the provider + plan (via the same code path executors + * use) to count DataFusion output partitions, plus a freshly minted scanId and the pinned + * session config that makes executor re-plans comparable. + */ +class DatafusionScanBuilder( + factoryFqcn: String, + optionsBytes: Array[Byte], + fullSchema: StructType +) extends ScanBuilder + with SupportsPushDownV2Filters + with SupportsPushDownRequiredColumns { + + private var pushed: Array[Predicate] = Array.empty + private var pushedBytes: Array[Array[Byte]] = Array.empty + private var pruned: StructType = fullSchema + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val pushedBuf = scala.collection.mutable.ArrayBuffer[Predicate]() + val bytesBuf = scala.collection.mutable.ArrayBuffer[Array[Byte]]() + val residual = scala.collection.mutable.ArrayBuffer[Predicate]() + + var i = 0 + while (i < predicates.length) { + val p = predicates(i) + SparkPredicateTranslator.translate(p) match { + case Some(node) => + pushedBuf += p + bytesBuf += node.toByteArray + case None => + residual += p + } + i += 1 + } + pushed = pushedBuf.toArray + pushedBytes = bytesBuf.toArray + residual.toArray + } + + override def pushedPredicates(): Array[Predicate] = pushed + + override def pruneColumns(requiredSchema: StructType): Unit = { + pruned = requiredSchema + } + + override def build(): Scan = { + val factory = instantiateFactory(factoryFqcn) + val mode: DatafusionScanMode = + if (factory.sharedScan(optionsBytes)) buildSharedScanMode() + else buildPerPartitionMode(factory) + new DatafusionScan( + factoryFqcn, + optionsBytes, + fullSchema, + pruned, + pushed, + pushedBytes, + mode + ) + } + + private def buildPerPartitionMode(factory: BridgeProviderFactory): PerPartitionMode = { + val partitions: Array[PartitionInfo] = + factory.listPartitions(optionsBytes, pushedBytes) + if (partitions == null || partitions.isEmpty) { + throw new IllegalStateException( + s"BridgeProviderFactory '$factoryFqcn' returned no partitions to scan" + ) + } + PerPartitionMode(partitions, factory.reportPartitioning(optionsBytes)) + } + + /** + * Driver plan probe: build the provider + plan exactly as executors will (same widening, SQL, + * filters, pinned config — one code path in [[NativeSharedScanResources]]) and read the + * physical plan's output partition count. All Spark conf is resolved here, driver-side; + * executors only see the shipped copies. + */ + private def buildSharedScanMode(): SharedScanMode = { + val conf = SQLConf.get + val pinned = PinnedSessionConfig.fromConf(conf) + val idleTtlMs = PinnedSessionConfig.idleTtlMs(conf) + val scanId = UUID.randomUUID().toString + + val probeSpec = SharedScanSpec( + scanId = scanId, + factoryFqcn = factoryFqcn, + optionsBytes = optionsBytes, + projectionColumnNames = pruned.fieldNames, + filterProtoBytes = pushedBytes, + pinnedConfig = pinned + ) + val probe = NativeSharedScanResources.build(probeSpec) + val numPartitions = + try { + probe.partitionCount + } finally { + probe.close() + } + if (numPartitions <= 0) { + throw new IllegalStateException( + s"shared-scan probe for factory '$factoryFqcn' produced a plan with no partitions") + } + SharedScanMode(scanId, numPartitions, pinned, idleTtlMs) + } + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala new file mode 100644 index 0000000..125b3a1 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.io.ByteArrayInputStream +import java.nio.channels.Channels +import java.util + +import org.apache.arrow.vector.ipc.ReadChannel +import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Generic Spark DataSource V2 entry point. Concrete bridges either: + * - Subclass and override [[shortName]] + [[factoryFqcn]] (the short-name shim pattern), or + * - Use this class directly with `option("df.factory", "fully.qualified.FactoryClass")`. + * + * Schema discovery happens driver-side inside the bridge's native scan backend + * (`ScanBackend.providerSchemaIpc`), which widens the provider and returns its Arrow schema as + * IPC bytes. The same `optionsBytes` (and the factory FQCN) is then carried verbatim through + * `DatafusionInputPartition`, so each executor task repeats the same factory → backend pipeline + * locally. + */ +class DatafusionSource extends TableProvider with DataSourceRegister { + + override def shortName(): String = "datafusion" + + /** Spark option key carrying the BridgeProviderFactory FQCN when no override is provided. */ + protected val FactoryOptionKey: String = "df.factory" + + /** + * Resolve the bridge factory class name from the Spark options. Subclasses override to return a + * hard-coded FQCN so users don't need to set `df.factory` themselves. + */ + protected def factoryFqcn(options: CaseInsensitiveStringMap): String = { + val v = options.get(FactoryOptionKey) + if (v == null || v.isEmpty) + throw new IllegalArgumentException( + s"DatafusionSource: option '$FactoryOptionKey' is required when no subclass override is set" + ) + v + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + val fqcn = factoryFqcn(options) + val factory = instantiateFactory(fqcn) + val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap()) + // Schema probe: pass empty partitionBytes — bridges are required to honour an empty + // payload for the driver-side probe (schema must not depend on per-partition state). + val ipcBytes = factory.scanBackend().providerSchemaIpc(optionsBytes, Array.emptyByteArray) + val arrowSchema = MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(ipcBytes)))) + ArrowToSparkSchema.toSparkSchema(arrowSchema) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String] + ): Table = { + val options = new CaseInsensitiveStringMap(properties) + val fqcn = factoryFqcn(options) + val factory = instantiateFactory(fqcn) + val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap()) + new DatafusionTable(fqcn, optionsBytes, schema) + } + + override def supportsExternalMetadata(): Boolean = false + + private def instantiateFactory(fqcn: String): BridgeProviderFactory = { + val cls = Class.forName(fqcn) + cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory] + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala new file mode 100644 index 0000000..a0e8ec4 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util + +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Read-only DataFusion-backed table. Capabilities advertise only batch read. + */ +class DatafusionTable( + val factoryFqcn: String, + val optionsBytes: Array[Byte], + val sparkSchema: StructType +) extends Table + with SupportsRead { + + override def name(): String = s"datafusion.${factoryFqcn.split('.').last}" + + override def schema(): StructType = sparkSchema + + override def capabilities(): util.Set[TableCapability] = { + val caps = new util.HashSet[TableCapability]() + caps.add(TableCapability.BATCH_READ) + caps + } + + override def newScanBuilder(scanOpts: CaseInsensitiveStringMap): ScanBuilder = + new DatafusionScanBuilder(factoryFqcn, optionsBytes, sparkSchema) +} diff --git a/spark/src/main/scala/io/datafusion/spark/FfiStream.scala b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala new file mode 100644 index 0000000..eb1149a --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.c.{ArrowArrayStream, Data} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader + +/** + * Arrow C-data import of a native-produced `FFI_ArrowArrayStream`: allocate the empty struct, + * let the native side write into it, then hand it to Arrow Java. On any failure the struct is + * released so a half-written stream can't leak. + */ +private[spark] object FfiStream { + + def importReader(allocator: BufferAllocator)(writeStream: Long => Unit): ArrowReader = { + val stream = ArrowArrayStream.allocateNew(allocator) + try { + writeStream(stream.memoryAddress()) + Data.importArrayStream(allocator, stream) + } catch { + case t: Throwable => + stream.close() + throw t + } + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala new file mode 100644 index 0000000..b541c8a --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.internal.Logging + +/** + * JNI-backed shared-scan entry: one provider, one planned scan handle inside the bridge's native + * scan backend. + * + * The build sequence is the single code path for BOTH the driver-side partition-count probe and + * every executor's cache entry — identical widening, registration, projection, filters, and + * pinned session config are what make the partition count comparable across machines (the + * bridge's determinism contract covers the rest). + */ +private[spark] final class NativeSharedScanResources( + allocator: RootAllocator, + backend: ScanBackend, + scanHandle: Long +) extends SharedScanResources { + + override def partitionCount: Int = backend.partitionCount(scanHandle) + + override def newTaskAllocator(name: String): BufferAllocator = + allocator.newChildAllocator(name, 0, Long.MaxValue) + + override def openPartitionStream( + partition: Int, + taskAllocator: BufferAllocator): ArrowReader = + FfiStream.importReader(taskAllocator) { addr => + backend.executeStreamPartition(scanHandle, partition, addr) + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(backend.closeScan(scanHandle)) + safe(allocator.close()) + if (first != null) throw first + } +} + +private[spark] object NativeSharedScanResources extends Logging { + + def build(spec: SharedScanSpec): SharedScanResources = { + logInfo( + s"Building shared-scan entry for scanId=${spec.scanId} " + + s"(factory=${spec.factoryFqcn}, filters=${spec.filterProtoBytes.length})") + + val factory = Class + .forName(spec.factoryFqcn) + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[BridgeProviderFactory] + val backend = factory.scanBackend() + + val allocator = new RootAllocator(Long.MaxValue) + try { + // Shared mode builds the dataset-wide provider: empty partitionBytes, like the + // driver-side schema probe. DataFusion-native partitioning replaces listPartitions. + val scanHandle = backend.createScan( + spec.optionsBytes, + Array.emptyByteArray, + spec.pinnedConfig.targetPartitions, + spec.pinnedConfig.batchSize, + spec.pinnedConfig.options.map(_._1).toArray, + spec.pinnedConfig.options.map(_._2).toArray, + spec.projectionColumnNames, + spec.filterProtoBytes + ) + new NativeSharedScanResources(allocator, backend, scanHandle) + } catch { + case t: Throwable => + try allocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + throw t + } + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala new file mode 100644 index 0000000..4fa74bd --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.vector.FieldVector +import org.apache.spark.sql.vectorized.ArrowColumnVector + +/** + * `ArrowColumnVector` whose `close()` is a no-op. The `ArrowReader`'s `VectorSchemaRoot` owns + * the underlying `ValueVector` lifecycles across `loadNextBatch()` calls; closing them per Spark + * batch would break the next read. Lifecycle is centralised in + * `DatafusionColumnarPartitionReader.close()`. + */ +final class NonClosingArrowColumnVector(vec: FieldVector) extends ArrowColumnVector(vec) { + override def close(): Unit = { /* intentional no-op */ } +} diff --git a/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala new file mode 100644 index 0000000..1340978 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.sql.internal.SQLConf + +/** + * DataFusion session knobs pinned by the driver and replicated verbatim on every executor in + * shared-scan mode. + * + * DataFusion's default `SessionConfig` derives `target_partitions` from the machine's core count, + * so a plan that yields N partitions on the driver could yield M ≠ N on a differently-sized + * executor — and partition-indexed execution would silently drop or duplicate data. The driver + * resolves these values once in `DatafusionScanBuilder.build()`, ships them inside every + * [[DatafusionSharedScanPartition]], and both the driver probe and the executors hand the same + * values to `ScanBackend.createScan`, which builds the native `SessionContext` from them. + * + * `options` additionally disables the optimizer's plan-reshaping repartition passes so the + * physical partitioning is exactly what the provider's `scan()` reports, on every machine. + */ +final case class PinnedSessionConfig( + targetPartitions: Int, + batchSize: Int, + options: Vector[(String, String)] +) extends Serializable + +object PinnedSessionConfig { + + val TargetPartitionsConf = "spark.datafusion.sharedScan.targetPartitions" + val BatchSizeConf = "spark.datafusion.sharedScan.batchSize" + val IdleTtlConf = "spark.datafusion.sharedScan.idleTtlMs" + + val DefaultTargetPartitions = 8 + val DefaultBatchSize = 8192 + val DefaultIdleTtlMs = 120000L + + /** + * Optimizer knobs that must not vary with the host. Round-robin repartition and file-scan + * repartition would let the optimizer change the plan's output partition count based on + * `target_partitions` heuristics; statistics collection could steer per-host plan differences. + */ + private val DeterminismOptions: Vector[(String, String)] = Vector( + "datafusion.optimizer.enable_round_robin_repartition" -> "false", + "datafusion.optimizer.repartition_file_scans" -> "false", + "datafusion.execution.collect_statistics" -> "false" + ) + + /** + * Resolve the pinned config from the driver's session conf. Called exactly once per scan, on + * the driver; executors never read Spark conf for these values — they use the shipped copy. + */ + def fromConf(conf: SQLConf): PinnedSessionConfig = { + PinnedSessionConfig( + targetPartitions = + conf.getConfString(TargetPartitionsConf, DefaultTargetPartitions.toString).toInt, + batchSize = conf.getConfString(BatchSizeConf, DefaultBatchSize.toString).toInt, + options = DeterminismOptions + ) + } + + def idleTtlMs(conf: SQLConf): Long = + conf.getConfString(IdleTtlConf, DefaultIdleTtlMs.toString).toLong +} diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala new file mode 100644 index 0000000..a134746 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.concurrent.{ConcurrentHashMap, Executors, ScheduledExecutorService, TimeUnit} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader + +/** + * Everything the driver resolved that an executor needs to rebuild the shared scan: identity + * (scanId) plus the exact build inputs (factory, options, projection, filters, pinned config). + */ +final case class SharedScanSpec( + scanId: String, + factoryFqcn: String, + optionsBytes: Array[Byte], + projectionColumnNames: Array[String], + filterProtoBytes: Array[Array[Byte]], + pinnedConfig: PinnedSessionConfig +) + +/** + * What one cached shared-scan entry exposes to readers. Implemented by + * [[NativeSharedScanResources]] (JNI-backed) and by fakes in tests. + */ +trait SharedScanResources extends AutoCloseable { + + /** Output partition count of the planned physical plan. */ + def partitionCount: Int + + /** Child allocator for one task's reader; closed by the task, before release. */ + def newTaskAllocator(name: String): BufferAllocator + + /** Open an independent stream over one plan partition. Concurrent-safe. */ + def openPartitionStream(partition: Int, taskAllocator: BufferAllocator): ArrowReader +} + +/** + * Executor-JVM cache of shared-scan entries, keyed by the driver-minted scanId. + * + * Semantics: + * - `acquire` builds the entry exactly once per attempt wave: the first caller builds under + * the entry's lock, concurrent callers block and share the result. Each successful acquire + * increments a refcount that the caller MUST pair with `release(scanId)`. + * - Build failures propagate to the builder AND all waiters of that attempt, and are not + * cached: the next acquire rebuilds. + * - Eviction closes entries with refcount 0 that have been idle longer than their TTL. The + * refcount covers every open reader, so native close never races an in-flight stream. + * Acquire after eviction rebuilds — correct, just slower. + * + * The cache itself is JNI-free: the entry builder is injected, so tests run without native libs. + */ +final class SharedScanCache( + buildEntry: SharedScanSpec => SharedScanResources, + nanoClock: () => Long = () => System.nanoTime() +) { + + /** + * Per-scanId slot. All state transitions are guarded by `this` (the holder's monitor); the + * build itself also runs under the monitor, which is what blocks concurrent acquirers of the + * same scan until the entry exists. + */ + private final class EntryHolder(spec: SharedScanSpec, idleTtlMs: Long) { + private var resources: SharedScanResources = _ + private var refCount: Int = 0 + private var lastReleaseNanos: Long = nanoClock() + private var closed: Boolean = false + + /** Returns the resources with refcount incremented, or None if this holder was evicted. */ + def acquire(): Option[SharedScanResources] = synchronized { + if (closed) return None + if (resources == null) { + resources = buildEntry(spec) // throws -> caller removes holder + } + refCount += 1 + Some(resources) + } + + def release(): Unit = synchronized { + refCount -= 1 + lastReleaseNanos = nanoClock() + } + + /** Close if idle past TTL; returns true when this holder is now closed. */ + def closeIfIdle(nowNanos: Long): Boolean = synchronized { + if (closed) return true + val idle = refCount == 0 && + (nowNanos - lastReleaseNanos) >= TimeUnit.MILLISECONDS.toNanos(idleTtlMs) + if (idle) forceCloseLocked() + closed + } + + def forceClose(): Unit = synchronized { forceCloseLocked() } + + private def forceCloseLocked(): Unit = { + if (!closed) { + closed = true + if (resources != null) { + val r = resources + resources = null + r.close() + } + } + } + } + + private val entries = new ConcurrentHashMap[String, EntryHolder]() + + def acquire(spec: SharedScanSpec, idleTtlMs: Long): SharedScanResources = { + while (true) { + val holder = + entries.computeIfAbsent(spec.scanId, _ => new EntryHolder(spec, idleTtlMs)) + val acquired = + try { + holder.acquire() + } catch { + case t: Throwable => + // Build failed: drop the holder so the next acquire rebuilds, then propagate. + entries.remove(spec.scanId, holder) + throw t + } + acquired match { + case Some(resources) => return resources + case None => + // Holder was evicted between map lookup and acquire; retry with a fresh one. + entries.remove(spec.scanId, holder) + } + } + throw new IllegalStateException("unreachable") + } + + def release(scanId: String): Unit = { + val holder = entries.get(scanId) + if (holder == null) { + throw new IllegalStateException( + s"release($scanId) without a cached entry: unbalanced acquire/release") + } + holder.release() + } + + /** Close and remove every idle-past-TTL entry. Called by the evictor daemon and by tests. */ + private[spark] def evictIdleNow(): Unit = { + val now = nanoClock() + entries.forEach { (scanId, holder) => + if (holder.closeIfIdle(now)) { + entries.remove(scanId, holder) + } + } + } + + /** Close everything regardless of refcounts. JVM-shutdown path only. */ + def shutdown(): Unit = { + entries.forEach { (_, holder) => holder.forceClose() } + entries.clear() + } +} + +object SharedScanCache { + + /** Evictor period. Short relative to any sane TTL; cheap when the map is empty. */ + private val EvictorPeriodMs = 5000L + + /** JVM singleton used by executor tasks. Lazily started together with its evictor daemon. */ + lazy val global: SharedScanCache = { + val cache = new SharedScanCache(NativeSharedScanResources.build) + val evictor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor { r => + val t = new Thread(r, "datafusion-shared-scan-evictor") + t.setDaemon(true) + t + } + evictor.scheduleWithFixedDelay( + () => cache.evictIdleNow(), + EvictorPeriodMs, + EvictorPeriodMs, + TimeUnit.MILLISECONDS) + Runtime.getRuntime.addShutdownHook(new Thread(() => cache.shutdown())) + cache + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala new file mode 100644 index 0000000..4f0c9c1 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.TaskContext +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Shared-scan task reader: acquires the executor's cached (provider, plan) entry and streams ONE + * DataFusion plan partition from it. The acquire/release refcount pair brackets the reader's + * whole lifetime, so the cache can never close the native plan under an open stream. + */ +class SharedScanPartitionReader( + partition: DatafusionSharedScanPartition, + cache: SharedScanCache +) extends PartitionReader[ColumnarBatch] + with ArrowColumnarBatchIteration { + + private val resources: SharedScanResources = cache.acquire(partition.toSpec, partition.idleTtlMs) + + // Determinism guard: the driver counted partitions by planning once; if this executor's + // re-plan disagrees, partition indices are meaningless and every task of the scan must fail + // rather than silently drop or duplicate data. + if (resources.partitionCount != partition.numPartitions) { + val executorCount = resources.partitionCount + cache.release(partition.scanId) + throw new IllegalStateException( + s"shared-scan determinism violation for scanId=${partition.scanId}: driver planned " + + s"${partition.numPartitions} partition(s) but this executor planned $executorCount. " + + "The provider's partitioning must be a pure function of optionsBytes; pin your " + + "source snapshot (see BridgeProviderFactory.sharedScan).") + } + + private val taskAllocator: BufferAllocator = { + val attempt = Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + resources.newTaskAllocator( + s"shared-${partition.scanId}-p${partition.partitionIndex}-attempt$attempt") + } + + override protected val arrowReader: ArrowReader = + try { + resources.openPartitionStream(partition.partitionIndex, taskAllocator) + } catch { + case t: Throwable => + try taskAllocator.close() + catch { case suppressed: Throwable => t.addSuppressed(suppressed) } + cache.release(partition.scanId) + throw t + } + + override def close(): Unit = { + var first: Throwable = null + def safe(f: => Unit): Unit = + try f + catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) } + safe(arrowReader.close()) + safe(taskAllocator.close()) + // Release LAST: the refcount must cover the open stream and the task allocator. + safe(cache.release(partition.scanId)) + if (first != null) throw first + } +} diff --git a/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala new file mode 100644 index 0000000..3092914 --- /dev/null +++ b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import datafusion_common.DatafusionCommon.{Column, ScalarValue} +import org.apache.datafusion.protobuf.{ + BinaryExprNode, + InListNode, + IsNotNull, + IsNull, + LikeNode, + LogicalExprNode, + Not => NotNode +} +import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate + +/** + * Translate Spark V2 `Predicate` → DataFusion `LogicalExprNode` proto. Only emits expressions + * that the producer can apply EXACTLY — anything else returns `None` and the caller marks the + * predicate as residual so Spark re-applies it above the scan. + */ +object SparkPredicateTranslator { + + def translate(p: Predicate): Option[LogicalExprNode] = p.name() match { + case "=" => binary(p, "Eq") + case "<>" => binary(p, "NotEq") + case "<" => binary(p, "Lt") + case "<=" => binary(p, "LtEq") + case ">" => binary(p, "Gt") + case ">=" => binary(p, "GtEq") + case "IS_NULL" => unary(p, "IsNull") + case "IS_NOT_NULL" => unary(p, "IsNotNull") + case "AND" => combine(p, "And") + case "OR" => combine(p, "Or") + case "NOT" => translateNot(p) + case "IN" => translateIn(p) + case "STARTS_WITH" => like(p, prefix = false, suffix = true) + case "ENDS_WITH" => like(p, prefix = true, suffix = false) + case "CONTAINS" => like(p, prefix = true, suffix = true) + case _ => None + } + + private def binary(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val left = expr(cs(0)) + val right = expr(cs(1)) + if (left.isEmpty || right.isEmpty) return None + Some( + LogicalExprNode + .newBuilder() + .setBinaryExpr( + BinaryExprNode + .newBuilder() + .addOperands(left.get) + .addOperands(right.get) + .setOp(op) + .build() + ) + .build() + ) + } + + private def unary(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 1) return None + val inner = expr(cs(0)) + if (inner.isEmpty) return None + val builder = LogicalExprNode.newBuilder() + op match { + case "IsNull" => builder.setIsNullExpr(IsNull.newBuilder().setExpr(inner.get).build()) + case "IsNotNull" => + builder.setIsNotNullExpr(IsNotNull.newBuilder().setExpr(inner.get).build()) + case _ => return None + } + Some(builder.build()) + } + + private def combine(p: Predicate, op: String): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val (l, r) = (cs(0), cs(1)) + if (!l.isInstanceOf[Predicate] || !r.isInstanceOf[Predicate]) return None + val ln = translate(l.asInstanceOf[Predicate]) + val rn = translate(r.asInstanceOf[Predicate]) + if (ln.isEmpty || rn.isEmpty) return None + Some( + LogicalExprNode + .newBuilder() + .setBinaryExpr( + BinaryExprNode + .newBuilder() + .addOperands(ln.get) + .addOperands(rn.get) + .setOp(op) + .build() + ) + .build() + ) + } + + private def translateNot(p: Predicate): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 1 || !cs(0).isInstanceOf[Predicate]) return None + val inner = translate(cs(0).asInstanceOf[Predicate]) + if (inner.isEmpty) return None + Some(LogicalExprNode.newBuilder().setNotExpr(NotNode.newBuilder().setExpr(inner.get).build()).build()) + } + + private def translateIn(p: Predicate): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length < 2) return None + val target = expr(cs(0)) + if (target.isEmpty) return None + val values = new java.util.ArrayList[LogicalExprNode]() + var i = 1 + while (i < cs.length) { + val v = expr(cs(i)) + if (v.isEmpty) return None + values.add(v.get) + i += 1 + } + val node = InListNode + .newBuilder() + .setExpr(target.get) + .addAllList(values) + .setNegated(false) + .build() + Some(LogicalExprNode.newBuilder().setInList(node).build()) + } + + private def like(p: Predicate, prefix: Boolean, suffix: Boolean): Option[LogicalExprNode] = { + val cs = p.children() + if (cs.length != 2) return None + val target = expr(cs(0)) + val pat = cs(1) match { + case lit: Literal[_] => + val raw = lit.value() match { + case s: String => Some(s) + case u: org.apache.spark.unsafe.types.UTF8String => Some(u.toString) + case _ => None + } + raw.map { r => + val escaped = r.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + (if (prefix) "%" else "") + escaped + (if (suffix) "%" else "") + } + case _ => None + } + if (target.isEmpty || pat.isEmpty) return None + val patternExpr = stringLiteral(pat.get) + val like = LikeNode + .newBuilder() + .setExpr(target.get) + .setPattern(patternExpr) + .setNegated(false) + .setEscapeChar("\\") + .build() + Some(LogicalExprNode.newBuilder().setLike(like).build()) + } + + private def expr(e: Expression): Option[LogicalExprNode] = e match { + case nr: NamedReference => + val parts = nr.fieldNames() + if (parts.length != 1) None + else + Some( + LogicalExprNode + .newBuilder() + .setColumn(Column.newBuilder().setName(parts(0)).build()) + .build() + ) + case lit: Literal[_] => literal(lit.value()) + case _ => None + } + + private def literal(v: Any): Option[LogicalExprNode] = { + val sv = ScalarValue.newBuilder() + val ok: Boolean = v match { + case b: java.lang.Boolean => sv.setBoolValue(b.booleanValue()); true + case b: java.lang.Byte => sv.setInt8Value(b.intValue()); true + case s: java.lang.Short => sv.setInt16Value(s.intValue()); true + case i: java.lang.Integer => sv.setInt32Value(i.intValue()); true + case l: java.lang.Long => sv.setInt64Value(l.longValue()); true + case f: java.lang.Float => sv.setFloat32Value(f.floatValue()); true + case d: java.lang.Double => sv.setFloat64Value(d.doubleValue()); true + case s: String => sv.setUtf8Value(s); true + case u: org.apache.spark.unsafe.types.UTF8String => sv.setUtf8Value(u.toString); true + case _ => false + } + if (!ok) None + else Some(LogicalExprNode.newBuilder().setLiteral(sv.build()).build()) + } + + private def stringLiteral(s: String): LogicalExprNode = + LogicalExprNode.newBuilder().setLiteral(ScalarValue.newBuilder().setUtf8Value(s).build()).build() +} diff --git a/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala new file mode 100644 index 0000000..2b59601 --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.Collections + +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.types._ +import org.scalatest.funsuite.AnyFunSuite + +class ArrowToSparkSchemaTest extends AnyFunSuite { + + private def primField(name: String, t: ArrowType, nullable: Boolean = true): Field = + new Field(name, new FieldType(nullable, t, /*dict=*/ null), Collections.emptyList()) + + test("signed ints map to matching Spark int types") { + val arrow = new Schema( + java.util.Arrays.asList( + primField("i8", new ArrowType.Int(8, true)), + primField("i16", new ArrowType.Int(16, true)), + primField("i32", new ArrowType.Int(32, true)), + primField("i64", new ArrowType.Int(64, true)) + ) + ) + val s = ArrowToSparkSchema.toSparkSchema(arrow) + assert(s.fields(0).dataType == ByteType) + assert(s.fields(1).dataType == ShortType) + assert(s.fields(2).dataType == IntegerType) + assert(s.fields(3).dataType == LongType) + } + + test("unsigned ints are rejected with a clear error") { + val arrow = new Schema( + java.util.Arrays.asList(primField("u32", new ArrowType.Int(32, false))) + ) + val ex = intercept[UnsupportedOperationException](ArrowToSparkSchema.toSparkSchema(arrow)) + assert(ex.getMessage.contains("u32")) + assert(ex.getMessage.toLowerCase.contains("unsigned")) + } + + test("timestamps split on timezone presence") { + val withTz = primField("t_utc", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")) + val noTz = primField("t_local", new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)) + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(withTz, noTz)) + ) + assert(s.fields(0).dataType == TimestampType) + assert(s.fields(1).dataType == TimestampNTZType) + } + + test("decimal preserves precision and scale") { + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("d", new ArrowType.Decimal(18, 4, 128)))) + ) + assert(s.fields(0).dataType == DecimalType(18, 4)) + } + + test("Time and Float16 are rejected (no Spark accessor)") { + intercept[UnsupportedOperationException] { + ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("t", new ArrowType.Time(TimeUnit.MICROSECOND, 64)))) + ) + } + intercept[UnsupportedOperationException] { + ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(primField("h", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))) + ) + } + } + + test("list element nullability propagates") { + val child = + new Field( + "el", + new FieldType(/*nullable=*/ true, new ArrowType.Int(32, true), null), + Collections.emptyList() + ) + val listField = new Field( + "xs", + new FieldType(true, new ArrowType.List(), null), + java.util.Arrays.asList(child) + ) + val s = ArrowToSparkSchema.toSparkSchema( + new Schema(java.util.Arrays.asList(listField)) + ) + assert(s.fields(0).dataType == ArrayType(IntegerType, containsNull = true)) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala new file mode 100644 index 0000000..e2f876d --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.funsuite.AnyFunSuite + +class PartitionKeyConversionTest extends AnyFunSuite { + + private def info(id: String, keys: Array[AnyRef]): PartitionInfo = + new PartitionInfo(id, Array.emptyByteArray, Array.empty[String], keys) + + private def infoNoKeys(id: String): PartitionInfo = + new PartitionInfo(id, Array.emptyByteArray, Array.empty[String]) + + test("String and Long key values convert to catalyst representations") { + val reported = ReportedPartitioning.identity("segment_id", "bucket") + val row = + DatafusionBatch.toKeyRow("p0", Array[AnyRef]("segment-a", Long.box(42L)), reported) + assert(row.numFields == 2) + assert(row.get(0, org.apache.spark.sql.types.StringType) == UTF8String.fromString("segment-a")) + assert(row.getLong(1) == 42L) + } + + test("arity mismatch between key values and declared keys throws") { + val reported = ReportedPartitioning.identity("segment_id", "bucket") + val e = intercept[IllegalStateException] { + DatafusionBatch.toKeyRow("p0", Array[AnyRef]("only-one"), reported) + } + assert(e.getMessage.contains("declares 2 key(s)")) + } + + test("keyed state requires reported partitioning") { + val partitions = Array(info("p0", Array[AnyRef]("a"))) + assert(!DatafusionBatch.validateKeyedState("F", partitions, null)) + } + + test("no partitions with keys means unkeyed, even with reported partitioning") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = Array(infoNoKeys("p0"), infoNoKeys("p1")) + assert(!DatafusionBatch.validateKeyedState("F", partitions, reported)) + } + + test("all partitions with keys means keyed") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = + Array(info("p0", Array[AnyRef]("a")), info("p1", Array[AnyRef]("b"))) + assert(DatafusionBatch.validateKeyedState("F", partitions, reported)) + } + + test("mixed keyed and unkeyed partitions throw driver-side") { + val reported = ReportedPartitioning.identity("segment_id") + val partitions = Array(info("p0", Array[AnyRef]("a")), infoNoKeys("p1")) + val e = intercept[IllegalStateException] { + DatafusionBatch.validateKeyedState("F", partitions, reported) + } + assert(e.getMessage.contains("only 1 of 2")) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala new file mode 100644 index 0000000..dae49eb --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.scalatest.funsuite.AnyFunSuite + +class SharedScanCacheTest extends AnyFunSuite { + + private def spec(scanId: String): SharedScanSpec = + SharedScanSpec( + scanId = scanId, + factoryFqcn = "test.Factory", + optionsBytes = Array.emptyByteArray, + projectionColumnNames = Array.empty, + filterProtoBytes = Array.empty, + pinnedConfig = PinnedSessionConfig(8, 8192, Vector.empty) + ) + + /** JNI-free fake entry; records close. */ + private final class FakeResources extends SharedScanResources { + @volatile var closed = false + override def partitionCount: Int = 3 + override def newTaskAllocator(name: String): BufferAllocator = + throw new UnsupportedOperationException("not used in cache tests") + override def openPartitionStream(p: Int, a: BufferAllocator): ArrowReader = + throw new UnsupportedOperationException("not used in cache tests") + override def close(): Unit = closed = true + } + + private final class Fixture { + val clock = new AtomicLong(0L) + val buildCount = new AtomicInteger(0) + var failBuilds = false + var lastBuilt: FakeResources = _ + + val cache = new SharedScanCache( + buildEntry = _ => { + buildCount.incrementAndGet() + if (failBuilds) throw new RuntimeException("synthetic build failure") + lastBuilt = new FakeResources + lastBuilt + }, + nanoClock = () => clock.get() + ) + + def advanceMillis(ms: Long): Unit = clock.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms)) + } + + test("acquire builds once, second acquire reuses, refcount pairs with release") { + val f = new Fixture + val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val r2 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 1) + assert(r1 eq r2) + f.cache.release("s1") + f.cache.release("s1") + } + + test("concurrent acquires build exactly once") { + val f = new Fixture + val n = 8 + val pool = Executors.newFixedThreadPool(n) + val ready = new CountDownLatch(n) + val go = new CountDownLatch(1) + try { + val futures = (0 until n).map { _ => + pool.submit { () => + ready.countDown() + go.await() + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + } + } + ready.await() + go.countDown() + val results = futures.map(_.get(10, TimeUnit.SECONDS)) + assert(f.buildCount.get() == 1) + assert(results.forall(_ eq results.head)) + (0 until n).foreach(_ => f.cache.release("s1")) + } finally { + pool.shutdownNow() + } + } + + test("build failure propagates and is not cached") { + val f = new Fixture + f.failBuilds = true + val e = intercept[RuntimeException](f.cache.acquire(spec("s1"), idleTtlMs = 1000)) + assert(e.getMessage == "synthetic build failure") + f.failBuilds = false + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r eq f.lastBuilt) + f.cache.release("s1") + } + + test("idle entry past TTL is evicted and closed") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + val built = f.lastBuilt + f.advanceMillis(999) + f.cache.evictIdleNow() + assert(!built.closed) + f.advanceMillis(2) + f.cache.evictIdleNow() + assert(built.closed) + } + + test("entry in use is never evicted, regardless of idle time") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val built = f.lastBuilt + f.advanceMillis(100000) + f.cache.evictIdleNow() + assert(!built.closed) + f.cache.release("s1") + f.advanceMillis(100000) + f.cache.evictIdleNow() + assert(built.closed) + } + + test("release then reacquire within TTL resets idleness") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + f.advanceMillis(900) + // Next task wave lands before TTL: same entry, no rebuild. + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 1) + assert(r eq f.lastBuilt) + f.cache.release("s1") + f.advanceMillis(900) + f.cache.evictIdleNow() + assert(!f.lastBuilt.closed, "idle clock must restart at the last release") + } + + test("acquire after eviction rebuilds") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + f.cache.release("s1") + val first = f.lastBuilt + f.advanceMillis(2000) + f.cache.evictIdleNow() + assert(first.closed) + val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r ne first) + f.cache.release("s1") + } + + test("distinct scanIds get distinct entries") { + val f = new Fixture + val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val r2 = f.cache.acquire(spec("s2"), idleTtlMs = 1000) + assert(f.buildCount.get() == 2) + assert(r1 ne r2) + f.cache.release("s1") + f.cache.release("s2") + } + + test("unbalanced release throws") { + val f = new Fixture + intercept[IllegalStateException](f.cache.release("never-acquired")) + } + + test("shutdown closes everything, even entries in use") { + val f = new Fixture + f.cache.acquire(spec("s1"), idleTtlMs = 1000) + val built = f.lastBuilt + f.cache.shutdown() + assert(built.closed) + } +} diff --git a/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala new file mode 100644 index 0000000..b7faac1 --- /dev/null +++ b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.datafusion.spark + +import org.apache.datafusion.protobuf.LogicalExprNode +import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.scalatest.funsuite.AnyFunSuite + +class SparkPredicateTranslatorTest extends AnyFunSuite { + + private def col(name: String): NamedReference = Expressions.column(name) + private def litInt(v: Int): Expression = Expressions.literal(Int.box(v)) + private def litLong(v: Long): Expression = Expressions.literal(Long.box(v)) + private def litStr(v: String): Expression = + Expressions.literal(org.apache.spark.unsafe.types.UTF8String.fromString(v)) + + test("LessThan(timeline, 1_000_000) translates to a non-empty proto") { + val p = new Predicate("<", Array[Expression](col("timeline"), litLong(1000000L))) + val node = SparkPredicateTranslator.translate(p).getOrElse(fail("expected Some")) + val bytes = node.toByteArray + assert(bytes.nonEmpty) + val parsed = LogicalExprNode.parseFrom(bytes) + assert(parsed.hasBinaryExpr) + assert(parsed.getBinaryExpr.getOp == "Lt") + } + + test("AND of two translatable predicates round-trips through binary op 'And'") { + val lt = new Predicate("<", Array[Expression](col("a"), litInt(10))) + val eq = new Predicate("=", Array[Expression](col("b"), litStr("x"))) + val and = new Predicate("AND", Array[Expression](lt, eq)) + val node = SparkPredicateTranslator.translate(and).getOrElse(fail("expected Some")) + val parsed = LogicalExprNode.parseFrom(node.toByteArray) + assert(parsed.hasBinaryExpr) + assert(parsed.getBinaryExpr.getOp == "And") + } + + test("AND becomes residual when an operand is untranslatable") { + val nse = new Predicate("<=>", Array[Expression](col("a"), litInt(1))) + val eq = new Predicate("=", Array[Expression](col("b"), litInt(2))) + val and = new Predicate("AND", Array[Expression](nse, eq)) + assert(SparkPredicateTranslator.translate(and).isEmpty) + } + + test("IS_NULL and IS_NOT_NULL emit the dedicated proto variants") { + val isNull = new Predicate("IS_NULL", Array[Expression](col("x"))) + val isNotNull = new Predicate("IS_NOT_NULL", Array[Expression](col("x"))) + val n1 = SparkPredicateTranslator.translate(isNull).getOrElse(fail()).toByteArray + val n2 = SparkPredicateTranslator.translate(isNotNull).getOrElse(fail()).toByteArray + val p1 = LogicalExprNode.parseFrom(n1) + val p2 = LogicalExprNode.parseFrom(n2) + assert(p1.hasIsNullExpr) + assert(p2.hasIsNotNullExpr) + } + + test("STARTS_WITH translates to a LIKE with a '%' suffix") { + val p = + new Predicate("STARTS_WITH", Array[Expression](col("name"), litStr("foo"))) + val node = SparkPredicateTranslator.translate(p).getOrElse(fail()) + val parsed = LogicalExprNode.parseFrom(node.toByteArray) + assert(parsed.hasLike) + val patStr = parsed.getLike.getPattern.getLiteral.getUtf8Value + assert(patStr == "foo%") + } + + test("unknown predicate name returns None (becomes residual)") { + val p = new Predicate("UNKNOWN_OP", Array[Expression](col("x"), litInt(1))) + assert(SparkPredicateTranslator.translate(p).isEmpty) + } +} From 98a6381c94aab49aa8bf955b408922aa955be300 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 12 Jun 2026 13:27:54 +0200 Subject: [PATCH 5/5] feat(spark): bridge scaffold generator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `spark/scaffold/new_bridge.py` plus the `bridge-template/` it stamps out: a standalone Maven+Cargo bridge project wired to the datafusion-spark-bridge SDK — a Rust cdylib with `export_bridge!` + a demo in-memory provider, the four Java classes, the DataSourceRegister service file, a shaded-jar pom that bundles the cdylib, and a pyspark smoke test. Stdlib-only generator. Standalone tooling. Fifth of the 6-PR stack. Co-Authored-By: Claude Opus 4.8 (1M context) --- spark/scaffold/bridge-template/.gitignore | 3 + spark/scaffold/bridge-template/README.md | 54 ++++++ .../bridge-template/native/Cargo.toml | 21 +++ .../bridge-template/native/src/lib.rs | 59 ++++++ spark/scaffold/bridge-template/pom.xml | 174 ++++++++++++++++++ spark/scaffold/bridge-template/smoke_test.py | 55 ++++++ .../main/java/__PKG_PATH__/BridgeNative.java | 40 ++++ .../__PKG_PATH__/__PREFIX__DataSource.java | 21 +++ .../__PREFIX__ProviderFactory.java | 28 +++ .../__PKG_PATH__/__PREFIX__ScanBackend.java | 53 ++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + spark/scaffold/new_bridge.py | 138 ++++++++++++++ 12 files changed, 647 insertions(+) create mode 100644 spark/scaffold/bridge-template/.gitignore create mode 100644 spark/scaffold/bridge-template/README.md create mode 100644 spark/scaffold/bridge-template/native/Cargo.toml create mode 100644 spark/scaffold/bridge-template/native/src/lib.rs create mode 100644 spark/scaffold/bridge-template/pom.xml create mode 100644 spark/scaffold/bridge-template/smoke_test.py create mode 100644 spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java create mode 100644 spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java create mode 100644 spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java create mode 100644 spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java create mode 100644 spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/scaffold/new_bridge.py diff --git a/spark/scaffold/bridge-template/.gitignore b/spark/scaffold/bridge-template/.gitignore new file mode 100644 index 0000000..e2777a5 --- /dev/null +++ b/spark/scaffold/bridge-template/.gitignore @@ -0,0 +1,3 @@ +target/ +native/target/ +*.class diff --git a/spark/scaffold/bridge-template/README.md b/spark/scaffold/bridge-template/README.md new file mode 100644 index 0000000..8259e53 --- /dev/null +++ b/spark/scaffold/bridge-template/README.md @@ -0,0 +1,54 @@ +# __PREFIX__ Spark Bridge + +A Spark DataSource V2 connector for the `__FORMAT__` format, built on the +[datafusion-java Spark connector](https://github.com/apache/datafusion-java) +and its `datafusion-spark-bridge` Rust SDK. Generated by `spark/scaffold/new_bridge.py`; +the only code you need to touch is marked `TODO`. + +## What's here + +| File | Role | +| --- | --- | +| `native/src/lib.rs` | **Your provider.** `build_provider` turns option/partition bytes into a DataFusion `TableProvider` (demo: an in-memory table). `export_bridge!` generates the whole JNI surface. | +| `src/main/java/.../BridgeNative.java` | Declares the generated native methods and loads the bundled cdylib. Must keep the name/package the Rust macro was generated with. | +| `src/main/java/.../__PREFIX__ScanBackend.java` | Routes the connector's scan calls to `BridgeNative`. Pure delegation. | +| `src/main/java/.../__PREFIX__ProviderFactory.java` | The connector contract. Override `listPartitions` / `sharedScan` / `encodeOptions` here as the bridge grows. | +| `src/main/java/.../__PREFIX__DataSource.java` + `META-INF/services/...` | `spark.read.format("__FORMAT__")`. | +| `pom.xml` | One shaded fat jar with the cdylib bundled inside. | + +## Build + +```bash +# 0. Once: install datafusion-java to your local Maven repo (from its checkout): +# cargo build && ./mvnw install -DskipTests + +# 1. The cdylib: +cargo build --manifest-path native/Cargo.toml + +# 2. The shaded jar (target/__CRATE__-0.1.0-SNAPSHOT.jar): +mvn package +``` + +Release builds: `cargo build --release --manifest-path native/Cargo.toml` and +`mvn package -Dnative.profile=release`. + +## Use + +```python +df = (spark.read.format("__FORMAT__") + .option("rows", "5") # demo option; replace with your own + .load()) +df.show() +``` + +with the shaded jar on `spark.jars`. `python3 smoke_test.py` runs exactly this +against a local Spark (needs `SPARK_HOME` pointing at a Scala 2.13 distro). + +## Where to go next + +- Replace the demo `MemTable` in `native/src/lib.rs` with your real provider. +- Split the dataset into Spark tasks (`listPartitions`) or switch to + shared-scan mode (`sharedScan`) — task-sizing guidance lives in the + connector's `spark/README.md` ("Spark tasks vs. DataFusion partitions"). +- Multi-platform jars: build the cdylib per platform in CI and copy each into + `src`-side `//` directories before `mvn package`. diff --git a/spark/scaffold/bridge-template/native/Cargo.toml b/spark/scaffold/bridge-template/native/Cargo.toml new file mode 100644 index 0000000..c0d2996 --- /dev/null +++ b/spark/scaffold/bridge-template/native/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "__CRATE__-native" +version = "0.1.0" +edition = "2021" +publish = false + +# Standalone crate: the empty [workspace] table stops cargo from adopting +# this crate into any workspace found in a parent directory. +[workspace] + +[lib] +name = "__LIB__" +crate-type = ["cdylib"] + +[dependencies] +# TODO: replace the path with a git or crates.io dependency once you build +# outside a local datafusion-java checkout. +datafusion-spark-bridge = { path = "__BRIDGE_SDK_PATH__" } + +[profile.release] +strip = "debuginfo" diff --git a/spark/scaffold/bridge-template/native/src/lib.rs b/spark/scaffold/bridge-template/native/src/lib.rs new file mode 100644 index 0000000..8439217 --- /dev/null +++ b/spark/scaffold/bridge-template/native/src/lib.rs @@ -0,0 +1,59 @@ +//! Native side of the `__FORMAT__` Spark bridge. +//! +//! `export_bridge!` generates the whole JNI surface for +//! `__PKG__.BridgeNative`; the only code you own is [`build_provider`], +//! which turns the option/partition bytes your JVM factory encoded into a +//! concrete `TableProvider`. Everything downstream — type widening, session +//! construction, projection, pushed filters, planning, partition streams — +//! is the SDK's job. + +use std::sync::Arc; + +use datafusion_spark_bridge::datafusion::arrow::array::{Int64Array, StringArray}; +use datafusion_spark_bridge::datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion_spark_bridge::datafusion::arrow::record_batch::RecordBatch; +use datafusion_spark_bridge::datafusion::catalog::TableProvider; +use datafusion_spark_bridge::datafusion::datasource::MemTable; +use datafusion_spark_bridge::options::decode_options; +use datafusion_spark_bridge::{export_bridge, BridgeContext, JniResult}; + +/// Build the provider for one scan. +/// +/// `options` is whatever the JVM factory's `encodeOptions` produced — with +/// the default factory that is the connector's `OptionsCodec` format, decoded +/// below into a string map. `partition` is the per-task payload from +/// `listPartitions` (empty for the schema probe, for shared-scan mode, and +/// for the default single-partition layout). +/// +/// TODO: replace the demo `MemTable` with your real `TableProvider`. For +/// async construction (remote catalogs, object stores), use +/// `ctx.block_on(...)`. +fn build_provider( + _ctx: &BridgeContext, + options: &[u8], + _partition: &[u8], +) -> JniResult> { + let opts = decode_options(options)?; + let rows: i64 = match opts.get("rows") { + Some(v) => v + .parse() + .map_err(|e| format!("option 'rows' is not an integer: {e}"))?, + None => 3, + }; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("greeting", DataType::Utf8, false), + ])); + let ids = Int64Array::from_iter_values(0..rows); + let greetings = + StringArray::from_iter_values((0..rows).map(|i| format!("hello from __FORMAT__ #{i}"))); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(greetings)])?; + + Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?)) +} + +export_bridge! { + jni_class: "__JNI_CLASS__", + build_provider: build_provider, +} diff --git a/spark/scaffold/bridge-template/pom.xml b/spark/scaffold/bridge-template/pom.xml new file mode 100644 index 0000000..4e8c2b6 --- /dev/null +++ b/spark/scaffold/bridge-template/pom.xml @@ -0,0 +1,174 @@ + + + 4.0.0 + + __PKG__ + __CRATE__ + 0.1.0-SNAPSHOT + jar + + __PREFIX__ Spark Bridge + + + UTF-8 + 17 + 2.13 + 3.5.7 + __DF_JAVA_VERSION__ + + debug + + + + + + org.apache.datafusion + datafusion-java-spark_${scala.compat.version} + ${datafusion.java.version} + + + org.apache.spark + spark-sql_${scala.compat.version} + ${spark.version} + provided + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 3.1.0 + + + copy-bridge-cdylib + process-classes + run + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.6.0 + + + package + shade + + false + + + + org.scala-lang:scala-library + + + + + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + + + + + native-linux-amd64 + + unixlinuxamd64 + + + linux + x86_64 + lib__LIB__.so + + + + native-linux-x86_64 + + unixlinuxx86_64 + + + linux + x86_64 + lib__LIB__.so + + + + native-linux-aarch64 + + unixlinuxaarch64 + + + linux + aarch64 + lib__LIB__.so + + + + native-mac-x86_64 + + macx86_64 + + + darwin + x86_64 + lib__LIB__.dylib + + + + native-mac-aarch64 + + macaarch64 + + + darwin + aarch64 + lib__LIB__.dylib + + + + diff --git a/spark/scaffold/bridge-template/smoke_test.py b/spark/scaffold/bridge-template/smoke_test.py new file mode 100644 index 0000000..ca3925e --- /dev/null +++ b/spark/scaffold/bridge-template/smoke_test.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +"""Smoke test: scan the __FORMAT__ bridge's demo table through PySpark. + +Prerequisites: + - cargo build --manifest-path native/Cargo.toml (the bridge cdylib) + - mvn package (the shaded jar) + - a Scala 2.13 Spark distribution; the PyPI pyspark wheel embeds 2.12, so + point SPARK_HOME at e.g. spark-3.5.7-bin-hadoop3-scala2.13. + +Run: python3 smoke_test.py +""" + +import glob +import os +import sys +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent + +spark_home = os.environ.get("SPARK_HOME") +if not spark_home or not Path(spark_home, "jars").is_dir(): + sys.exit("Set SPARK_HOME to a Scala 2.13 Spark distribution.") +os.environ["SPARK_HOME"] = spark_home + +jars = glob.glob(str(PROJECT_ROOT / "target" / "__CRATE__-*.jar")) +jars = [j for j in jars if not j.endswith(("-sources.jar", "-javadoc.jar"))] +if not jars: + sys.exit("Shaded jar not found under target/. Run 'mvn package' first.") +jar = jars[0] + +from pyspark.sql import SparkSession # noqa: E402 + +spark = ( + SparkSession.builder.appName("__FORMAT__-smoke") + .master("local[2]") + .config("spark.jars", jar) + # extraClassPath PREPENDS, so the fat jar's Arrow wins over Spark's + # bundled (older) copy on both driver and executors. + .config("spark.driver.extraClassPath", jar) + .config("spark.executor.extraClassPath", jar) + .config("spark.driver.extraJavaOptions", "--add-opens=java.base/java.nio=ALL-UNNAMED") + .config("spark.executor.extraJavaOptions", "--add-opens=java.base/java.nio=ALL-UNNAMED") + .getOrCreate() +) + +df = spark.read.format("__FORMAT__").option("rows", "5").load() +df.printSchema() +df.show(truncate=False) +count = df.count() +filtered = df.filter("id >= 2").count() +spark.stop() + +assert count == 5, f"expected 5 rows, got {count}" +assert filtered == 3, f"expected 3 rows with id >= 2, got {filtered}" +print("smoke test OK: 5 rows scanned, filter returned 3") diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java new file mode 100644 index 0000000..7cf02de --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/BridgeNative.java @@ -0,0 +1,40 @@ +package __PKG__; + +import io.datafusion.spark.NativeLibraryLoader; + +/** + * JNI surface generated on the Rust side by {@code export_bridge!} with {@code jni_class = + * "__JNI_CLASS__"} — the mangled binary name of THIS class. Renaming or moving this class + * requires regenerating the Rust macro invocation to match. + * + *

The cdylib is bundled in this jar under {@code __PKG_PATH__///} (see the antrun + * execution in pom.xml) and extracted/loaded once per JVM by the connector's loader. + */ +final class BridgeNative { + + private BridgeNative() {} + + static { + NativeLibraryLoader.load(BridgeNative.class, "__PKG_PATH__", "__LIB__"); + } + + static native byte[] providerSchemaIpc(byte[] options, byte[] partition); + + static native long createScan( + byte[] options, + byte[] partition, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos); + + static native int partitionCount(long scanHandle); + + static native void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr); + + static native void executeStream(long scanHandle, long ffiStreamAddr); + + static native void closeScan(long scanHandle); +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java new file mode 100644 index 0000000..c888a0d --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__DataSource.java @@ -0,0 +1,21 @@ +package __PKG__; + +import io.datafusion.spark.DatafusionSource; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Gives the bridge its Spark format name: {@code spark.read.format("__FORMAT__")}. Registered via + * {@code META-INF/services/org.apache.spark.sql.sources.DataSourceRegister}. + */ +public class __PREFIX__DataSource extends DatafusionSource { + + @Override + public String shortName() { + return "__FORMAT__"; + } + + @Override + public String factoryFqcn(CaseInsensitiveStringMap options) { + return __PREFIX__ProviderFactory.class.getName(); + } +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java new file mode 100644 index 0000000..03498e4 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ProviderFactory.java @@ -0,0 +1,28 @@ +package __PKG__; + +import io.datafusion.spark.BridgeProviderFactory; +import io.datafusion.spark.ScanBackend; + +/** + * The bridge's contract with the Spark connector: the provider is built inside this bridge's own + * cdylib, and {@link #scanBackend()} is the only required method. + * + *

Useful optional overrides (see their javadoc on {@link BridgeProviderFactory}): + * + *

    + *
  • {@code encodeOptions} — only if you have your own options schema; the default ships the + * Spark options map in the connector's {@code OptionsCodec} format, which the Rust side + * already decodes via {@code datafusion_spark_bridge::options::decode_options}. + *
  • {@code listPartitions} — the default is ONE whole-dataset partition. Override to split + * into more Spark tasks (with optional preferred hosts and partition keys), or… + *
  • {@code sharedScan} — …opt into shared-scan mode: one provider per executor, one Spark + * task per DataFusion output partition. Mind the determinism contract. + *
+ */ +public final class __PREFIX__ProviderFactory implements BridgeProviderFactory { + + @Override + public ScanBackend scanBackend() { + return new __PREFIX__ScanBackend(); + } +} diff --git a/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java new file mode 100644 index 0000000..eb78dd1 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/java/__PKG_PATH__/__PREFIX__ScanBackend.java @@ -0,0 +1,53 @@ +package __PKG__; + +import io.datafusion.spark.ScanBackend; + +/** Routes the connector's scan calls to this bridge's own cdylib. Pure delegation. */ +public final class __PREFIX__ScanBackend implements ScanBackend { + + @Override + public byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes) { + return BridgeNative.providerSchemaIpc(options, partitionBytes); + } + + @Override + public long createScan( + byte[] options, + byte[] partitionBytes, + int targetPartitions, + int batchSize, + String[] optionKeys, + String[] optionValues, + String[] projectionColumns, + byte[][] filterProtos) { + return BridgeNative.createScan( + options, + partitionBytes, + targetPartitions, + batchSize, + optionKeys, + optionValues, + projectionColumns, + filterProtos); + } + + @Override + public int partitionCount(long scanHandle) { + return BridgeNative.partitionCount(scanHandle); + } + + @Override + public void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr) { + BridgeNative.executeStreamPartition(scanHandle, partition, ffiStreamAddr); + } + + @Override + public void executeStream(long scanHandle, long ffiStreamAddr) { + BridgeNative.executeStream(scanHandle, ffiStreamAddr); + } + + @Override + public void closeScan(long scanHandle) { + BridgeNative.closeScan(scanHandle); + } +} diff --git a/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..e72a178 --- /dev/null +++ b/spark/scaffold/bridge-template/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +__PKG__.__PREFIX__DataSource diff --git a/spark/scaffold/new_bridge.py b/spark/scaffold/new_bridge.py new file mode 100644 index 0000000..03b8de7 --- /dev/null +++ b/spark/scaffold/new_bridge.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Scaffold a new Spark bridge project from spark/scaffold/bridge-template/. + +Stamps out a standalone project (Maven + Cargo) wired to the +datafusion-spark-bridge SDK: a Rust cdylib with `export_bridge!` and a demo +in-memory provider, the four Java classes (native surface, ScanBackend, +factory, DataSource shim), the DataSourceRegister service file, a shaded-jar +pom that bundles the cdylib, a pyspark smoke test, and a README with the +build/run commands. + +Usage: + python3 spark/scaffold/new_bridge.py --name acme --package com.example.acme \ + [--output DIR] [--datafusion-java REPO_ROOT] + +`--name` is the Spark format short name (spark.read.format("acme")); it also +derives the class prefix (acme -> Acme, my_format -> MyFormat), the cargo +crate name, and the cdylib name. Stdlib only; no dependencies. +""" + +import argparse +import re +import sys +from pathlib import Path + +TEMPLATE_DIR = Path(__file__).resolve().parent / "bridge-template" + + +def jni_mangle(binary_class_name: str) -> str: + """JNI symbol mangling for a class's binary name: '_' -> '_1', '.' -> '_'.""" + return binary_class_name.replace("_", "_1").replace(".", "_") + + +def class_prefix(name: str) -> str: + return "".join(part.capitalize() for part in name.split("_")) + + +def validate(name: str, package: str) -> None: + if not re.fullmatch(r"[a-z][a-z0-9_]*", name): + sys.exit(f"--name must match [a-z][a-z0-9_]*, got: {name}") + if not re.fullmatch(r"[a-z][a-z0-9_]*(\.[a-z][a-z0-9_]*)+", package): + sys.exit(f"--package must be a dotted lowercase Java package, got: {package}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--name", required=True, help="Spark format short name, e.g. acme") + parser.add_argument( + "--package", required=True, help="Java package for the bridge, e.g. com.example.acme" + ) + parser.add_argument( + "--output", + help="Directory to create (default: ./-spark-bridge; must not exist)", + ) + parser.add_argument( + "--datafusion-java", + help="datafusion-java repo root providing the spark/bridge SDK crate " + "(default: the repo this script lives in)", + ) + args = parser.parse_args() + + validate(args.name, args.package) + prefix = class_prefix(args.name) + crate = args.name.replace("_", "-") + "-spark-bridge" + lib = args.name + "_spark_bridge" + repo = Path(args.datafusion_java).resolve() if args.datafusion_java else TEMPLATE_DIR.parents[2] + sdk_path = repo / "spark" / "bridge" + if not (sdk_path / "Cargo.toml").is_file(): + sys.exit(f"datafusion-spark-bridge crate not found at {sdk_path}") + out = Path(args.output) if args.output else Path.cwd() / crate + if out.exists(): + sys.exit(f"output directory already exists: {out}") + + tokens = { + "__PKG__": args.package, + "__PKG_PATH__": args.package.replace(".", "/"), + "__JNI_CLASS__": jni_mangle(args.package + ".BridgeNative"), + "__PREFIX__": prefix, + "__FORMAT__": args.name, + "__CRATE__": crate, + "__LIB__": lib, + "__BRIDGE_SDK_PATH__": str(sdk_path), + "__DF_JAVA_VERSION__": read_repo_version(repo), + } + + generated = [] + for src in sorted(TEMPLATE_DIR.rglob("*")): + if not src.is_file(): + continue + rel = str(src.relative_to(TEMPLATE_DIR)) + for token, value in tokens.items(): + rel = rel.replace(token, value) + dst = out / rel + dst.parent.mkdir(parents=True, exist_ok=True) + text = src.read_text() + for token, value in tokens.items(): + text = text.replace(token, value) + dst.write_text(text) + generated.append(rel) + + print(f"Generated {len(generated)} files under {out}:") + for rel in generated: + print(f" {rel}") + print() + print("Next steps (see the generated README.md):") + print(f" 1. cd {out}") + print(" 2. cargo build --release --manifest-path native/Cargo.toml") + print(" 3. mvn package -Dnative.profile=release") + print(f" 4. spark.read.format(\"{args.name}\") with the shaded jar on spark.jars") + + +def read_repo_version(repo: Path) -> str: + """datafusion-java's maven version, scraped from the parent pom.""" + pom = (repo / "pom.xml").read_text() + m = re.search(r"([^<]+)", pom) + if not m: + sys.exit(f"could not find in {repo}/pom.xml") + return m.group(1) + + +if __name__ == "__main__": + main()