From cf358a4d150a3c7dcef9d1bccb09986a938aa128 Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Wed, 16 Aug 2023 07:44:05 +0800 Subject: [PATCH] Spark 3.4: Support setting current snapshot with ref (#8163) --- docs/spark-procedures.md | 10 ++- .../TestSetCurrentSnapshotProcedure.java | 65 +++++++++++++++++-- .../SetCurrentSnapshotProcedure.java | 24 +++++-- 3 files changed, 88 insertions(+), 11 deletions(-) diff --git a/docs/spark-procedures.md b/docs/spark-procedures.md index 59fa39dc9347..debf393a30b6 100644 --- a/docs/spark-procedures.md +++ b/docs/spark-procedures.md @@ -130,7 +130,10 @@ This procedure invalidates all cached Spark plans that reference the affected ta | Argument Name | Required? | Type | Description | |---------------|-----------|------|-------------| | `table` | ✔️ | string | Name of the table to update | -| `snapshot_id` | ✔️ | long | Snapshot ID to set as current | +| `snapshot_id` | | long | Snapshot ID to set as current | +| `ref` | | string | Snapshot Referece (branch or tag) to set as current | + +Either `snapshot_id` or `ref` must be provided but not both. #### Output @@ -146,6 +149,11 @@ Set the current snapshot for `db.sample` to 1: CALL catalog_name.system.set_current_snapshot('db.sample', 1) ``` +Set the current snapshot for `db.sample` to tag `s1`: +```sql +CALL catalog_name.system.set_current_snapshot(table => 'db.sample', tag => 's1'); +``` + ### `cherrypick_snapshot` Cherry-picks changes from a snapshot into the current table state. diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java index 51db8d321059..e894ba4ff0ae 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSetCurrentSnapshotProcedure.java @@ -212,12 +212,12 @@ public void testInvalidRollbackToSnapshotCases() { Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot('t')", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); Assertions.assertThatThrownBy(() -> sql("CALL %s.system.set_current_snapshot(1L)", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Cannot parse identifier for arg table: 1"); Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot(snapshot_id => 1L)", catalogName)) @@ -226,8 +226,8 @@ public void testInvalidRollbackToSnapshotCases() { Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot(table => 't')", catalogName)) - .isInstanceOf(AnalysisException.class) - .hasMessage("Missing required parameters: [snapshot_id]"); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); Assertions.assertThatThrownBy( () -> sql("CALL %s.system.set_current_snapshot('t', 2.2)", catalogName)) @@ -238,5 +238,58 @@ public void testInvalidRollbackToSnapshotCases() { () -> sql("CALL %s.system.set_current_snapshot('', 1L)", catalogName)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Cannot handle an empty identifier for argument table"); + + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => 't', snapshot_id => 1L, ref => 's1')", + catalogName)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Either snapshot_id or ref must be provided, not both"); + } + + @Test + public void testSetCurrentSnapshotToRef() { + sql("CREATE TABLE %s (id bigint NOT NULL, data string) USING iceberg", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot firstSnapshot = table.currentSnapshot(); + String ref = "s1"; + sql("ALTER TABLE %s CREATE TAG %s", tableName, ref); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); + assertEquals( + "Should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + table.refresh(); + + Snapshot secondSnapshot = table.currentSnapshot(); + + List output = + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, ref); + + assertEquals( + "Procedure output must match", + ImmutableList.of(row(secondSnapshot.snapshotId(), firstSnapshot.snapshotId())), + output); + + assertEquals( + "Set must be successful", + ImmutableList.of(row(1L, "a")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + String notExistRef = "s2"; + Assertions.assertThatThrownBy( + () -> + sql( + "CALL %s.system.set_current_snapshot(table => '%s', ref => '%s')", + catalogName, tableIdent, notExistRef)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot find matching snapshot ID for ref " + notExistRef); } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java index f8f8049c22b6..22719e43c057 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/SetCurrentSnapshotProcedure.java @@ -19,6 +19,10 @@ package org.apache.iceberg.spark.procedures; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.procedures.SparkProcedures.ProcedureBuilder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Identifier; @@ -42,7 +46,8 @@ class SetCurrentSnapshotProcedure extends BaseProcedure { private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", DataTypes.StringType), - ProcedureParameter.required("snapshot_id", DataTypes.LongType) + ProcedureParameter.optional("snapshot_id", DataTypes.LongType), + ProcedureParameter.optional("ref", DataTypes.StringType) }; private static final StructType OUTPUT_TYPE = @@ -78,7 +83,11 @@ public StructType outputType() { @Override public InternalRow[] call(InternalRow args) { Identifier tableIdent = toIdentifier(args.getString(0), PARAMETERS[0].name()); - long snapshotId = args.getLong(1); + Long snapshotId = args.isNullAt(1) ? null : args.getLong(1); + String ref = args.isNullAt(2) ? null : args.getString(2); + Preconditions.checkArgument( + (snapshotId != null && ref == null) || (snapshotId == null && ref != null), + "Either snapshot_id or ref must be provided, not both"); return modifyIcebergTable( tableIdent, @@ -86,9 +95,10 @@ public InternalRow[] call(InternalRow args) { Snapshot previousSnapshot = table.currentSnapshot(); Long previousSnapshotId = previousSnapshot != null ? previousSnapshot.snapshotId() : null; - table.manageSnapshots().setCurrentSnapshot(snapshotId).commit(); + long targetSnapshotId = snapshotId != null ? snapshotId : toSnapshotId(table, ref); + table.manageSnapshots().setCurrentSnapshot(targetSnapshotId).commit(); - InternalRow outputRow = newInternalRow(previousSnapshotId, snapshotId); + InternalRow outputRow = newInternalRow(previousSnapshotId, targetSnapshotId); return new InternalRow[] {outputRow}; }); } @@ -97,4 +107,10 @@ public InternalRow[] call(InternalRow args) { public String description() { return "SetCurrentSnapshotProcedure"; } + + private long toSnapshotId(Table table, String refName) { + SnapshotRef ref = table.refs().get(refName); + ValidationException.check(ref != null, "Cannot find matching snapshot ID for ref " + refName); + return ref.snapshotId(); + } }