Skip to content

Commit

Permalink
Check permission when registering tables in Iceberg and Delta Lake
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Oct 9, 2024
1 parent 3882fb6 commit 0a3a89f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.deltalake.procedure;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import com.google.inject.Provider;
import io.trino.filesystem.Location;
Expand All @@ -32,6 +33,7 @@
import io.trino.plugin.deltalake.transactionlog.TransactionLogAccess;
import io.trino.spi.TrinoException;
import io.trino.spi.classloader.ThreadContextClassLoader;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SchemaNotFoundException;
import io.trino.spi.connector.SchemaTableName;
Expand Down Expand Up @@ -68,7 +70,7 @@ public class RegisterTableProcedure

static {
try {
REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorSession.class, String.class, String.class, String.class));
REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorAccessControl.class, ConnectorSession.class, String.class, String.class, String.class));
}
catch (ReflectiveOperationException e) {
throw new AssertionError(e);
Expand Down Expand Up @@ -110,13 +112,15 @@ public Procedure get()
}

public void registerTable(
ConnectorAccessControl accessControl,
ConnectorSession clientSession,
String schemaName,
String tableName,
String tableLocation)
{
try (ThreadContextClassLoader _ = new ThreadContextClassLoader(getClass().getClassLoader())) {
doRegisterTable(
accessControl,
clientSession,
schemaName,
tableName,
Expand All @@ -125,6 +129,7 @@ public void registerTable(
}

private void doRegisterTable(
ConnectorAccessControl accessControl,
ConnectorSession session,
String schemaName,
String tableName,
Expand All @@ -138,6 +143,7 @@ private void doRegisterTable(
checkProcedureArgument(!isNullOrEmpty(tableLocation), "table_location cannot be null or empty");

SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName);
accessControl.checkCanCreateTable(null, schemaTableName, ImmutableMap.of());
DeltaLakeMetadata metadata = metadataFactory.create(session.getIdentity());
metadata.beginQuery(session);
try (UncheckedCloseable ignore = () -> metadata.cleanupQuery(session)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.testing.MaterializedResult.resultBuilder;
import static io.trino.testing.QueryAssertions.copyTpchTables;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_TABLE;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE_FUNCTION;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN;
import static io.trino.testing.TestingAccessControlManager.privilege;
Expand Down Expand Up @@ -4820,6 +4821,20 @@ public void testDuplicatedFieldNames()
}
}

@Test
void testRegisterTableAccessControl()
{
String tableName = "test_register_table_access_control_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " AS SELECT 1 a", 1);
String tableLocation = metastore.getTable(SCHEMA, tableName).orElseThrow().getStorage().getLocation();
metastore.dropTable(SCHEMA, tableName, false);

assertAccessDenied(
"CALL system.register_table(CURRENT_SCHEMA, '" + tableName + "', '" + tableLocation + "')",
"Cannot create table .*",
privilege(tableName, CREATE_TABLE));
}

@Test
public void testMetastoreAfterCreateTable()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.iceberg.procedure;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import com.google.inject.Provider;
import io.trino.filesystem.Location;
Expand All @@ -25,6 +26,7 @@
import io.trino.plugin.iceberg.fileio.ForwardingFileIo;
import io.trino.spi.TrinoException;
import io.trino.spi.classloader.ThreadContextClassLoader;
import io.trino.spi.connector.ConnectorAccessControl;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.procedure.Procedure;
Expand Down Expand Up @@ -64,7 +66,7 @@ public class RegisterTableProcedure

static {
try {
REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorSession.class, String.class, String.class, String.class, String.class));
REGISTER_TABLE = lookup().unreflect(RegisterTableProcedure.class.getMethod("registerTable", ConnectorAccessControl.class, ConnectorSession.class, String.class, String.class, String.class, String.class));
}
catch (ReflectiveOperationException e) {
throw new AssertionError(e);
Expand Down Expand Up @@ -98,6 +100,7 @@ public Procedure get()
}

public void registerTable(
ConnectorAccessControl accessControl,
ConnectorSession clientSession,
String schemaName,
String tableName,
Expand All @@ -106,6 +109,7 @@ public void registerTable(
{
try (ThreadContextClassLoader _ = new ThreadContextClassLoader(getClass().getClassLoader())) {
doRegisterTable(
accessControl,
clientSession,
schemaName,
tableName,
Expand All @@ -115,6 +119,7 @@ public void registerTable(
}

private void doRegisterTable(
ConnectorAccessControl accessControl,
ConnectorSession clientSession,
String schemaName,
String tableName,
Expand All @@ -130,6 +135,7 @@ private void doRegisterTable(
metadataFileName.ifPresent(RegisterTableProcedure::validateMetadataFileName);

SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName);
accessControl.checkCanCreateTable(null, schemaTableName, ImmutableMap.of());
TrinoCatalog catalog = catalogFactory.create(clientSession.getIdentity());
if (!catalog.namespaceExists(clientSession, schemaTableName.getSchemaName())) {
throw new TrinoException(SCHEMA_NOT_FOUND, format("Schema '%s' does not exist", schemaTableName.getSchemaName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import static io.trino.plugin.iceberg.IcebergTestUtils.getFileSystemFactory;
import static io.trino.plugin.iceberg.IcebergUtil.METADATA_FOLDER_NAME;
import static io.trino.plugin.iceberg.IcebergUtil.getLatestMetadataLocation;
import static io.trino.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_TABLE;
import static io.trino.testing.TestingAccessControlManager.privilege;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static io.trino.testing.TestingSession.testSessionBuilder;
Expand Down Expand Up @@ -541,6 +543,20 @@ public void testRegisterHadoopTableAndRead()
assertUpdate("DROP TABLE " + tempTableName);
}

@Test
void testRegisterTableAccessControl()
{
String tableName = "test_register_table_access_control_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " AS SELECT 1 a", 1);
String tableLocation = getTableLocation(tableName);
assertUpdate("CALL system.unregister_table(CURRENT_SCHEMA, '" + tableName + "')");

assertAccessDenied(
"CALL system.register_table(CURRENT_SCHEMA, '" + tableName + "', '" + tableLocation + "')",
"Cannot create table .*",
privilege(tableName, CREATE_TABLE));
}

private String getTableLocation(String tableName)
{
Pattern locationPattern = Pattern.compile(".*location = '(.*?)'.*", Pattern.DOTALL);
Expand Down

0 comments on commit 0a3a89f

Please sign in to comment.