> {
+ let schema = Arc::new(Schema::empty());
+ let table = MemTable::try_new(schema, vec![vec![]])?;
+ Ok(Arc::new(table))
+}
+
+export_bridge! {
+ jni_class: "com_example_testbridge_BridgeNative",
+ build_provider: build_provider,
+}
+
+#[test]
+fn builder_contract_runs_outside_jvm() {
+ // Expansion + linking is the macro test; this additionally runs the
+ // builder through the same BridgeContext the expansion hands it.
+ let ctx = BridgeContext::get();
+ let provider = build_provider(&ctx, &[], &[]).expect("builder failed");
+ assert_eq!(provider.schema().fields().len(), 0);
+}
diff --git a/spark/pom.xml b/spark/pom.xml
new file mode 100644
index 0000000..90e4e6d
--- /dev/null
+++ b/spark/pom.xml
@@ -0,0 +1,150 @@
+
+
+
+ 4.0.0
+
+
+ org.apache.datafusion
+ datafusion-java-parent
+ 0.2.0-SNAPSHOT
+
+
+ datafusion-java-spark_2.13
+ jar
+
+ Apache DataFusion Java Spark Connector
+
+ Generic Spark DataSource V2 connector for DataFusion TableProviders.
+ Domain bridges implement BridgeProviderFactory over a cdylib built
+ with the datafusion-spark-bridge Rust SDK; this module supplies the
+ Spark plumbing, predicate translation, Arrow-to-Spark schema
+ conversion, and the shared-scan cache. Pure JVM artifact — the
+ native code ships inside each bridge's own jar.
+
+
+
+ 2.13
+ 2.13.14
+ 3.5.7
+
+
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
+
+
+ org.apache.spark
+ spark-core_${scala.compat.version}
+ ${spark.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_${scala.compat.version}
+ ${spark.version}
+ provided
+
+
+
+ org.apache.datafusion
+ datafusion-java
+
+
+
+ org.apache.arrow
+ arrow-vector
+
+
+ org.apache.arrow
+ arrow-c-data
+
+
+ org.apache.arrow
+ arrow-memory-netty
+ runtime
+
+
+
+ org.scalatest
+ scalatest_${scala.compat.version}
+ 3.2.18
+ test
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.8.1
+
+
+
+ compile
+ testCompile
+
+
+
+
+ ${scala.version}
+
+ -deprecation
+ -feature
+ -unchecked
+
+ all
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ --add-opens=java.base/java.nio=ALL-UNNAMED
+ true
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.2.0
+
+ ${project.build.directory}/scalatest-reports
+ .
+ WDF TestSuite.txt
+ --add-opens=java.base/java.nio=ALL-UNNAMED
+
+
+
+ test
+ test
+
+
+
+
+
+
+
diff --git a/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java
new file mode 100644
index 0000000..3bcf7ad
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/BridgeProviderFactory.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+import java.util.Map;
+
+/**
+ * Bridge interface implemented per domain (HDF5, custom Iceberg, an in-house format, etc.). A
+ * bridge owns its options encoding and a native scan implementation built with {@code
+ * datafusion_spark_bridge::export_bridge!}; the connector Spark plumbing is generic — it knows only
+ * this interface.
+ *
+ * The single required method is {@link #scanBackend()}, returning the delegations to the JNI
+ * class the bridge named in its {@code export_bridge!} invocation. Everything else has a working
+ * default: {@link #encodeOptions(Map)} encodes the Spark options via {@link OptionsCodec}, and
+ * {@link #listPartitions(byte[])} reports a single partition.
+ *
+ *
Implementations must be no-arg constructable so the Spark connector can instantiate them
+ * reflectively via {@link Class#forName(String)} on the executor.
+ */
+public interface BridgeProviderFactory {
+
+ /**
+ * The native scan implementation this bridge talks to: delegations to the JNI class named in the
+ * bridge's {@code export_bridge!} invocation, whose generated {@code createScan} builds the
+ * provider from the options/partition bytes in process. Called wherever the connector needs
+ * native work — driver-side schema/plan probes and executor-side streams — always on a factory
+ * freshly instantiated from its class name, so the returned backend never has to be serializable.
+ */
+ ScanBackend scanBackend();
+
+ /**
+ * Convert Spark's flat option map to the bridge's encoded options. Driver-side only; the bytes
+ * ship verbatim through {@code DatafusionInputPartition} and are the scan's identity in
+ * shared-scan mode (encode deterministically).
+ *
+ *
Default: {@link OptionsCodec#encode(Map)} — the key-sorted length-prefixed pair format that
+ * {@code datafusion_spark_bridge::options} decodes on the Rust side. Override only if the bridge
+ * already has its own options schema (e.g. a protobuf).
+ *
+ * @throws IllegalArgumentException if required options are missing or invalid
+ */
+ default byte[] encodeOptions(Map sparkOptions) {
+ return OptionsCodec.encode(sparkOptions);
+ }
+
+ /**
+ * Enumerate partitions for this dataset. One Spark task is created per returned {@link
+ * PartitionInfo}. Driver-side only.
+ *
+ * Each partition's {@code partitionBytes} ships verbatim through {@code
+ * DatafusionInputPartition} to the executor, where it is passed to {@link
+ * ScanBackend#createScan}. Use it to encode whatever slice metadata (row range, sub-options, file
+ * offsets, segment id, …) the bridge needs to materialise *that* partition.
+ *
+ *
Each partition's {@code preferredLocations} hostnames are returned from {@code
+ * InputPartition.preferredLocations()} so Spark co-locates the task with the data; empty array =
+ * no preference.
+ *
+ *
Default: one partition ({@code "p0"}, empty payload, no host preference) — one Spark task
+ * scans the whole dataset. Fine for small tables and first bring-up; override (or opt into {@link
+ * #sharedScan(byte[])}) before pointing it at anything large. Size guidance lives in {@code
+ * spark/README.md}.
+ */
+ default PartitionInfo[] listPartitions(byte[] optionsBytes) {
+ return new PartitionInfo[] {new PartitionInfo("p0", new byte[0], new String[0])};
+ }
+
+ /**
+ * Filter-aware variant of {@link #listPartitions(byte[])}. The connector calls this overload with
+ * the pushed-down predicates ({@code LogicalExprNode} proto bytes, one array per predicate, same
+ * encoding the executor later replays via {@link ScanBackend#createScan}). Bridges that can map
+ * predicates onto their partition layout (e.g. {@code segment_id = 'x'}) should prune partitions
+ * that cannot match — pruning here eliminates whole Spark tasks, whereas the per-task filter only
+ * reduces rows inside a task.
+ *
+ *
Pruning must be conservative: only drop a partition when NO row in it can satisfy the
+ * conjunction of all pushed predicates. The default delegates to the filter-unaware overload (no
+ * pruning), which is always correct.
+ */
+ default PartitionInfo[] listPartitions(byte[] optionsBytes, byte[][] filterProtoBytes) {
+ return listPartitions(optionsBytes);
+ }
+
+ /**
+ * Opt into shared-scan mode for this dataset. Default {@code false} (per-partition payload mode,
+ * the {@link #listPartitions(byte[])} path).
+ *
+ *
When {@code true}, the connector builds ONE provider per (executor JVM × scan) with empty
+ * {@code partitionBytes}, plans it once, and runs one Spark task per DataFusion output partition
+ * — task {@code i} streams plan partition {@code i} from the shared, cached plan. This amortises
+ * provider construction cost across all tasks on an executor and is the right model when the
+ * dataset has many small partitions or provider construction is expensive (remote metadata,
+ * connections). {@link #listPartitions(byte[])} and {@link #reportPartitioning(byte[])} are NOT
+ * called in this mode, and the scan reports {@code UnknownPartitioning} (DataFusion-native
+ * partitions carry no key contract).
+ *
+ *
Determinism contract. The driver counts partitions by planning once; every executor
+ * re-plans independently and must arrive at the same result. A bridge returning {@code true}
+ * guarantees:
+ *
+ *
+ * - The provider's schema, partitioning, and per-partition row content are a pure function of
+ * {@code optionsBytes}. Remote sources must pin a snapshot (version, timestamp) inside
+ * the options; data that compacts or moves between driver planning and executor execution
+ * otherwise yields wrong results that no runtime check can catch.
+ *
- The provider's {@code ExecutionPlan} supports calling {@code execute(i)} more than once
+ * per plan instance (Spark task retry and speculative execution re-execute a partition
+ * index, sometimes concurrently). Stateless scans satisfy this; single-shot streams do not.
+ *
+ *
+ * The connector fails tasks with a clear error when the executor's partition count diverges
+ * from the driver's — but identical counts with different contents cannot be detected.
+ */
+ default boolean sharedScan(byte[] optionsBytes) {
+ return false;
+ }
+
+ /**
+ * Declare how rows are partitioned across the {@link PartitionInfo} entries returned by {@link
+ * #listPartitions(byte[])}. Driver-side only.
+ *
+ *
When non-null, the connector surfaces a {@code KeyGroupedPartitioning(keys,
+ * listPartitions(...).length)} to Spark via {@code SupportsReportPartitioning} so the optimizer
+ * can elide shuffles ahead of joins/aggregations on the declared keys.
+ *
+ *
Default returns {@code null} — no partitioning guarantees, Spark plans as if the scan's
+ * output ordering and grouping are unknown.
+ *
+ *
If a bridge implements this, it must hold the {@link ReportedPartitioning} contract: every
+ * row in a given partition evaluates to the same tuple of key values under the declared
+ * transforms.
+ *
+ *
Spark 3.3+ caveat: the reported partitioning only takes effect when every {@link
+ * PartitionInfo} also carries {@link PartitionInfo#partitionKeyValues()} (surfaced to Spark via
+ * {@code HasPartitionKey}); without key values Spark ignores the declared {@code
+ * KeyGroupedPartitioning} entirely. Storage-partitioned joins additionally require {@code
+ * spark.sql.sources.v2.bucketing.enabled=true}.
+ */
+ default ReportedPartitioning reportPartitioning(byte[] optionsBytes) {
+ return null;
+ }
+}
diff --git a/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java
new file mode 100644
index 0000000..eb4766a
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/NativeLibraryLoader.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
+import java.util.Locale;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Extracts a cdylib bundled inside a jar to a temp file and loads it via {@link System#load}.
+ * Expected layout inside the jar:
+ *
+ *
+ * <resourcePrefix>/<os>/<arch>/lib<name>.<ext>
+ *
+ *
+ * where {@code } is one of {@code linux}, {@code darwin}, {@code windows} and {@code } is
+ * {@code x86_64} or {@code aarch64}.
+ *
+ * Bridges call {@link #load(Class, String, String)} from their native class's static
+ * initializer, with their own resource prefix, instead of hand-rolling extraction. Bundle the
+ * cdylib with the antrun-copy pattern shown in "Packaging your bridge" in {@code spark/README.md}.
+ */
+public final class NativeLibraryLoader {
+
+ /** {@code /} entries already extracted and loaded by this classloader. */
+ private static final Set LOADED = ConcurrentHashMap.newKeySet();
+
+ private NativeLibraryLoader() {}
+
+ /**
+ * Extract {@code ///} from {@code anchor}'s classloader
+ * and {@link System#load} it. Idempotent per (prefix, name): repeated calls — e.g. one per Spark
+ * task instantiating the bridge's native class — load once.
+ *
+ * @param anchor class whose classloader holds the resource (the bridge's own native class, so the
+ * lookup works under Spark's per-application classloaders)
+ * @param resourcePrefix jar-internal directory, no leading or trailing slash (e.g. {@code
+ * "com/example/mybridge"})
+ * @param name unmapped library name (e.g. {@code "my_bridge"} for {@code libmy_bridge.so})
+ * @throws UnsatisfiedLinkError if the resource is missing or extraction fails
+ */
+ public static void load(Class> anchor, String resourcePrefix, String name) {
+ String key = resourcePrefix + "/" + name;
+ if (!LOADED.add(key)) {
+ return;
+ }
+ String resource =
+ String.format(
+ "/%s/%s/%s/%s",
+ resourcePrefix, currentOs(), currentArch(), System.mapLibraryName(name));
+ try (InputStream in = anchor.getResourceAsStream(resource)) {
+ if (in == null) {
+ LOADED.remove(key);
+ throw new UnsatisfiedLinkError("Native library not found on classpath: " + resource);
+ }
+ Path tmp = Files.createTempFile("libdatafusion-spark-", "-" + System.mapLibraryName(name));
+ tmp.toFile().deleteOnExit();
+ Files.copy(in, tmp, StandardCopyOption.REPLACE_EXISTING);
+ System.load(tmp.toAbsolutePath().toString());
+ } catch (IOException e) {
+ LOADED.remove(key);
+ throw new UnsatisfiedLinkError(
+ "Failed to extract native library " + resource + ": " + e.getMessage());
+ } catch (RuntimeException | Error e) {
+ LOADED.remove(key);
+ throw e;
+ }
+ }
+
+ private static String currentOs() {
+ String os = System.getProperty("os.name", "").toLowerCase(Locale.ROOT);
+ if (os.contains("linux")) return "linux";
+ if (os.contains("mac") || os.contains("darwin")) return "darwin";
+ if (os.contains("windows")) return "windows";
+ throw new UnsupportedOperationException("Unsupported OS: " + os);
+ }
+
+ private static String currentArch() {
+ String arch = System.getProperty("os.arch", "").toLowerCase(Locale.ROOT);
+ if (arch.equals("amd64") || arch.equals("x86_64")) return "x86_64";
+ if (arch.equals("aarch64") || arch.equals("arm64")) return "aarch64";
+ throw new UnsupportedOperationException("Unsupported arch: " + arch);
+ }
+}
diff --git a/spark/src/main/java/io/datafusion/spark/OptionsCodec.java b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java
new file mode 100644
index 0000000..0d16a28
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/OptionsCodec.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+import java.io.ByteArrayOutputStream;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.TreeMap;
+
+/**
+ * Default wire format for {@link BridgeProviderFactory#encodeOptions(Map)}: the Spark options map
+ * as length-prefixed UTF-8 pairs, sorted by key.
+ *
+ * Layout (all integers big-endian {@code int32}): entry count, then per entry key length, key
+ * bytes, value length, value bytes. Key-sorting makes the bytes a pure function of the map's
+ * contents regardless of source iteration order — required by the shared-scan determinism contract,
+ * where the options bytes are the cache/plan identity.
+ *
+ *
The Rust decoder lives in {@code datafusion_spark_bridge::options}; bridges using the default
+ * {@code encodeOptions} read their options there as a {@code BTreeMap}. The two
+ * implementations are pinned to each other by a shared test fixture.
+ */
+public final class OptionsCodec {
+
+ private OptionsCodec() {}
+
+ /** Encode {@code options} sorted by key. {@code null} or empty map encodes as count 0. */
+ public static byte[] encode(Map options) {
+ TreeMap sorted = options == null ? new TreeMap<>() : new TreeMap<>(options);
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ writeInt(out, sorted.size());
+ for (Map.Entry e : sorted.entrySet()) {
+ if (e.getKey() == null || e.getValue() == null) {
+ throw new IllegalArgumentException("OptionsCodec does not accept null keys or values");
+ }
+ writeBytes(out, e.getKey().getBytes(StandardCharsets.UTF_8));
+ writeBytes(out, e.getValue().getBytes(StandardCharsets.UTF_8));
+ }
+ return out.toByteArray();
+ }
+
+ /** Decode bytes produced by {@link #encode(Map)}. Preserves the encoded (sorted) order. */
+ public static Map decode(byte[] bytes) {
+ Map out = new LinkedHashMap<>();
+ if (bytes == null || bytes.length == 0) {
+ return out;
+ }
+ ByteBuffer buf = ByteBuffer.wrap(bytes);
+ int count = readCount(buf, "entry count");
+ for (int i = 0; i < count; i++) {
+ String key = readString(buf, "key of entry " + i);
+ String value = readString(buf, "value of entry " + i);
+ out.put(key, value);
+ }
+ if (buf.hasRemaining()) {
+ throw new IllegalArgumentException(
+ "OptionsCodec: " + buf.remaining() + " trailing byte(s) after " + count + " entries");
+ }
+ return out;
+ }
+
+ private static void writeInt(ByteArrayOutputStream out, int v) {
+ out.write((v >>> 24) & 0xFF);
+ out.write((v >>> 16) & 0xFF);
+ out.write((v >>> 8) & 0xFF);
+ out.write(v & 0xFF);
+ }
+
+ private static void writeBytes(ByteArrayOutputStream out, byte[] bytes) {
+ writeInt(out, bytes.length);
+ out.write(bytes, 0, bytes.length);
+ }
+
+ private static int readCount(ByteBuffer buf, String what) {
+ if (buf.remaining() < 4) {
+ throw new IllegalArgumentException("OptionsCodec: truncated " + what);
+ }
+ int v = buf.getInt();
+ if (v < 0) {
+ throw new IllegalArgumentException("OptionsCodec: negative " + what + ": " + v);
+ }
+ return v;
+ }
+
+ private static String readString(ByteBuffer buf, String what) {
+ int len = readCount(buf, "length of " + what);
+ if (buf.remaining() < len) {
+ throw new IllegalArgumentException("OptionsCodec: truncated " + what);
+ }
+ byte[] bytes = new byte[len];
+ buf.get(bytes);
+ return new String(bytes, StandardCharsets.UTF_8);
+ }
+}
diff --git a/spark/src/main/java/io/datafusion/spark/PartitionInfo.java b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java
new file mode 100644
index 0000000..e6e061b
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/PartitionInfo.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+/**
+ * Driver-side descriptor for a single partition produced by {@link
+ * BridgeProviderFactory#listPartitions(byte[])}. Carries the bridge-specific slice payload that the
+ * executor passes back into {@link ScanBackend#createScan}, plus
+ * optional host hints for Spark's scheduler.
+ *
+ * Fields:
+ *
+ *
+ * - {@code id} — stable, human-readable identifier for this partition (e.g. a segment id).
+ * Surfaces in Spark UI, logs, and exception messages. Must be non-empty.
+ *
- {@code partitionBytes} — opaque per-partition payload. Bridge encodes whatever the executor
+ * needs to materialise *this* slice (offsets, row ranges, sub-options, etc.). Combined with
+ * the global {@code optionsBytes} in {@link ScanBackend#createScan}. Empty array = no
+ * per-partition state (single-partition table).
+ *
- {@code preferredLocations} — hostnames where this partition's data lives. Returned from
+ * {@code InputPartition.preferredLocations()} so Spark can co-locate the task with the data.
+ * Empty array = no preference. Honoured subject to {@code spark.locality.wait}.
+ *
- {@code partitionKeyValues} — optional values of the partitioning keys for every row in this
+ * partition, in the same order as {@link BridgeProviderFactory#reportPartitioning(byte[])}'s
+ * declared transforms. {@code null} = no key (the default). When the bridge reports a
+ * partitioning AND every partition carries key values, the connector exposes them to Spark
+ * via {@code HasPartitionKey} — required on Spark 3.3+ for the reported {@code
+ * KeyGroupedPartitioning} to have any effect (and storage-partitioned joins additionally
+ * require {@code spark.sql.sources.v2.bucketing.enabled=true}). Values must be Java types
+ * that Spark's {@code CatalystTypeConverters} can convert for the key columns' data types
+ * (e.g. {@code String}, {@code Long}, {@code Integer}, {@code java.time.Instant}, {@code
+ * java.time.LocalDate}, {@code java.math.BigDecimal}), and the array length must equal the
+ * number of declared keys.
+ *
+ */
+public record PartitionInfo(
+ String id, byte[] partitionBytes, String[] preferredLocations, Object[] partitionKeyValues) {
+
+ public PartitionInfo {
+ if (id == null || id.isEmpty()) {
+ throw new IllegalArgumentException("PartitionInfo: id must be non-empty");
+ }
+ if (partitionBytes == null) {
+ partitionBytes = new byte[0];
+ }
+ if (preferredLocations == null) {
+ preferredLocations = new String[0];
+ }
+ // partitionKeyValues stays null when absent: null and "no key" are the same state,
+ // and DatafusionBatch distinguishes keyed from unkeyed partitions by it.
+ }
+
+ /** Without partition key values — the common case. */
+ public PartitionInfo(String id, byte[] partitionBytes, String[] preferredLocations) {
+ this(id, partitionBytes, preferredLocations, null);
+ }
+}
diff --git a/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java
new file mode 100644
index 0000000..639fec9
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/ReportedPartitioning.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+import java.util.Arrays;
+
+import org.apache.spark.sql.connector.expressions.Expressions;
+import org.apache.spark.sql.connector.expressions.Transform;
+
+/**
+ * Driver-side declaration of how a bridge's data is partitioned on the key columns. When supplied
+ * via {@link BridgeProviderFactory#reportPartitioning(byte[])}, the connector surfaces a {@link
+ * org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} from {@link
+ * org.apache.spark.sql.connector.read.SupportsReportPartitioning#outputPartitioning()} — Spark's
+ * optimizer can then skip the shuffle ahead of joins/aggregations whose grouping keys line up with
+ * these transforms.
+ *
+ * Contract: for any partition reported by {@link BridgeProviderFactory#listPartitions(byte[])},
+ * every row produced by that partition must evaluate to the same tuple of key values under these
+ * transforms. Different partitions may share key values (Spark will fuse them); they must
+ * not straddle key values.
+ *
+ *
The partition count Spark sees is {@code listPartitions(...).length}; it is not carried here
+ * to keep a single source of truth.
+ */
+public final class ReportedPartitioning {
+
+ private final Transform[] keys;
+
+ public ReportedPartitioning(Transform[] keys) {
+ if (keys == null || keys.length == 0) {
+ throw new IllegalArgumentException(
+ "ReportedPartitioning: keys must contain at least one transform");
+ }
+ this.keys = keys;
+ }
+
+ public Transform[] keys() {
+ return keys;
+ }
+
+ /**
+ * Convenience: declare identity partitioning on one or more columns (a row in partition P has the
+ * same {@code (col1, col2, …)} values as every other row in P).
+ */
+ public static ReportedPartitioning identity(String... columns) {
+ if (columns == null || columns.length == 0) {
+ throw new IllegalArgumentException(
+ "ReportedPartitioning.identity: at least one column required");
+ }
+ Transform[] ts = Arrays.stream(columns).map(Expressions::identity).toArray(Transform[]::new);
+ return new ReportedPartitioning(ts);
+ }
+
+ /**
+ * Convenience: declare hash-bucket partitioning. Mirrors Spark's {@code bucket(N, cols…)}
+ * transform — each row is assigned to bucket {@code hash(cols) mod numBuckets}.
+ */
+ public static ReportedPartitioning bucket(int numBuckets, String... columns) {
+ if (numBuckets <= 0) {
+ throw new IllegalArgumentException(
+ "ReportedPartitioning.bucket: numBuckets must be > 0, got " + numBuckets);
+ }
+ if (columns == null || columns.length == 0) {
+ throw new IllegalArgumentException(
+ "ReportedPartitioning.bucket: at least one column required");
+ }
+ return new ReportedPartitioning(new Transform[] {Expressions.bucket(numBuckets, columns)});
+ }
+}
diff --git a/spark/src/main/java/io/datafusion/spark/ScanBackend.java b/spark/src/main/java/io/datafusion/spark/ScanBackend.java
new file mode 100644
index 0000000..a994c98
--- /dev/null
+++ b/spark/src/main/java/io/datafusion/spark/ScanBackend.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark;
+
+/**
+ * Native scan surface the connector plumbing talks to: one method per JNI entry point generated by
+ * the bridge's {@code datafusion_spark_bridge::export_bridge!} invocation. A bridge's
+ * implementation is six one-line delegations to the JNI class named in that macro, whose {@code
+ * createScan} builds the provider from {@code options}/{@code partitionBytes} in process.
+ *
+ *
Implementations must be stateless or thread-safe: the driver probes schemas and plans through
+ * one instance while executor tasks stream through others, and scan handles are shared across
+ * threads by the shared-scan cache. Handle-based methods accept handles produced by {@code
+ * createScan} on any instance of the same implementation.
+ */
+public interface ScanBackend {
+
+ /**
+ * Driver-side schema probe: the widened Arrow schema of the provider described by {@code options}
+ * + {@code partitionBytes}, serialized as Arrow IPC bytes (deserialize with {@code
+ * MessageSerializer.deserializeSchema}).
+ */
+ byte[] providerSchemaIpc(byte[] options, byte[] partitionBytes);
+
+ /**
+ * Build a planned scan and return its handle. {@code targetPartitions}/{@code batchSize} {@code
+ * <= 0} leave DataFusion defaults; {@code optionKeys}/{@code optionValues} are parallel config
+ * override arrays; empty {@code projectionColumns} selects all columns; each {@code filterProtos}
+ * element is a serialized {@code datafusion.LogicalExprNode}.
+ *
+ *
The caller owns the handle and must pair it with {@link #closeScan(long)}. Closing while a
+ * stream opened from the handle is in flight is undefined behaviour — the shared-scan cache's
+ * refcount enforces this; any other caller must serialize close itself.
+ */
+ long createScan(
+ byte[] options,
+ byte[] partitionBytes,
+ int targetPartitions,
+ int batchSize,
+ String[] optionKeys,
+ String[] optionValues,
+ String[] projectionColumns,
+ byte[][] filterProtos);
+
+ /** Output partition count of the planned physical plan. */
+ int partitionCount(long scanHandle);
+
+ /**
+ * Open an independent stream over ONE plan partition, writing an {@code FFI_ArrowArrayStream}
+ * into the caller-allocated struct at {@code ffiStreamAddr}. Concurrent-safe across JVM threads.
+ */
+ void executeStreamPartition(long scanHandle, int partition, long ffiStreamAddr);
+
+ /**
+ * Stream the WHOLE plan (all partitions coalesced) into the caller-allocated {@code
+ * FFI_ArrowArrayStream} at {@code ffiStreamAddr}. Used by per-partition mode.
+ */
+ void executeStream(long scanHandle, long ffiStreamAddr);
+
+ /** Drop the planned scan. See {@link #createScan} for the close-vs-in-flight contract. */
+ void closeScan(long scanHandle);
+}
diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000..3e612e0
--- /dev/null
+++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1 @@
+io.datafusion.spark.DatafusionSource
diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala
new file mode 100644
index 0000000..30f62f8
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/ArrowColumnarBatchIteration.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
+
+/**
+ * Shared `next()`/`get()` loop for the connector's columnar readers: each `loadNextBatch()`
+ * yields a `VectorSchemaRoot` wrapped as a `ColumnarBatch` of [[NonClosingArrowColumnVector]]s
+ * (the reader owns the vectors; Spark must not close them per batch).
+ */
+private[spark] trait ArrowColumnarBatchIteration {
+
+ /** The Arrow stream this reader drains. Stable for the reader's lifetime. */
+ protected def arrowReader: ArrowReader
+
+ private var currentBatch: ColumnarBatch = _
+
+ def next(): Boolean = {
+ if (currentBatch != null) {
+ currentBatch = null
+ }
+ if (!arrowReader.loadNextBatch()) return false
+ val root = arrowReader.getVectorSchemaRoot
+ val vectors: java.util.List[FieldVector] = root.getFieldVectors
+ val cols = new Array[ColumnVector](vectors.size())
+ var i = 0
+ while (i < vectors.size()) {
+ cols(i) = new NonClosingArrowColumnVector(vectors.get(i))
+ i += 1
+ }
+ val batch = new ColumnarBatch(cols)
+ batch.setNumRows(root.getRowCount)
+ currentBatch = batch
+ true
+ }
+
+ def get(): ColumnarBatch = currentBatch
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala
new file mode 100644
index 0000000..2e8f1a5
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/ArrowToSparkSchema.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
+import org.apache.spark.sql.types._
+
+/**
+ * Arrow Schema → Spark StructType converter.
+ *
+ * The reported Spark schema MUST be one whose runtime ArrowColumnVector accessor Spark can pick
+ * for the underlying Arrow vector. Spark 3.5's `ArrowColumnVector` supports the following
+ * accessors: Boolean, Byte, Short, Int, Long, Float, Double, Decimal, Date, Timestamp,
+ * TimestampNTZ, Duration (DayTimeInterval), String, LargeString, Binary, Array, Map, Struct,
+ * Null. No unsigned-int or Time accessor exists; we surface a clear error at schema discovery
+ * for those — the alternative is silent corruption.
+ *
+ * The widening layer (datafusion-spark-bridge, compiled into every bridge cdylib) inserts a
+ * `WideningTableProvider` upstream of the Spark reader that casts unsupported types kernel-side
+ * (UInt*→signed wider, Float16→Float32, non-µs Timestamp→µs Timestamp, Time→Int) so Spark only
+ * ever sees compatible Arrow types.
+ */
+object ArrowToSparkSchema {
+
+ def toSparkSchema(schema: Schema): StructType =
+ StructType(schema.getFields.asScala.toSeq.map(toSparkField))
+
+ private def toSparkField(f: Field): StructField = {
+ val dt =
+ Option(f.getDictionary) match {
+ case Some(_) =>
+ unsupported(f, "dictionary-encoded fields (need dictionary value schema in JNI)")
+ case None => toSparkType(f)
+ }
+ StructField(f.getName, dt, f.isNullable)
+ }
+
+ private def toSparkType(f: Field): DataType = f.getType match {
+ case _: ArrowType.Bool => BooleanType
+
+ case t: ArrowType.Int =>
+ (t.getBitWidth, t.getIsSigned) match {
+ case (8, true) => ByteType
+ case (16, true) => ShortType
+ case (32, true) => IntegerType
+ case (64, true) => LongType
+ case (bits, false) =>
+ unsupported(
+ f,
+ s"unsigned integer UInt$bits (Spark ArrowColumnVector has no unsigned accessor; " +
+ "widening layer casts these before Spark sees them — this branch indicates the " +
+ "WideningTableProvider was bypassed)"
+ )
+ case (bits, signed) => unsupported(f, s"Int(bits=$bits, signed=$signed)")
+ }
+
+ case t: ArrowType.FloatingPoint =>
+ t.getPrecision match {
+ case FloatingPointPrecision.HALF =>
+ unsupported(f, "Float16 (widening layer must cast to Float32 before Spark)")
+ case FloatingPointPrecision.SINGLE => FloatType
+ case FloatingPointPrecision.DOUBLE => DoubleType
+ case other => unsupported(f, s"FloatingPoint($other)")
+ }
+
+ case _: ArrowType.Utf8 => StringType
+ case _: ArrowType.LargeUtf8 => StringType
+ case _: ArrowType.Binary => BinaryType
+ case _: ArrowType.LargeBinary => BinaryType
+ case _: ArrowType.FixedSizeBinary => BinaryType
+
+ case d: ArrowType.Date =>
+ d.getUnit match {
+ case DateUnit.DAY | DateUnit.MILLISECOND => DateType
+ case other => unsupported(f, s"Date($other)")
+ }
+
+ case t: ArrowType.Timestamp =>
+ val _unused = t.getUnit
+ if (t.getTimezone == null) TimestampNTZType else TimestampType
+
+ case ti: ArrowType.Time =>
+ unsupported(
+ f,
+ s"Time(${ti.getUnit}, ${ti.getBitWidth}-bit) — Spark has no time-of-day type"
+ )
+
+ case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
+
+ case _: ArrowType.Null => NullType
+
+ case _: ArrowType.Duration => DayTimeIntervalType()
+
+ case iv: ArrowType.Interval =>
+ iv.getUnit match {
+ case IntervalUnit.YEAR_MONTH => YearMonthIntervalType()
+ case IntervalUnit.DAY_TIME => DayTimeIntervalType()
+ case IntervalUnit.MONTH_DAY_NANO =>
+ unsupported(f, "Interval(MONTH_DAY_NANO) — no clean Spark equivalent")
+ }
+
+ case _: ArrowType.Struct =>
+ StructType(f.getChildren.asScala.toSeq.map(toSparkField))
+
+ case _: ArrowType.List =>
+ val child = f.getChildren.get(0)
+ ArrayType(toSparkType(child), containsNull = child.isNullable)
+ case _: ArrowType.LargeList =>
+ val child = f.getChildren.get(0)
+ ArrayType(toSparkType(child), containsNull = child.isNullable)
+ case _: ArrowType.FixedSizeList =>
+ val child = f.getChildren.get(0)
+ ArrayType(toSparkType(child), containsNull = child.isNullable)
+
+ case _: ArrowType.Map =>
+ val entries = f.getChildren.get(0)
+ val keyValue = entries.getChildren
+ val keyField = keyValue.get(0)
+ val valueField = keyValue.get(1)
+ MapType(toSparkType(keyField), toSparkType(valueField), valueField.isNullable)
+
+ case _: ArrowType.Union =>
+ unsupported(f, "Union (Spark has no equivalent)")
+
+ case other => unsupported(f, s"$other")
+ }
+
+ private def unsupported(f: Field, detail: String): Nothing =
+ throw new UnsupportedOperationException(
+ s"Column '${f.getName}': $detail"
+ )
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala
new file mode 100644
index 0000000..684e9fd
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionBatch.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory}
+
+/**
+ * Spark `Batch` for a DataFusion-backed scan. Driver-side partition planning:
+ * - [[PerPartitionMode]]: one task per `PartitionInfo` (resolved by [[DatafusionScanBuilder]]); when
+ * the bridge reported a partitioning and every entry carries key values, tasks implement
+ * `HasPartitionKey` so Spark can actually use the `KeyGroupedPartitioning`.
+ * - [[SharedScanMode]]: one task per DataFusion plan partition index.
+ */
+class DatafusionBatch(val scan: DatafusionScan) extends Batch {
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val projection = scan.prunedSchema.fieldNames
+ val filterBytes: Array[Array[Byte]] = scan.pushedPredicateBytes
+
+ scan.mode match {
+ case PerPartitionMode(partitions, reported) =>
+ val keyed = DatafusionBatch.validateKeyedState(scan.factoryFqcn, partitions, reported)
+ partitions.iterator.map { p =>
+ val base = DatafusionInputPartition(
+ factoryFqcn = scan.factoryFqcn,
+ optionsBytes = scan.optionsBytes,
+ projectionColumnNames = projection,
+ filterProtoBytes = filterBytes,
+ partitionId = p.id,
+ partitionBytes = p.partitionBytes,
+ preferredLocs = p.preferredLocations
+ )
+ val out: DatafusionPartition =
+ if (keyed) {
+ DatafusionKeyedInputPartition(
+ base,
+ DatafusionBatch.toKeyRow(p.id, p.partitionKeyValues, reported))
+ } else base
+ out.asInstanceOf[InputPartition]
+ }.toArray
+
+ case SharedScanMode(scanId, numPartitions, pinnedConfig, idleTtlMs) =>
+ Array.tabulate[InputPartition](numPartitions) { i =>
+ DatafusionSharedScanPartition(
+ factoryFqcn = scan.factoryFqcn,
+ optionsBytes = scan.optionsBytes,
+ projectionColumnNames = projection,
+ filterProtoBytes = filterBytes,
+ scanId = scanId,
+ partitionIndex = i,
+ numPartitions = numPartitions,
+ pinnedConfig = pinnedConfig,
+ idleTtlMs = idleTtlMs
+ )
+ }
+ }
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory =
+ new DatafusionPartitionReaderFactory(scan.prunedSchema)
+}
+
+private[spark] object DatafusionBatch {
+
+ /**
+ * Keyed partitions require a reported partitioning AND key values on EVERY partition. A mixed
+ * state means the bridge violated its own contract; failing driver-side beats Spark silently
+ * planning without the declared grouping.
+ */
+ def validateKeyedState(
+ factoryFqcn: String,
+ partitions: Array[PartitionInfo],
+ reported: ReportedPartitioning): Boolean = {
+ if (reported == null) {
+ return false
+ }
+ val withKeys = partitions.count(_.partitionKeyValues != null)
+ if (withKeys == 0) {
+ return false
+ }
+ if (withKeys != partitions.length) {
+ throw new IllegalStateException(
+ s"BridgeProviderFactory '$factoryFqcn' reported a partitioning but only $withKeys of " +
+ s"${partitions.length} PartitionInfo entries carry partitionKeyValues; either all " +
+ "partitions must carry key values or none")
+ }
+ true
+ }
+
+ /**
+ * Convert a bridge-supplied `Object[]` of key values into Spark's internal row representation
+ * (String → UTF8String, Instant → micros, LocalDate → days, BigDecimal → Decimal, ...).
+ */
+ def toKeyRow(
+ partitionId: String,
+ values: Array[AnyRef],
+ reported: ReportedPartitioning): InternalRow = {
+ val keyCount = reported.keys().length
+ if (values.length != keyCount) {
+ throw new IllegalStateException(
+ s"PartitionInfo '$partitionId' carries ${values.length} partitionKeyValues but the " +
+ s"reported partitioning declares $keyCount key(s)")
+ }
+ val converted = values.map(v => CatalystTypeConverters.convertToCatalyst(v))
+ new GenericInternalRow(converted)
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala
new file mode 100644
index 0000000..96b7548
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionColumnarPartitionReader.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Per-task columnar reader for the per-partition path. Lifecycle:
+ *
+ * 1. Reflectively instantiate the bridge's `BridgeProviderFactory` (no-arg) and take its
+ * [[ScanBackend]].
+ * 2. `backend.createScan(options, partitionBytes, ...)` — builds the provider for the slice
+ * described by `partitionBytes` and does the rest natively: widening wrap, private
+ * `SessionContext`, projection, pushed proto filters, physical plan.
+ * 3. `backend.executeStream` streams the whole plan (the provider already IS the task's
+ * slice); batches surface through [[ArrowColumnarBatchIteration]].
+ */
+class DatafusionColumnarPartitionReader(
+ partition: DatafusionInputPartition,
+ readSchema: StructType
+) extends PartitionReader[ColumnarBatch]
+ with ArrowColumnarBatchIteration {
+
+ private val allocator = new RootAllocator(Long.MaxValue)
+
+ private val backend: ScanBackend = instantiateFactory(partition.factoryFqcn).scanBackend()
+
+ private val scanHandle: Long =
+ try {
+ backend.createScan(
+ partition.optionsBytes,
+ partition.partitionBytes,
+ /* targetPartitions = */ -1,
+ /* batchSize = */ -1,
+ Array.empty[String],
+ Array.empty[String],
+ partition.projectionColumnNames,
+ partition.filterProtoBytes
+ )
+ } catch {
+ case t: Throwable =>
+ try allocator.close()
+ catch { case suppressed: Throwable => t.addSuppressed(suppressed) }
+ throw t
+ }
+
+ override protected val arrowReader: ArrowReader =
+ try {
+ FfiStream.importReader(allocator) { addr =>
+ backend.executeStream(scanHandle, addr)
+ }
+ } catch {
+ case t: Throwable =>
+ try backend.closeScan(scanHandle)
+ catch { case suppressed: Throwable => t.addSuppressed(suppressed) }
+ try allocator.close()
+ catch { case suppressed: Throwable => t.addSuppressed(suppressed) }
+ throw t
+ }
+
+ override def close(): Unit = {
+ var first: Throwable = null
+ def safe(f: => Unit): Unit =
+ try f
+ catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) }
+ safe(arrowReader.close())
+ safe(backend.closeScan(scanHandle))
+ safe(allocator.close())
+ if (first != null) throw first
+ }
+
+ private def instantiateFactory(fqcn: String): BridgeProviderFactory = {
+ val cls = Class.forName(fqcn)
+ cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory]
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala
new file mode 100644
index 0000000..5255644
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionInputPartition.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
+
+/**
+ * Marker for the connector's task payloads, shipped driver → executor via Java serialization.
+ * [[DatafusionPartitionReaderFactory]] dispatches on the concrete type.
+ */
+sealed trait DatafusionPartition extends InputPartition
+
+/**
+ * Per-task payload for the per-partition read path.
+ *
+ * - `factoryFqcn`: fully-qualified class name of the bridge's `BridgeProviderFactory`. The
+ * executor reflectively instantiates this and calls
+ * `scanBackend().createScan(optionsBytes, partitionBytes, …)`.
+ * - `optionsBytes`: bridge-specific global connection options, encoded by the bridge.
+ * Opaque to connector-core. Same bytes ride along on every partition.
+ * - `projectionColumnNames`: pruned column list (post-`pruneColumns`).
+ * - `filterProtoBytes`: V2 `Predicate` → DataFusion `LogicalExprNode` proto bytes; each one is
+ * applied natively via `ScanBackend.createScan`.
+ * - `partitionId`: stable identifier (e.g. a segment or file id) — surfaces in Spark UI/logs/errors.
+ * - `partitionBytes`: opaque per-partition payload from `PartitionInfo.partitionBytes`. Passed
+ * back into `ScanBackend.createScan` so the bridge materialises *this* slice.
+ * - `preferredLocs`: hostnames where this partition's data lives; returned from
+ * `preferredLocations()` so Spark schedules the task there subject to `spark.locality.wait`.
+ */
+final case class DatafusionInputPartition(
+ factoryFqcn: String,
+ optionsBytes: Array[Byte],
+ projectionColumnNames: Array[String],
+ filterProtoBytes: Array[Array[Byte]],
+ partitionId: String,
+ partitionBytes: Array[Byte],
+ preferredLocs: Array[String]
+) extends DatafusionPartition {
+
+ override def preferredLocations(): Array[String] = preferredLocs
+}
+
+/**
+ * Per-partition payload that additionally carries this partition's key values, precomputed
+ * driver-side into an [[InternalRow]]. Emitted by [[DatafusionBatch]] when the bridge reported a
+ * partitioning AND every `PartitionInfo` carries `partitionKeyValues` — implementing
+ * [[HasPartitionKey]] is what makes the reported `KeyGroupedPartitioning` visible to Spark 3.3+
+ * (`DataSourceV2ScanExecBase.groupPartitions` ignores it otherwise).
+ */
+final case class DatafusionKeyedInputPartition(
+ base: DatafusionInputPartition,
+ keyRow: InternalRow
+) extends DatafusionPartition
+ with HasPartitionKey {
+
+ override def preferredLocations(): Array[String] = base.preferredLocations()
+
+ override def partitionKey(): InternalRow = keyRow
+}
+
+/**
+ * Per-task payload for shared-scan mode: task `partitionIndex` streams that DataFusion plan
+ * partition from the executor's cached entry (see [[SharedScanCache]]).
+ *
+ * - `scanId`: driver-minted UUID identifying this scan; the executor cache key.
+ * - `partitionIndex`: DataFusion output partition this task drives.
+ * - `numPartitions`: the driver probe's partition count; executors fail fast when their re-plan
+ * diverges (determinism guard).
+ * - `pinnedConfig`: DataFusion session knobs resolved once on the driver and replicated on
+ * every executor so both plan identically.
+ * - `idleTtlMs`: cache-entry idle eviction window, resolved from driver conf.
+ *
+ * No preferred locations: the shared plan materialises the whole dataset on whichever executors
+ * Spark picks; there is no per-slice host mapping in this mode.
+ */
+final case class DatafusionSharedScanPartition(
+ factoryFqcn: String,
+ optionsBytes: Array[Byte],
+ projectionColumnNames: Array[String],
+ filterProtoBytes: Array[Array[Byte]],
+ scanId: String,
+ partitionIndex: Int,
+ numPartitions: Int,
+ pinnedConfig: PinnedSessionConfig,
+ idleTtlMs: Long
+) extends DatafusionPartition {
+
+ def toSpec: SharedScanSpec =
+ SharedScanSpec(
+ scanId = scanId,
+ factoryFqcn = factoryFqcn,
+ optionsBytes = optionsBytes,
+ projectionColumnNames = projectionColumnNames,
+ filterProtoBytes = filterProtoBytes,
+ pinnedConfig = pinnedConfig
+ )
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala
new file mode 100644
index 0000000..ba5409c
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionPartitionReaderFactory.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Per-task `PartitionReader` factory. Columnar-only: row-based reads would force the connector
+ * to convert Arrow → `InternalRow` per row, defeating the zero-copy path that is the whole
+ * reason we are in-process.
+ */
+class DatafusionPartitionReaderFactory(val readSchema: StructType) extends PartitionReaderFactory {
+
+ override def supportColumnarReads(partition: InputPartition): Boolean = true
+
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] =
+ throw new UnsupportedOperationException(
+ "DatafusionPartitionReaderFactory: row-based read not supported; consumers must opt into columnar"
+ )
+
+ override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] =
+ partition match {
+ case p: DatafusionInputPartition =>
+ new DatafusionColumnarPartitionReader(p, readSchema)
+ case p: DatafusionKeyedInputPartition =>
+ new DatafusionColumnarPartitionReader(p.base, readSchema)
+ case p: DatafusionSharedScanPartition =>
+ new SharedScanPartitionReader(p, SharedScanCache.global)
+ case other =>
+ throw new IllegalArgumentException(
+ s"unexpected InputPartition type: ${other.getClass.getName}")
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala
new file mode 100644
index 0000000..38f0a8b
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScan.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.read.{Batch, Scan, SupportsReportPartitioning}
+import org.apache.spark.sql.connector.read.partitioning.{
+ KeyGroupedPartitioning,
+ Partitioning,
+ UnknownPartitioning
+}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * How the scan maps to Spark tasks — resolved once, driver-side, in
+ * [[DatafusionScanBuilder.build]].
+ */
+sealed trait DatafusionScanMode extends Serializable
+
+/**
+ * Per-partition payload mode: one task per [[PartitionInfo]], each task builds its own provider
+ * from that entry's `partitionBytes`. `reported` is the bridge's optional partitioning
+ * declaration (may be null).
+ */
+final case class PerPartitionMode(
+ partitions: Array[PartitionInfo],
+ reported: ReportedPartitioning
+) extends DatafusionScanMode
+
+/**
+ * Shared-scan mode: one cached provider + plan per (executor × scan), `numPartitions` tasks each
+ * driving one DataFusion output partition. See [[BridgeProviderFactory#sharedScan]] for the
+ * determinism contract.
+ */
+final case class SharedScanMode(
+ scanId: String,
+ numPartitions: Int,
+ pinnedConfig: PinnedSessionConfig,
+ idleTtlMs: Long
+) extends DatafusionScanMode
+
+/**
+ * Read plan for a DataFusion-backed scan. Holds pruning state, the pushed predicates (for
+ * `description()` / `explain(True)`), the corresponding `LogicalExprNode` proto byte arrays the
+ * executor applies natively via `ScanBackend.createScan`, and the driver-resolved
+ * [[DatafusionScanMode]].
+ *
+ * Per-partition mode with a bridge-declared [[ReportedPartitioning]] surfaces `KeyGroupedPartitioning`
+ * via `SupportsReportPartitioning`; note Spark 3.3+ only consumes it when the input partitions
+ * also implement `HasPartitionKey` (see [[DatafusionBatch]]). Shared-scan mode always reports
+ * `UnknownPartitioning` — DataFusion-native partitions carry no key contract.
+ */
+class DatafusionScan(
+ val factoryFqcn: String,
+ val optionsBytes: Array[Byte],
+ val fullSchema: StructType,
+ val prunedSchema: StructType,
+ val pushedPredicates: Array[Predicate],
+ val pushedPredicateBytes: Array[Array[Byte]],
+ val mode: DatafusionScanMode
+) extends Scan
+ with SupportsReportPartitioning {
+
+ override def readSchema(): StructType = prunedSchema
+
+ override def description(): String = {
+ val modeDesc = mode match {
+ case PerPartitionMode(partitions, reported) =>
+ s"mode=per-partition, partitions=${partitions.length}," +
+ s" reportedPartitioning=${if (reported == null) "unknown" else "key-grouped"}"
+ case SharedScanMode(scanId, n, _, _) =>
+ s"mode=shared-scan, scanId=$scanId, partitions=$n"
+ }
+ s"DatafusionScan(factory=$factoryFqcn, projection=${prunedSchema.fieldNames.mkString(",")}," +
+ s" pushedPredicates=${pushedPredicates.length}, $modeDesc)"
+ }
+
+ override def toBatch: Batch = new DatafusionBatch(this)
+
+ override def outputPartitioning(): Partitioning = mode match {
+ case PerPartitionMode(partitions, reported) =>
+ if (reported == null) new UnknownPartitioning(partitions.length)
+ else new KeyGroupedPartitioning(reported.keys().toArray, partitions.length)
+ case SharedScanMode(_, numPartitions, _, _) =>
+ new UnknownPartitioning(numPartitions)
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala
new file mode 100644
index 0000000..a74029c
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionScanBuilder.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.util.UUID
+
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+/**
+ * ScanBuilder with V2 Predicate pushdown + column pruning. Every translatable predicate is
+ * marked Exact and dropped from Spark's post-scan Filter; the rest stay residual.
+ *
+ * Pushdown discipline: over-claiming Exact = wrong results, under-claiming = full scans. The
+ * translator (see [[SparkPredicateTranslator]]) only emits proto for predicates it can encode
+ * losslessly — anything else returns `None` and lands in residuals.
+ *
+ * `build()` resolves the driver-side facts the optimizer needs *before* it starts asking the
+ * [[DatafusionScan]] about its output partitioning. Spark guarantees `pushPredicates` and
+ * `pruneColumns` run first, so both paths see the final projection + filters:
+ * - per-partition payload mode: `listPartitions(opts, filters)` (filter-aware overload — the
+ * bridge can prune whole partitions) + the optional [[ReportedPartitioning]];
+ * - shared-scan mode: a probe build of the provider + plan (via the same code path executors
+ * use) to count DataFusion output partitions, plus a freshly minted scanId and the pinned
+ * session config that makes executor re-plans comparable.
+ */
+class DatafusionScanBuilder(
+ factoryFqcn: String,
+ optionsBytes: Array[Byte],
+ fullSchema: StructType
+) extends ScanBuilder
+ with SupportsPushDownV2Filters
+ with SupportsPushDownRequiredColumns {
+
+ private var pushed: Array[Predicate] = Array.empty
+ private var pushedBytes: Array[Array[Byte]] = Array.empty
+ private var pruned: StructType = fullSchema
+
+ override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
+ val pushedBuf = scala.collection.mutable.ArrayBuffer[Predicate]()
+ val bytesBuf = scala.collection.mutable.ArrayBuffer[Array[Byte]]()
+ val residual = scala.collection.mutable.ArrayBuffer[Predicate]()
+
+ var i = 0
+ while (i < predicates.length) {
+ val p = predicates(i)
+ SparkPredicateTranslator.translate(p) match {
+ case Some(node) =>
+ pushedBuf += p
+ bytesBuf += node.toByteArray
+ case None =>
+ residual += p
+ }
+ i += 1
+ }
+ pushed = pushedBuf.toArray
+ pushedBytes = bytesBuf.toArray
+ residual.toArray
+ }
+
+ override def pushedPredicates(): Array[Predicate] = pushed
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ pruned = requiredSchema
+ }
+
+ override def build(): Scan = {
+ val factory = instantiateFactory(factoryFqcn)
+ val mode: DatafusionScanMode =
+ if (factory.sharedScan(optionsBytes)) buildSharedScanMode()
+ else buildPerPartitionMode(factory)
+ new DatafusionScan(
+ factoryFqcn,
+ optionsBytes,
+ fullSchema,
+ pruned,
+ pushed,
+ pushedBytes,
+ mode
+ )
+ }
+
+ private def buildPerPartitionMode(factory: BridgeProviderFactory): PerPartitionMode = {
+ val partitions: Array[PartitionInfo] =
+ factory.listPartitions(optionsBytes, pushedBytes)
+ if (partitions == null || partitions.isEmpty) {
+ throw new IllegalStateException(
+ s"BridgeProviderFactory '$factoryFqcn' returned no partitions to scan"
+ )
+ }
+ PerPartitionMode(partitions, factory.reportPartitioning(optionsBytes))
+ }
+
+ /**
+ * Driver plan probe: build the provider + plan exactly as executors will (same widening, SQL,
+ * filters, pinned config — one code path in [[NativeSharedScanResources]]) and read the
+ * physical plan's output partition count. All Spark conf is resolved here, driver-side;
+ * executors only see the shipped copies.
+ */
+ private def buildSharedScanMode(): SharedScanMode = {
+ val conf = SQLConf.get
+ val pinned = PinnedSessionConfig.fromConf(conf)
+ val idleTtlMs = PinnedSessionConfig.idleTtlMs(conf)
+ val scanId = UUID.randomUUID().toString
+
+ val probeSpec = SharedScanSpec(
+ scanId = scanId,
+ factoryFqcn = factoryFqcn,
+ optionsBytes = optionsBytes,
+ projectionColumnNames = pruned.fieldNames,
+ filterProtoBytes = pushedBytes,
+ pinnedConfig = pinned
+ )
+ val probe = NativeSharedScanResources.build(probeSpec)
+ val numPartitions =
+ try {
+ probe.partitionCount
+ } finally {
+ probe.close()
+ }
+ if (numPartitions <= 0) {
+ throw new IllegalStateException(
+ s"shared-scan probe for factory '$factoryFqcn' produced a plan with no partitions")
+ }
+ SharedScanMode(scanId, numPartitions, pinned, idleTtlMs)
+ }
+
+ private def instantiateFactory(fqcn: String): BridgeProviderFactory = {
+ val cls = Class.forName(fqcn)
+ cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory]
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala
new file mode 100644
index 0000000..125b3a1
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionSource.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.io.ByteArrayInputStream
+import java.nio.channels.Channels
+import java.util
+
+import org.apache.arrow.vector.ipc.ReadChannel
+import org.apache.arrow.vector.ipc.message.MessageSerializer
+import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Generic Spark DataSource V2 entry point. Concrete bridges either:
+ * - Subclass and override [[shortName]] + [[factoryFqcn]] (the short-name shim pattern), or
+ * - Use this class directly with `option("df.factory", "fully.qualified.FactoryClass")`.
+ *
+ * Schema discovery happens driver-side inside the bridge's native scan backend
+ * (`ScanBackend.providerSchemaIpc`), which widens the provider and returns its Arrow schema as
+ * IPC bytes. The same `optionsBytes` (and the factory FQCN) is then carried verbatim through
+ * `DatafusionInputPartition`, so each executor task repeats the same factory → backend pipeline
+ * locally.
+ */
+class DatafusionSource extends TableProvider with DataSourceRegister {
+
+ override def shortName(): String = "datafusion"
+
+ /** Spark option key carrying the BridgeProviderFactory FQCN when no override is provided. */
+ protected val FactoryOptionKey: String = "df.factory"
+
+ /**
+ * Resolve the bridge factory class name from the Spark options. Subclasses override to return a
+ * hard-coded FQCN so users don't need to set `df.factory` themselves.
+ */
+ protected def factoryFqcn(options: CaseInsensitiveStringMap): String = {
+ val v = options.get(FactoryOptionKey)
+ if (v == null || v.isEmpty)
+ throw new IllegalArgumentException(
+ s"DatafusionSource: option '$FactoryOptionKey' is required when no subclass override is set"
+ )
+ v
+ }
+
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ val fqcn = factoryFqcn(options)
+ val factory = instantiateFactory(fqcn)
+ val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap())
+ // Schema probe: pass empty partitionBytes — bridges are required to honour an empty
+ // payload for the driver-side probe (schema must not depend on per-partition state).
+ val ipcBytes = factory.scanBackend().providerSchemaIpc(optionsBytes, Array.emptyByteArray)
+ val arrowSchema = MessageSerializer.deserializeSchema(
+ new ReadChannel(Channels.newChannel(new ByteArrayInputStream(ipcBytes))))
+ ArrowToSparkSchema.toSparkSchema(arrowSchema)
+ }
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String]
+ ): Table = {
+ val options = new CaseInsensitiveStringMap(properties)
+ val fqcn = factoryFqcn(options)
+ val factory = instantiateFactory(fqcn)
+ val optionsBytes = factory.encodeOptions(options.asCaseSensitiveMap())
+ new DatafusionTable(fqcn, optionsBytes, schema)
+ }
+
+ override def supportsExternalMetadata(): Boolean = false
+
+ private def instantiateFactory(fqcn: String): BridgeProviderFactory = {
+ val cls = Class.forName(fqcn)
+ cls.getDeclaredConstructor().newInstance().asInstanceOf[BridgeProviderFactory]
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala
new file mode 100644
index 0000000..a0e8ec4
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/DatafusionTable.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.util
+
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Read-only DataFusion-backed table. Capabilities advertise only batch read.
+ */
+class DatafusionTable(
+ val factoryFqcn: String,
+ val optionsBytes: Array[Byte],
+ val sparkSchema: StructType
+) extends Table
+ with SupportsRead {
+
+ override def name(): String = s"datafusion.${factoryFqcn.split('.').last}"
+
+ override def schema(): StructType = sparkSchema
+
+ override def capabilities(): util.Set[TableCapability] = {
+ val caps = new util.HashSet[TableCapability]()
+ caps.add(TableCapability.BATCH_READ)
+ caps
+ }
+
+ override def newScanBuilder(scanOpts: CaseInsensitiveStringMap): ScanBuilder =
+ new DatafusionScanBuilder(factoryFqcn, optionsBytes, sparkSchema)
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/FfiStream.scala b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala
new file mode 100644
index 0000000..eb1149a
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/FfiStream.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.c.{ArrowArrayStream, Data}
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.ipc.ArrowReader
+
+/**
+ * Arrow C-data import of a native-produced `FFI_ArrowArrayStream`: allocate the empty struct,
+ * let the native side write into it, then hand it to Arrow Java. On any failure the struct is
+ * released so a half-written stream can't leak.
+ */
+private[spark] object FfiStream {
+
+ def importReader(allocator: BufferAllocator)(writeStream: Long => Unit): ArrowReader = {
+ val stream = ArrowArrayStream.allocateNew(allocator)
+ try {
+ writeStream(stream.memoryAddress())
+ Data.importArrayStream(allocator, stream)
+ } catch {
+ case t: Throwable =>
+ stream.close()
+ throw t
+ }
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala
new file mode 100644
index 0000000..b541c8a
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/NativeSharedScanResources.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.spark.internal.Logging
+
+/**
+ * JNI-backed shared-scan entry: one provider, one planned scan handle inside the bridge's native
+ * scan backend.
+ *
+ * The build sequence is the single code path for BOTH the driver-side partition-count probe and
+ * every executor's cache entry — identical widening, registration, projection, filters, and
+ * pinned session config are what make the partition count comparable across machines (the
+ * bridge's determinism contract covers the rest).
+ */
+private[spark] final class NativeSharedScanResources(
+ allocator: RootAllocator,
+ backend: ScanBackend,
+ scanHandle: Long
+) extends SharedScanResources {
+
+ override def partitionCount: Int = backend.partitionCount(scanHandle)
+
+ override def newTaskAllocator(name: String): BufferAllocator =
+ allocator.newChildAllocator(name, 0, Long.MaxValue)
+
+ override def openPartitionStream(
+ partition: Int,
+ taskAllocator: BufferAllocator): ArrowReader =
+ FfiStream.importReader(taskAllocator) { addr =>
+ backend.executeStreamPartition(scanHandle, partition, addr)
+ }
+
+ override def close(): Unit = {
+ var first: Throwable = null
+ def safe(f: => Unit): Unit =
+ try f
+ catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) }
+ safe(backend.closeScan(scanHandle))
+ safe(allocator.close())
+ if (first != null) throw first
+ }
+}
+
+private[spark] object NativeSharedScanResources extends Logging {
+
+ def build(spec: SharedScanSpec): SharedScanResources = {
+ logInfo(
+ s"Building shared-scan entry for scanId=${spec.scanId} " +
+ s"(factory=${spec.factoryFqcn}, filters=${spec.filterProtoBytes.length})")
+
+ val factory = Class
+ .forName(spec.factoryFqcn)
+ .getDeclaredConstructor()
+ .newInstance()
+ .asInstanceOf[BridgeProviderFactory]
+ val backend = factory.scanBackend()
+
+ val allocator = new RootAllocator(Long.MaxValue)
+ try {
+ // Shared mode builds the dataset-wide provider: empty partitionBytes, like the
+ // driver-side schema probe. DataFusion-native partitioning replaces listPartitions.
+ val scanHandle = backend.createScan(
+ spec.optionsBytes,
+ Array.emptyByteArray,
+ spec.pinnedConfig.targetPartitions,
+ spec.pinnedConfig.batchSize,
+ spec.pinnedConfig.options.map(_._1).toArray,
+ spec.pinnedConfig.options.map(_._2).toArray,
+ spec.projectionColumnNames,
+ spec.filterProtoBytes
+ )
+ new NativeSharedScanResources(allocator, backend, scanHandle)
+ } catch {
+ case t: Throwable =>
+ try allocator.close()
+ catch { case suppressed: Throwable => t.addSuppressed(suppressed) }
+ throw t
+ }
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala
new file mode 100644
index 0000000..4fa74bd
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/NonClosingArrowColumnVector.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.vector.FieldVector
+import org.apache.spark.sql.vectorized.ArrowColumnVector
+
+/**
+ * `ArrowColumnVector` whose `close()` is a no-op. The `ArrowReader`'s `VectorSchemaRoot` owns
+ * the underlying `ValueVector` lifecycles across `loadNextBatch()` calls; closing them per Spark
+ * batch would break the next read. Lifecycle is centralised in
+ * `DatafusionColumnarPartitionReader.close()`.
+ */
+final class NonClosingArrowColumnVector(vec: FieldVector) extends ArrowColumnVector(vec) {
+ override def close(): Unit = { /* intentional no-op */ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala
new file mode 100644
index 0000000..1340978
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/PinnedSessionConfig.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * DataFusion session knobs pinned by the driver and replicated verbatim on every executor in
+ * shared-scan mode.
+ *
+ * DataFusion's default `SessionConfig` derives `target_partitions` from the machine's core count,
+ * so a plan that yields N partitions on the driver could yield M ≠ N on a differently-sized
+ * executor — and partition-indexed execution would silently drop or duplicate data. The driver
+ * resolves these values once in `DatafusionScanBuilder.build()`, ships them inside every
+ * [[DatafusionSharedScanPartition]], and both the driver probe and the executors hand the same
+ * values to `ScanBackend.createScan`, which builds the native `SessionContext` from them.
+ *
+ * `options` additionally disables the optimizer's plan-reshaping repartition passes so the
+ * physical partitioning is exactly what the provider's `scan()` reports, on every machine.
+ */
+final case class PinnedSessionConfig(
+ targetPartitions: Int,
+ batchSize: Int,
+ options: Vector[(String, String)]
+) extends Serializable
+
+object PinnedSessionConfig {
+
+ val TargetPartitionsConf = "spark.datafusion.sharedScan.targetPartitions"
+ val BatchSizeConf = "spark.datafusion.sharedScan.batchSize"
+ val IdleTtlConf = "spark.datafusion.sharedScan.idleTtlMs"
+
+ val DefaultTargetPartitions = 8
+ val DefaultBatchSize = 8192
+ val DefaultIdleTtlMs = 120000L
+
+ /**
+ * Optimizer knobs that must not vary with the host. Round-robin repartition and file-scan
+ * repartition would let the optimizer change the plan's output partition count based on
+ * `target_partitions` heuristics; statistics collection could steer per-host plan differences.
+ */
+ private val DeterminismOptions: Vector[(String, String)] = Vector(
+ "datafusion.optimizer.enable_round_robin_repartition" -> "false",
+ "datafusion.optimizer.repartition_file_scans" -> "false",
+ "datafusion.execution.collect_statistics" -> "false"
+ )
+
+ /**
+ * Resolve the pinned config from the driver's session conf. Called exactly once per scan, on
+ * the driver; executors never read Spark conf for these values — they use the shipped copy.
+ */
+ def fromConf(conf: SQLConf): PinnedSessionConfig = {
+ PinnedSessionConfig(
+ targetPartitions =
+ conf.getConfString(TargetPartitionsConf, DefaultTargetPartitions.toString).toInt,
+ batchSize = conf.getConfString(BatchSizeConf, DefaultBatchSize.toString).toInt,
+ options = DeterminismOptions
+ )
+ }
+
+ def idleTtlMs(conf: SQLConf): Long =
+ conf.getConfString(IdleTtlConf, DefaultIdleTtlMs.toString).toLong
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala
new file mode 100644
index 0000000..a134746
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/SharedScanCache.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.util.concurrent.{ConcurrentHashMap, Executors, ScheduledExecutorService, TimeUnit}
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.ipc.ArrowReader
+
+/**
+ * Everything the driver resolved that an executor needs to rebuild the shared scan: identity
+ * (scanId) plus the exact build inputs (factory, options, projection, filters, pinned config).
+ */
+final case class SharedScanSpec(
+ scanId: String,
+ factoryFqcn: String,
+ optionsBytes: Array[Byte],
+ projectionColumnNames: Array[String],
+ filterProtoBytes: Array[Array[Byte]],
+ pinnedConfig: PinnedSessionConfig
+)
+
+/**
+ * What one cached shared-scan entry exposes to readers. Implemented by
+ * [[NativeSharedScanResources]] (JNI-backed) and by fakes in tests.
+ */
+trait SharedScanResources extends AutoCloseable {
+
+ /** Output partition count of the planned physical plan. */
+ def partitionCount: Int
+
+ /** Child allocator for one task's reader; closed by the task, before release. */
+ def newTaskAllocator(name: String): BufferAllocator
+
+ /** Open an independent stream over one plan partition. Concurrent-safe. */
+ def openPartitionStream(partition: Int, taskAllocator: BufferAllocator): ArrowReader
+}
+
+/**
+ * Executor-JVM cache of shared-scan entries, keyed by the driver-minted scanId.
+ *
+ * Semantics:
+ * - `acquire` builds the entry exactly once per attempt wave: the first caller builds under
+ * the entry's lock, concurrent callers block and share the result. Each successful acquire
+ * increments a refcount that the caller MUST pair with `release(scanId)`.
+ * - Build failures propagate to the builder AND all waiters of that attempt, and are not
+ * cached: the next acquire rebuilds.
+ * - Eviction closes entries with refcount 0 that have been idle longer than their TTL. The
+ * refcount covers every open reader, so native close never races an in-flight stream.
+ * Acquire after eviction rebuilds — correct, just slower.
+ *
+ * The cache itself is JNI-free: the entry builder is injected, so tests run without native libs.
+ */
+final class SharedScanCache(
+ buildEntry: SharedScanSpec => SharedScanResources,
+ nanoClock: () => Long = () => System.nanoTime()
+) {
+
+ /**
+ * Per-scanId slot. All state transitions are guarded by `this` (the holder's monitor); the
+ * build itself also runs under the monitor, which is what blocks concurrent acquirers of the
+ * same scan until the entry exists.
+ */
+ private final class EntryHolder(spec: SharedScanSpec, idleTtlMs: Long) {
+ private var resources: SharedScanResources = _
+ private var refCount: Int = 0
+ private var lastReleaseNanos: Long = nanoClock()
+ private var closed: Boolean = false
+
+ /** Returns the resources with refcount incremented, or None if this holder was evicted. */
+ def acquire(): Option[SharedScanResources] = synchronized {
+ if (closed) return None
+ if (resources == null) {
+ resources = buildEntry(spec) // throws -> caller removes holder
+ }
+ refCount += 1
+ Some(resources)
+ }
+
+ def release(): Unit = synchronized {
+ refCount -= 1
+ lastReleaseNanos = nanoClock()
+ }
+
+ /** Close if idle past TTL; returns true when this holder is now closed. */
+ def closeIfIdle(nowNanos: Long): Boolean = synchronized {
+ if (closed) return true
+ val idle = refCount == 0 &&
+ (nowNanos - lastReleaseNanos) >= TimeUnit.MILLISECONDS.toNanos(idleTtlMs)
+ if (idle) forceCloseLocked()
+ closed
+ }
+
+ def forceClose(): Unit = synchronized { forceCloseLocked() }
+
+ private def forceCloseLocked(): Unit = {
+ if (!closed) {
+ closed = true
+ if (resources != null) {
+ val r = resources
+ resources = null
+ r.close()
+ }
+ }
+ }
+ }
+
+ private val entries = new ConcurrentHashMap[String, EntryHolder]()
+
+ def acquire(spec: SharedScanSpec, idleTtlMs: Long): SharedScanResources = {
+ while (true) {
+ val holder =
+ entries.computeIfAbsent(spec.scanId, _ => new EntryHolder(spec, idleTtlMs))
+ val acquired =
+ try {
+ holder.acquire()
+ } catch {
+ case t: Throwable =>
+ // Build failed: drop the holder so the next acquire rebuilds, then propagate.
+ entries.remove(spec.scanId, holder)
+ throw t
+ }
+ acquired match {
+ case Some(resources) => return resources
+ case None =>
+ // Holder was evicted between map lookup and acquire; retry with a fresh one.
+ entries.remove(spec.scanId, holder)
+ }
+ }
+ throw new IllegalStateException("unreachable")
+ }
+
+ def release(scanId: String): Unit = {
+ val holder = entries.get(scanId)
+ if (holder == null) {
+ throw new IllegalStateException(
+ s"release($scanId) without a cached entry: unbalanced acquire/release")
+ }
+ holder.release()
+ }
+
+ /** Close and remove every idle-past-TTL entry. Called by the evictor daemon and by tests. */
+ private[spark] def evictIdleNow(): Unit = {
+ val now = nanoClock()
+ entries.forEach { (scanId, holder) =>
+ if (holder.closeIfIdle(now)) {
+ entries.remove(scanId, holder)
+ }
+ }
+ }
+
+ /** Close everything regardless of refcounts. JVM-shutdown path only. */
+ def shutdown(): Unit = {
+ entries.forEach { (_, holder) => holder.forceClose() }
+ entries.clear()
+ }
+}
+
+object SharedScanCache {
+
+ /** Evictor period. Short relative to any sane TTL; cheap when the map is empty. */
+ private val EvictorPeriodMs = 5000L
+
+ /** JVM singleton used by executor tasks. Lazily started together with its evictor daemon. */
+ lazy val global: SharedScanCache = {
+ val cache = new SharedScanCache(NativeSharedScanResources.build)
+ val evictor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor { r =>
+ val t = new Thread(r, "datafusion-shared-scan-evictor")
+ t.setDaemon(true)
+ t
+ }
+ evictor.scheduleWithFixedDelay(
+ () => cache.evictIdleNow(),
+ EvictorPeriodMs,
+ EvictorPeriodMs,
+ TimeUnit.MILLISECONDS)
+ Runtime.getRuntime.addShutdownHook(new Thread(() => cache.shutdown()))
+ cache
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala
new file mode 100644
index 0000000..4f0c9c1
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/SharedScanPartitionReader.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Shared-scan task reader: acquires the executor's cached (provider, plan) entry and streams ONE
+ * DataFusion plan partition from it. The acquire/release refcount pair brackets the reader's
+ * whole lifetime, so the cache can never close the native plan under an open stream.
+ */
+class SharedScanPartitionReader(
+ partition: DatafusionSharedScanPartition,
+ cache: SharedScanCache
+) extends PartitionReader[ColumnarBatch]
+ with ArrowColumnarBatchIteration {
+
+ private val resources: SharedScanResources = cache.acquire(partition.toSpec, partition.idleTtlMs)
+
+ // Determinism guard: the driver counted partitions by planning once; if this executor's
+ // re-plan disagrees, partition indices are meaningless and every task of the scan must fail
+ // rather than silently drop or duplicate data.
+ if (resources.partitionCount != partition.numPartitions) {
+ val executorCount = resources.partitionCount
+ cache.release(partition.scanId)
+ throw new IllegalStateException(
+ s"shared-scan determinism violation for scanId=${partition.scanId}: driver planned " +
+ s"${partition.numPartitions} partition(s) but this executor planned $executorCount. " +
+ "The provider's partitioning must be a pure function of optionsBytes; pin your " +
+ "source snapshot (see BridgeProviderFactory.sharedScan).")
+ }
+
+ private val taskAllocator: BufferAllocator = {
+ val attempt = Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+ resources.newTaskAllocator(
+ s"shared-${partition.scanId}-p${partition.partitionIndex}-attempt$attempt")
+ }
+
+ override protected val arrowReader: ArrowReader =
+ try {
+ resources.openPartitionStream(partition.partitionIndex, taskAllocator)
+ } catch {
+ case t: Throwable =>
+ try taskAllocator.close()
+ catch { case suppressed: Throwable => t.addSuppressed(suppressed) }
+ cache.release(partition.scanId)
+ throw t
+ }
+
+ override def close(): Unit = {
+ var first: Throwable = null
+ def safe(f: => Unit): Unit =
+ try f
+ catch { case t: Throwable => if (first == null) first = t else first.addSuppressed(t) }
+ safe(arrowReader.close())
+ safe(taskAllocator.close())
+ // Release LAST: the refcount must cover the open stream and the task allocator.
+ safe(cache.release(partition.scanId))
+ if (first != null) throw first
+ }
+}
diff --git a/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala
new file mode 100644
index 0000000..3092914
--- /dev/null
+++ b/spark/src/main/scala/io/datafusion/spark/SparkPredicateTranslator.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import datafusion_common.DatafusionCommon.{Column, ScalarValue}
+import org.apache.datafusion.protobuf.{
+ BinaryExprNode,
+ InListNode,
+ IsNotNull,
+ IsNull,
+ LikeNode,
+ LogicalExprNode,
+ Not => NotNode
+}
+import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+
+/**
+ * Translate Spark V2 `Predicate` → DataFusion `LogicalExprNode` proto. Only emits expressions
+ * that the producer can apply EXACTLY — anything else returns `None` and the caller marks the
+ * predicate as residual so Spark re-applies it above the scan.
+ */
+object SparkPredicateTranslator {
+
+ def translate(p: Predicate): Option[LogicalExprNode] = p.name() match {
+ case "=" => binary(p, "Eq")
+ case "<>" => binary(p, "NotEq")
+ case "<" => binary(p, "Lt")
+ case "<=" => binary(p, "LtEq")
+ case ">" => binary(p, "Gt")
+ case ">=" => binary(p, "GtEq")
+ case "IS_NULL" => unary(p, "IsNull")
+ case "IS_NOT_NULL" => unary(p, "IsNotNull")
+ case "AND" => combine(p, "And")
+ case "OR" => combine(p, "Or")
+ case "NOT" => translateNot(p)
+ case "IN" => translateIn(p)
+ case "STARTS_WITH" => like(p, prefix = false, suffix = true)
+ case "ENDS_WITH" => like(p, prefix = true, suffix = false)
+ case "CONTAINS" => like(p, prefix = true, suffix = true)
+ case _ => None
+ }
+
+ private def binary(p: Predicate, op: String): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length != 2) return None
+ val left = expr(cs(0))
+ val right = expr(cs(1))
+ if (left.isEmpty || right.isEmpty) return None
+ Some(
+ LogicalExprNode
+ .newBuilder()
+ .setBinaryExpr(
+ BinaryExprNode
+ .newBuilder()
+ .addOperands(left.get)
+ .addOperands(right.get)
+ .setOp(op)
+ .build()
+ )
+ .build()
+ )
+ }
+
+ private def unary(p: Predicate, op: String): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length != 1) return None
+ val inner = expr(cs(0))
+ if (inner.isEmpty) return None
+ val builder = LogicalExprNode.newBuilder()
+ op match {
+ case "IsNull" => builder.setIsNullExpr(IsNull.newBuilder().setExpr(inner.get).build())
+ case "IsNotNull" =>
+ builder.setIsNotNullExpr(IsNotNull.newBuilder().setExpr(inner.get).build())
+ case _ => return None
+ }
+ Some(builder.build())
+ }
+
+ private def combine(p: Predicate, op: String): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length != 2) return None
+ val (l, r) = (cs(0), cs(1))
+ if (!l.isInstanceOf[Predicate] || !r.isInstanceOf[Predicate]) return None
+ val ln = translate(l.asInstanceOf[Predicate])
+ val rn = translate(r.asInstanceOf[Predicate])
+ if (ln.isEmpty || rn.isEmpty) return None
+ Some(
+ LogicalExprNode
+ .newBuilder()
+ .setBinaryExpr(
+ BinaryExprNode
+ .newBuilder()
+ .addOperands(ln.get)
+ .addOperands(rn.get)
+ .setOp(op)
+ .build()
+ )
+ .build()
+ )
+ }
+
+ private def translateNot(p: Predicate): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length != 1 || !cs(0).isInstanceOf[Predicate]) return None
+ val inner = translate(cs(0).asInstanceOf[Predicate])
+ if (inner.isEmpty) return None
+ Some(LogicalExprNode.newBuilder().setNotExpr(NotNode.newBuilder().setExpr(inner.get).build()).build())
+ }
+
+ private def translateIn(p: Predicate): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length < 2) return None
+ val target = expr(cs(0))
+ if (target.isEmpty) return None
+ val values = new java.util.ArrayList[LogicalExprNode]()
+ var i = 1
+ while (i < cs.length) {
+ val v = expr(cs(i))
+ if (v.isEmpty) return None
+ values.add(v.get)
+ i += 1
+ }
+ val node = InListNode
+ .newBuilder()
+ .setExpr(target.get)
+ .addAllList(values)
+ .setNegated(false)
+ .build()
+ Some(LogicalExprNode.newBuilder().setInList(node).build())
+ }
+
+ private def like(p: Predicate, prefix: Boolean, suffix: Boolean): Option[LogicalExprNode] = {
+ val cs = p.children()
+ if (cs.length != 2) return None
+ val target = expr(cs(0))
+ val pat = cs(1) match {
+ case lit: Literal[_] =>
+ val raw = lit.value() match {
+ case s: String => Some(s)
+ case u: org.apache.spark.unsafe.types.UTF8String => Some(u.toString)
+ case _ => None
+ }
+ raw.map { r =>
+ val escaped = r.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
+ (if (prefix) "%" else "") + escaped + (if (suffix) "%" else "")
+ }
+ case _ => None
+ }
+ if (target.isEmpty || pat.isEmpty) return None
+ val patternExpr = stringLiteral(pat.get)
+ val like = LikeNode
+ .newBuilder()
+ .setExpr(target.get)
+ .setPattern(patternExpr)
+ .setNegated(false)
+ .setEscapeChar("\\")
+ .build()
+ Some(LogicalExprNode.newBuilder().setLike(like).build())
+ }
+
+ private def expr(e: Expression): Option[LogicalExprNode] = e match {
+ case nr: NamedReference =>
+ val parts = nr.fieldNames()
+ if (parts.length != 1) None
+ else
+ Some(
+ LogicalExprNode
+ .newBuilder()
+ .setColumn(Column.newBuilder().setName(parts(0)).build())
+ .build()
+ )
+ case lit: Literal[_] => literal(lit.value())
+ case _ => None
+ }
+
+ private def literal(v: Any): Option[LogicalExprNode] = {
+ val sv = ScalarValue.newBuilder()
+ val ok: Boolean = v match {
+ case b: java.lang.Boolean => sv.setBoolValue(b.booleanValue()); true
+ case b: java.lang.Byte => sv.setInt8Value(b.intValue()); true
+ case s: java.lang.Short => sv.setInt16Value(s.intValue()); true
+ case i: java.lang.Integer => sv.setInt32Value(i.intValue()); true
+ case l: java.lang.Long => sv.setInt64Value(l.longValue()); true
+ case f: java.lang.Float => sv.setFloat32Value(f.floatValue()); true
+ case d: java.lang.Double => sv.setFloat64Value(d.doubleValue()); true
+ case s: String => sv.setUtf8Value(s); true
+ case u: org.apache.spark.unsafe.types.UTF8String => sv.setUtf8Value(u.toString); true
+ case _ => false
+ }
+ if (!ok) None
+ else Some(LogicalExprNode.newBuilder().setLiteral(sv.build()).build())
+ }
+
+ private def stringLiteral(s: String): LogicalExprNode =
+ LogicalExprNode.newBuilder().setLiteral(ScalarValue.newBuilder().setUtf8Value(s).build()).build()
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala
new file mode 100644
index 0000000..2b59601
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/ArrowToSparkSchemaTest.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.util.Collections
+
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.spark.sql.types._
+import org.scalatest.funsuite.AnyFunSuite
+
+class ArrowToSparkSchemaTest extends AnyFunSuite {
+
+ private def primField(name: String, t: ArrowType, nullable: Boolean = true): Field =
+ new Field(name, new FieldType(nullable, t, /*dict=*/ null), Collections.emptyList())
+
+ test("signed ints map to matching Spark int types") {
+ val arrow = new Schema(
+ java.util.Arrays.asList(
+ primField("i8", new ArrowType.Int(8, true)),
+ primField("i16", new ArrowType.Int(16, true)),
+ primField("i32", new ArrowType.Int(32, true)),
+ primField("i64", new ArrowType.Int(64, true))
+ )
+ )
+ val s = ArrowToSparkSchema.toSparkSchema(arrow)
+ assert(s.fields(0).dataType == ByteType)
+ assert(s.fields(1).dataType == ShortType)
+ assert(s.fields(2).dataType == IntegerType)
+ assert(s.fields(3).dataType == LongType)
+ }
+
+ test("unsigned ints are rejected with a clear error") {
+ val arrow = new Schema(
+ java.util.Arrays.asList(primField("u32", new ArrowType.Int(32, false)))
+ )
+ val ex = intercept[UnsupportedOperationException](ArrowToSparkSchema.toSparkSchema(arrow))
+ assert(ex.getMessage.contains("u32"))
+ assert(ex.getMessage.toLowerCase.contains("unsigned"))
+ }
+
+ test("timestamps split on timezone presence") {
+ val withTz = primField("t_utc", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"))
+ val noTz = primField("t_local", new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))
+ val s = ArrowToSparkSchema.toSparkSchema(
+ new Schema(java.util.Arrays.asList(withTz, noTz))
+ )
+ assert(s.fields(0).dataType == TimestampType)
+ assert(s.fields(1).dataType == TimestampNTZType)
+ }
+
+ test("decimal preserves precision and scale") {
+ val s = ArrowToSparkSchema.toSparkSchema(
+ new Schema(java.util.Arrays.asList(primField("d", new ArrowType.Decimal(18, 4, 128))))
+ )
+ assert(s.fields(0).dataType == DecimalType(18, 4))
+ }
+
+ test("Time and Float16 are rejected (no Spark accessor)") {
+ intercept[UnsupportedOperationException] {
+ ArrowToSparkSchema.toSparkSchema(
+ new Schema(java.util.Arrays.asList(primField("t", new ArrowType.Time(TimeUnit.MICROSECOND, 64))))
+ )
+ }
+ intercept[UnsupportedOperationException] {
+ ArrowToSparkSchema.toSparkSchema(
+ new Schema(java.util.Arrays.asList(primField("h", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF))))
+ )
+ }
+ }
+
+ test("list element nullability propagates") {
+ val child =
+ new Field(
+ "el",
+ new FieldType(/*nullable=*/ true, new ArrowType.Int(32, true), null),
+ Collections.emptyList()
+ )
+ val listField = new Field(
+ "xs",
+ new FieldType(true, new ArrowType.List(), null),
+ java.util.Arrays.asList(child)
+ )
+ val s = ArrowToSparkSchema.toSparkSchema(
+ new Schema(java.util.Arrays.asList(listField))
+ )
+ assert(s.fields(0).dataType == ArrayType(IntegerType, containsNull = true))
+ }
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala
new file mode 100644
index 0000000..0b94eee
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/BridgeProviderFactoryDefaultsTest.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class BridgeProviderFactoryDefaultsTest extends AnyFunSuite {
+
+ /** Backend stub: the defaults under test never touch native code. */
+ private object StubBackend extends ScanBackend {
+ def providerSchemaIpc(options: Array[Byte], partitionBytes: Array[Byte]): Array[Byte] =
+ throw new UnsupportedOperationException
+ def createScan(
+ options: Array[Byte],
+ partitionBytes: Array[Byte],
+ targetPartitions: Int,
+ batchSize: Int,
+ optionKeys: Array[String],
+ optionValues: Array[String],
+ projectionColumns: Array[String],
+ filterProtos: Array[Array[Byte]]): Long = throw new UnsupportedOperationException
+ def partitionCount(scanHandle: Long): Int = throw new UnsupportedOperationException
+ def executeStreamPartition(scanHandle: Long, partition: Int, ffiStreamAddr: Long): Unit =
+ throw new UnsupportedOperationException
+ def executeStream(scanHandle: Long, ffiStreamAddr: Long): Unit =
+ throw new UnsupportedOperationException
+ def closeScan(scanHandle: Long): Unit = throw new UnsupportedOperationException
+ }
+
+ /** Factory overriding only listPartitions (to spy on its inputs). */
+ private class MinimalFactory extends BridgeProviderFactory {
+ var lastListPartitionsOpts: Array[Byte] = _
+
+ override def scanBackend(): ScanBackend = StubBackend
+
+ override def listPartitions(optionsBytes: Array[Byte]): Array[PartitionInfo] = {
+ lastListPartitionsOpts = optionsBytes
+ Array(new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String]))
+ }
+ }
+
+ /** Only the required method implemented — the literal minimum a bridge can ship. */
+ private class EmptyFactory extends BridgeProviderFactory {
+ override def scanBackend(): ScanBackend = StubBackend
+ }
+
+ test("sharedScan defaults to false") {
+ assert(!new MinimalFactory().sharedScan(Array[Byte](1, 2, 3)))
+ }
+
+ test("default encodeOptions uses OptionsCodec") {
+ val opts = new java.util.HashMap[String, String]()
+ opts.put("url", "grpc://h:1")
+ val bytes = new EmptyFactory().encodeOptions(opts)
+ assert(java.util.Arrays.equals(bytes, OptionsCodec.encode(opts)))
+ assert(OptionsCodec.decode(bytes).get("url") == "grpc://h:1")
+ }
+
+ test("default listPartitions reports a single whole-dataset partition") {
+ val partitions = new EmptyFactory().listPartitions(Array[Byte](1))
+ assert(partitions.length == 1)
+ assert(partitions(0).id == "p0")
+ assert(partitions(0).partitionBytes().isEmpty)
+ assert(partitions(0).preferredLocations().isEmpty)
+ }
+
+ test("filter-aware listPartitions delegates to the filter-unaware overload") {
+ val factory = new MinimalFactory
+ val opts = Array[Byte](7, 8)
+ val filters = Array(Array[Byte](1), Array[Byte](2))
+ val partitions = factory.listPartitions(opts, filters)
+ assert(partitions.length == 1)
+ assert(partitions(0).id == "p0")
+ assert(factory.lastListPartitionsOpts eq opts)
+ }
+
+ test("reportPartitioning defaults to null") {
+ assert(new MinimalFactory().reportPartitioning(Array.emptyByteArray) == null)
+ }
+
+ test("PartitionInfo 3-arg constructor leaves partitionKeyValues null") {
+ val p = new PartitionInfo("p0", Array.emptyByteArray, Array.empty[String])
+ assert(p.partitionKeyValues() == null)
+ }
+
+ test("PartitionInfo 4-arg constructor carries key values") {
+ val p = new PartitionInfo(
+ "p0",
+ Array.emptyByteArray,
+ Array.empty[String],
+ Array[AnyRef]("segment-a", Long.box(42L)))
+ assert(p.partitionKeyValues().length == 2)
+ assert(p.partitionKeyValues()(0) == "segment-a")
+ }
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala
new file mode 100644
index 0000000..59f6c8f
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/OptionsCodecTest.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.io.ByteArrayOutputStream
+import java.nio.charset.StandardCharsets
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class OptionsCodecTest extends AnyFunSuite {
+
+ /**
+ * Shared fixture: must stay byte-identical to the one asserted by the Rust-side
+ * `datafusion_spark_bridge::options` tests. {"table": "t1", "url": "grpc://h:1"} encodes
+ * (sorted: table < url) as below.
+ */
+ private def fixtureBytes(): Array[Byte] = {
+ val out = new ByteArrayOutputStream()
+ def writeInt(v: Int): Unit = {
+ out.write((v >>> 24) & 0xFF); out.write((v >>> 16) & 0xFF)
+ out.write((v >>> 8) & 0xFF); out.write(v & 0xFF)
+ }
+ def writeString(s: String): Unit = {
+ val b = s.getBytes(StandardCharsets.UTF_8)
+ writeInt(b.length)
+ out.write(b, 0, b.length)
+ }
+ writeInt(2)
+ Seq("table" -> "t1", "url" -> "grpc://h:1").foreach { case (k, v) =>
+ writeString(k); writeString(v)
+ }
+ out.toByteArray
+ }
+
+ test("encodes the cross-language fixture byte-identically, sorted by key") {
+ // Insertion order deliberately unsorted; encoding must sort.
+ val opts = new java.util.LinkedHashMap[String, String]()
+ opts.put("url", "grpc://h:1")
+ opts.put("table", "t1")
+ assert(java.util.Arrays.equals(OptionsCodec.encode(opts), fixtureBytes()))
+ }
+
+ test("round-trips including unicode values") {
+ val opts = new java.util.HashMap[String, String]()
+ opts.put("a", "1")
+ opts.put("unicode", "héllo→world")
+ val decoded = OptionsCodec.decode(OptionsCodec.encode(opts))
+ assert(decoded.size() == 2)
+ assert(decoded.get("unicode") == "héllo→world")
+ }
+
+ test("null and empty maps encode to a zero count and decode back empty") {
+ assert(OptionsCodec.decode(OptionsCodec.encode(null)).isEmpty)
+ assert(OptionsCodec.decode(Array.emptyByteArray).isEmpty)
+ }
+
+ test("rejects truncation and trailing bytes") {
+ val bytes = fixtureBytes()
+ intercept[IllegalArgumentException] {
+ OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length - 1))
+ }
+ intercept[IllegalArgumentException] {
+ OptionsCodec.decode(java.util.Arrays.copyOf(bytes, bytes.length + 1))
+ }
+ }
+
+ test("rejects null keys or values") {
+ val opts = new java.util.HashMap[String, String]()
+ opts.put("k", null)
+ intercept[IllegalArgumentException] { OptionsCodec.encode(opts) }
+ }
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala
new file mode 100644
index 0000000..e2f876d
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/PartitionKeyConversionTest.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.spark.unsafe.types.UTF8String
+import org.scalatest.funsuite.AnyFunSuite
+
+class PartitionKeyConversionTest extends AnyFunSuite {
+
+ private def info(id: String, keys: Array[AnyRef]): PartitionInfo =
+ new PartitionInfo(id, Array.emptyByteArray, Array.empty[String], keys)
+
+ private def infoNoKeys(id: String): PartitionInfo =
+ new PartitionInfo(id, Array.emptyByteArray, Array.empty[String])
+
+ test("String and Long key values convert to catalyst representations") {
+ val reported = ReportedPartitioning.identity("segment_id", "bucket")
+ val row =
+ DatafusionBatch.toKeyRow("p0", Array[AnyRef]("segment-a", Long.box(42L)), reported)
+ assert(row.numFields == 2)
+ assert(row.get(0, org.apache.spark.sql.types.StringType) == UTF8String.fromString("segment-a"))
+ assert(row.getLong(1) == 42L)
+ }
+
+ test("arity mismatch between key values and declared keys throws") {
+ val reported = ReportedPartitioning.identity("segment_id", "bucket")
+ val e = intercept[IllegalStateException] {
+ DatafusionBatch.toKeyRow("p0", Array[AnyRef]("only-one"), reported)
+ }
+ assert(e.getMessage.contains("declares 2 key(s)"))
+ }
+
+ test("keyed state requires reported partitioning") {
+ val partitions = Array(info("p0", Array[AnyRef]("a")))
+ assert(!DatafusionBatch.validateKeyedState("F", partitions, null))
+ }
+
+ test("no partitions with keys means unkeyed, even with reported partitioning") {
+ val reported = ReportedPartitioning.identity("segment_id")
+ val partitions = Array(infoNoKeys("p0"), infoNoKeys("p1"))
+ assert(!DatafusionBatch.validateKeyedState("F", partitions, reported))
+ }
+
+ test("all partitions with keys means keyed") {
+ val reported = ReportedPartitioning.identity("segment_id")
+ val partitions =
+ Array(info("p0", Array[AnyRef]("a")), info("p1", Array[AnyRef]("b")))
+ assert(DatafusionBatch.validateKeyedState("F", partitions, reported))
+ }
+
+ test("mixed keyed and unkeyed partitions throw driver-side") {
+ val reported = ReportedPartitioning.identity("segment_id")
+ val partitions = Array(info("p0", Array[AnyRef]("a")), infoNoKeys("p1"))
+ val e = intercept[IllegalStateException] {
+ DatafusionBatch.validateKeyedState("F", partitions, reported)
+ }
+ assert(e.getMessage.contains("only 1 of 2"))
+ }
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala
new file mode 100644
index 0000000..dae49eb
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/SharedScanCacheTest.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.scalatest.funsuite.AnyFunSuite
+
+class SharedScanCacheTest extends AnyFunSuite {
+
+ private def spec(scanId: String): SharedScanSpec =
+ SharedScanSpec(
+ scanId = scanId,
+ factoryFqcn = "test.Factory",
+ optionsBytes = Array.emptyByteArray,
+ projectionColumnNames = Array.empty,
+ filterProtoBytes = Array.empty,
+ pinnedConfig = PinnedSessionConfig(8, 8192, Vector.empty)
+ )
+
+ /** JNI-free fake entry; records close. */
+ private final class FakeResources extends SharedScanResources {
+ @volatile var closed = false
+ override def partitionCount: Int = 3
+ override def newTaskAllocator(name: String): BufferAllocator =
+ throw new UnsupportedOperationException("not used in cache tests")
+ override def openPartitionStream(p: Int, a: BufferAllocator): ArrowReader =
+ throw new UnsupportedOperationException("not used in cache tests")
+ override def close(): Unit = closed = true
+ }
+
+ private final class Fixture {
+ val clock = new AtomicLong(0L)
+ val buildCount = new AtomicInteger(0)
+ var failBuilds = false
+ var lastBuilt: FakeResources = _
+
+ val cache = new SharedScanCache(
+ buildEntry = _ => {
+ buildCount.incrementAndGet()
+ if (failBuilds) throw new RuntimeException("synthetic build failure")
+ lastBuilt = new FakeResources
+ lastBuilt
+ },
+ nanoClock = () => clock.get()
+ )
+
+ def advanceMillis(ms: Long): Unit = clock.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms))
+ }
+
+ test("acquire builds once, second acquire reuses, refcount pairs with release") {
+ val f = new Fixture
+ val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ val r2 = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ assert(f.buildCount.get() == 1)
+ assert(r1 eq r2)
+ f.cache.release("s1")
+ f.cache.release("s1")
+ }
+
+ test("concurrent acquires build exactly once") {
+ val f = new Fixture
+ val n = 8
+ val pool = Executors.newFixedThreadPool(n)
+ val ready = new CountDownLatch(n)
+ val go = new CountDownLatch(1)
+ try {
+ val futures = (0 until n).map { _ =>
+ pool.submit { () =>
+ ready.countDown()
+ go.await()
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ }
+ }
+ ready.await()
+ go.countDown()
+ val results = futures.map(_.get(10, TimeUnit.SECONDS))
+ assert(f.buildCount.get() == 1)
+ assert(results.forall(_ eq results.head))
+ (0 until n).foreach(_ => f.cache.release("s1"))
+ } finally {
+ pool.shutdownNow()
+ }
+ }
+
+ test("build failure propagates and is not cached") {
+ val f = new Fixture
+ f.failBuilds = true
+ val e = intercept[RuntimeException](f.cache.acquire(spec("s1"), idleTtlMs = 1000))
+ assert(e.getMessage == "synthetic build failure")
+ f.failBuilds = false
+ val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ assert(f.buildCount.get() == 2)
+ assert(r eq f.lastBuilt)
+ f.cache.release("s1")
+ }
+
+ test("idle entry past TTL is evicted and closed") {
+ val f = new Fixture
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ f.cache.release("s1")
+ val built = f.lastBuilt
+ f.advanceMillis(999)
+ f.cache.evictIdleNow()
+ assert(!built.closed)
+ f.advanceMillis(2)
+ f.cache.evictIdleNow()
+ assert(built.closed)
+ }
+
+ test("entry in use is never evicted, regardless of idle time") {
+ val f = new Fixture
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ val built = f.lastBuilt
+ f.advanceMillis(100000)
+ f.cache.evictIdleNow()
+ assert(!built.closed)
+ f.cache.release("s1")
+ f.advanceMillis(100000)
+ f.cache.evictIdleNow()
+ assert(built.closed)
+ }
+
+ test("release then reacquire within TTL resets idleness") {
+ val f = new Fixture
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ f.cache.release("s1")
+ f.advanceMillis(900)
+ // Next task wave lands before TTL: same entry, no rebuild.
+ val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ assert(f.buildCount.get() == 1)
+ assert(r eq f.lastBuilt)
+ f.cache.release("s1")
+ f.advanceMillis(900)
+ f.cache.evictIdleNow()
+ assert(!f.lastBuilt.closed, "idle clock must restart at the last release")
+ }
+
+ test("acquire after eviction rebuilds") {
+ val f = new Fixture
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ f.cache.release("s1")
+ val first = f.lastBuilt
+ f.advanceMillis(2000)
+ f.cache.evictIdleNow()
+ assert(first.closed)
+ val r = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ assert(f.buildCount.get() == 2)
+ assert(r ne first)
+ f.cache.release("s1")
+ }
+
+ test("distinct scanIds get distinct entries") {
+ val f = new Fixture
+ val r1 = f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ val r2 = f.cache.acquire(spec("s2"), idleTtlMs = 1000)
+ assert(f.buildCount.get() == 2)
+ assert(r1 ne r2)
+ f.cache.release("s1")
+ f.cache.release("s2")
+ }
+
+ test("unbalanced release throws") {
+ val f = new Fixture
+ intercept[IllegalStateException](f.cache.release("never-acquired"))
+ }
+
+ test("shutdown closes everything, even entries in use") {
+ val f = new Fixture
+ f.cache.acquire(spec("s1"), idleTtlMs = 1000)
+ val built = f.lastBuilt
+ f.cache.shutdown()
+ assert(built.closed)
+ }
+}
diff --git a/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala
new file mode 100644
index 0000000..b7faac1
--- /dev/null
+++ b/spark/src/test/scala/io/datafusion/spark/SparkPredicateTranslatorTest.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package io.datafusion.spark
+
+import org.apache.datafusion.protobuf.LogicalExprNode
+import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NamedReference}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.scalatest.funsuite.AnyFunSuite
+
+class SparkPredicateTranslatorTest extends AnyFunSuite {
+
+ private def col(name: String): NamedReference = Expressions.column(name)
+ private def litInt(v: Int): Expression = Expressions.literal(Int.box(v))
+ private def litLong(v: Long): Expression = Expressions.literal(Long.box(v))
+ private def litStr(v: String): Expression =
+ Expressions.literal(org.apache.spark.unsafe.types.UTF8String.fromString(v))
+
+ test("LessThan(timeline, 1_000_000) translates to a non-empty proto") {
+ val p = new Predicate("<", Array[Expression](col("timeline"), litLong(1000000L)))
+ val node = SparkPredicateTranslator.translate(p).getOrElse(fail("expected Some"))
+ val bytes = node.toByteArray
+ assert(bytes.nonEmpty)
+ val parsed = LogicalExprNode.parseFrom(bytes)
+ assert(parsed.hasBinaryExpr)
+ assert(parsed.getBinaryExpr.getOp == "Lt")
+ }
+
+ test("AND of two translatable predicates round-trips through binary op 'And'") {
+ val lt = new Predicate("<", Array[Expression](col("a"), litInt(10)))
+ val eq = new Predicate("=", Array[Expression](col("b"), litStr("x")))
+ val and = new Predicate("AND", Array[Expression](lt, eq))
+ val node = SparkPredicateTranslator.translate(and).getOrElse(fail("expected Some"))
+ val parsed = LogicalExprNode.parseFrom(node.toByteArray)
+ assert(parsed.hasBinaryExpr)
+ assert(parsed.getBinaryExpr.getOp == "And")
+ }
+
+ test("AND becomes residual when an operand is untranslatable") {
+ val nse = new Predicate("<=>", Array[Expression](col("a"), litInt(1)))
+ val eq = new Predicate("=", Array[Expression](col("b"), litInt(2)))
+ val and = new Predicate("AND", Array[Expression](nse, eq))
+ assert(SparkPredicateTranslator.translate(and).isEmpty)
+ }
+
+ test("IS_NULL and IS_NOT_NULL emit the dedicated proto variants") {
+ val isNull = new Predicate("IS_NULL", Array[Expression](col("x")))
+ val isNotNull = new Predicate("IS_NOT_NULL", Array[Expression](col("x")))
+ val n1 = SparkPredicateTranslator.translate(isNull).getOrElse(fail()).toByteArray
+ val n2 = SparkPredicateTranslator.translate(isNotNull).getOrElse(fail()).toByteArray
+ val p1 = LogicalExprNode.parseFrom(n1)
+ val p2 = LogicalExprNode.parseFrom(n2)
+ assert(p1.hasIsNullExpr)
+ assert(p2.hasIsNotNullExpr)
+ }
+
+ test("STARTS_WITH translates to a LIKE with a '%' suffix") {
+ val p =
+ new Predicate("STARTS_WITH", Array[Expression](col("name"), litStr("foo")))
+ val node = SparkPredicateTranslator.translate(p).getOrElse(fail())
+ val parsed = LogicalExprNode.parseFrom(node.toByteArray)
+ assert(parsed.hasLike)
+ val patStr = parsed.getLike.getPattern.getLiteral.getUtf8Value
+ assert(patStr == "foo%")
+ }
+
+ test("unknown predicate name returns None (becomes residual)") {
+ val p = new Predicate("UNKNOWN_OP", Array[Expression](col("x"), litInt(1)))
+ assert(SparkPredicateTranslator.translate(p).isEmpty)
+ }
+}