diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..2ed53d6f8e 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -85,20 +85,20 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols f"DataFrames, and cannot be used as column names" ) from None - # Step 1: Prepare source index with join keys and a marker index - # Cast to target table schema, so we can do the join - # See: https://github.com/apache/arrow/issues/37542 + # Step 1: Prepare source index with join keys and a marker index. + # Cast only join columns to target join-column schema so schema evolution + # (for example, newly added non-key columns) doesn't break the join setup. source_index = ( - source_table.cast(target_table.schema) - .select(join_cols_set) + source_table.select(join_cols) + .cast(pa.schema([target_table.schema.field(col) for col in join_cols])) .append_column(SOURCE_INDEX_COLUMN_NAME, pa.array(range(len(source_table)))) ) # Step 2: Prepare target index with join keys and a marker - target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) + target_index = target_table.select(join_cols).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table)))) # Step 3: Perform an inner join to find which rows from source exist in target - matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner") + matching_indices = source_index.join(target_index, keys=join_cols, join_type="inner") # Step 4: Compare all rows using Python to_update_indices = [] @@ -112,7 +112,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols for key in non_key_cols: source_val = source_row.column(key)[0].as_py() - target_val = target_row.column(key)[0].as_py() + target_val = target_row.column(key)[0].as_py() if key in target_table.column_names else None if source_val != target_val: to_update_indices.append(source_idx) break diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..072064cc0d 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -714,6 +714,57 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ) +def test_upsert_after_schema_add_column(catalog: Catalog) -> None: + identifier = "default.test_upsert_after_schema_add_column" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "name", StringType(), required=True), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + initial = pa.Table.from_pylist( + [{"id": 1, "name": "Alice"}], + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ] + ), + ) + tbl.append(initial) + + with tbl.update_schema() as update_schema: + update_schema.add_column("country", StringType()) + tbl = tbl.refresh() + + source = pa.Table.from_pylist( + [ + {"id": 1, "name": "Alice", "country": "NL"}, + {"id": 2, "name": "Bob", "country": "US"}, + ], + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + pa.field("country", pa.string(), nullable=True), + ] + ), + ) + + upd = tbl.upsert(source, ["id"]) + + assert upd.rows_updated == 1 + assert upd.rows_inserted == 1 + assert sorted(tbl.scan().to_arrow().to_pylist(), key=lambda row: row["id"]) == [ + {"id": 1, "name": "Alice", "country": "NL"}, + {"id": 2, "name": "Bob", "country": "US"}, + ] + + def test_transaction(catalog: Catalog) -> None: """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is rolled back."""