From a6a31678c85033a019ed1c4934e70b7721fe9826 Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Mon, 8 Jan 2024 14:16:52 +0200 Subject: [PATCH 1/6] Avro data encryption --- .../java/org/apache/iceberg/avro/Avro.java | 9 - .../avro/TestEncryptedAvroFileSplit.java | 209 ++++++++++++++++++ .../encryption/EncryptionTestHelpers.java | 38 ++++ .../apache/iceberg/encryption/UnitestKMS.java | 38 ++++ 4 files changed, 285 insertions(+), 9 deletions(-) create mode 100644 core/src/test/java/org/apache/iceberg/avro/TestEncryptedAvroFileSplit.java create mode 100644 core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java create mode 100644 core/src/test/java/org/apache/iceberg/encryption/UnitestKMS.java diff --git a/core/src/main/java/org/apache/iceberg/avro/Avro.java b/core/src/main/java/org/apache/iceberg/avro/Avro.java index 6aa3241c3660..0eaa3f2d2400 100644 --- a/core/src/main/java/org/apache/iceberg/avro/Avro.java +++ b/core/src/main/java/org/apache/iceberg/avro/Avro.java @@ -94,9 +94,6 @@ public static WriteBuilder write(OutputFile file) { } public static WriteBuilder write(EncryptedOutputFile file) { - Preconditions.checkState( - file.keyMetadata() == null || file.keyMetadata() == EncryptionKeyMetadata.EMPTY, - "Avro encryption is not supported"); return new WriteBuilder(file.encryptingOutputFile()); } @@ -282,9 +279,6 @@ public static DataWriteBuilder writeData(OutputFile file) { } public static DataWriteBuilder writeData(EncryptedOutputFile file) { - Preconditions.checkState( - file.keyMetadata() == null || file.keyMetadata() == EncryptionKeyMetadata.EMPTY, - "Avro encryption is not supported"); return new DataWriteBuilder(file.encryptingOutputFile()); } @@ -385,9 +379,6 @@ public static DeleteWriteBuilder writeDeletes(OutputFile file) { } public static DeleteWriteBuilder writeDeletes(EncryptedOutputFile file) { - Preconditions.checkState( - file.keyMetadata() == null || file.keyMetadata() == EncryptionKeyMetadata.EMPTY, - "Avro encryption is not supported"); return new DeleteWriteBuilder(file.encryptingOutputFile()); } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestEncryptedAvroFileSplit.java b/core/src/test/java/org/apache/iceberg/avro/TestEncryptedAvroFileSplit.java new file mode 100644 index 000000000000..9020a1230271 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/avro/TestEncryptedAvroFileSplit.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.avro; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.UUID; +import org.apache.iceberg.Files; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.data.GenericRecord; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.avro.DataReader; +import org.apache.iceberg.data.avro.DataWriter; +import org.apache.iceberg.encryption.EncryptedFiles; +import org.apache.iceberg.encryption.EncryptedInputFile; +import org.apache.iceberg.encryption.EncryptedOutputFile; +import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.encryption.EncryptionTestHelpers; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class TestEncryptedAvroFileSplit { + private static final Schema SCHEMA = + new Schema( + NestedField.required(1, "id", Types.LongType.get()), + NestedField.required(2, "data", Types.StringType.get())); + + private static final EncryptionManager ENCRYPTION_MANAGER = + EncryptionTestHelpers.createEncryptionManager(); + + private static final int NUM_RECORDS = 100_000; + + @TempDir Path temp; + + public List expected = null; + public InputFile file = null; + + @BeforeEach + public void writeDataFile() throws IOException { + this.expected = Lists.newArrayList(); + + OutputFile out = Files.localOutput(temp.toFile()); + + EncryptedOutputFile eOut = ENCRYPTION_MANAGER.encrypt(out); + + try (FileAppender writer = + Avro.write(eOut) + .set(TableProperties.AVRO_COMPRESSION, "uncompressed") + .createWriterFunc(DataWriter::create) + .schema(SCHEMA) + .overwrite() + .build()) { + + Record record = GenericRecord.create(SCHEMA); + for (long i = 0; i < NUM_RECORDS; i += 1) { + Record next = record.copy(ImmutableMap.of("id", i, "data", UUID.randomUUID().toString())); + expected.add(next); + writer.add(next); + } + } + + EncryptedInputFile encryptedIn = + EncryptedFiles.encryptedInput(out.toInputFile(), eOut.keyMetadata()); + + this.file = ENCRYPTION_MANAGER.decrypt(encryptedIn); + } + + @Test + public void testSplitDataSkipping() throws IOException { + long end = file.getLength(); + long splitLocation = end / 2; + + List firstHalf = readAvro(file, SCHEMA, 0, splitLocation); + assertThat(firstHalf.size()).as("First split should not be empty").isNotEqualTo(0); + + List secondHalf = readAvro(file, SCHEMA, splitLocation + 1, end - splitLocation - 1); + assertThat(secondHalf.size()).as("Second split should not be empty").isNotEqualTo(0); + + assertThat(firstHalf.size() + secondHalf.size()) + .as("Total records should match expected") + .isEqualTo(expected.size()); + + for (int i = 0; i < firstHalf.size(); i += 1) { + assertThat(firstHalf.get(i)).isEqualTo(expected.get(i)); + } + + for (int i = 0; i < secondHalf.size(); i += 1) { + assertThat(secondHalf.get(i)).isEqualTo(expected.get(firstHalf.size() + i)); + } + } + + @Test + public void testPosField() throws IOException { + Schema projection = + new Schema(SCHEMA.columns().get(0), MetadataColumns.ROW_POSITION, SCHEMA.columns().get(1)); + + List records = readAvro(file, projection, 0, file.getLength()); + + for (int i = 0; i < expected.size(); i += 1) { + assertThat(records.get(i).getField(MetadataColumns.ROW_POSITION.name())) + .as("Field _pos should match") + .isEqualTo((long) i); + + assertThat(records.get(i).getField("id")) + .as("Field id should match") + .isEqualTo(expected.get(i).getField("id")); + + assertThat(records.get(i).getField("data")) + .as("Field data should match") + .isEqualTo(expected.get(i).getField("data")); + } + } + + @Test + public void testPosFieldWithSplits() throws IOException { + Schema projection = + new Schema(SCHEMA.columns().get(0), MetadataColumns.ROW_POSITION, SCHEMA.columns().get(1)); + + long end = file.getLength(); + long splitLocation = end / 2; + + List secondHalf = + readAvro(file, projection, splitLocation + 1, end - splitLocation - 1); + assertThat(secondHalf.size()).as("Second split should not be empty").isNotEqualTo(0); + + List firstHalf = readAvro(file, projection, 0, splitLocation); + assertThat(firstHalf.size()).as("First split should not be empty").isNotEqualTo(0); + + assertThat(firstHalf.size() + secondHalf.size()) + .as("Total records should match expected") + .isEqualTo(expected.size()); + + for (int i = 0; i < firstHalf.size(); i += 1) { + assertThat(firstHalf.get(i).getField(MetadataColumns.ROW_POSITION.name())) + .as("Field _pos should match") + .isEqualTo((long) i); + assertThat(firstHalf.get(i).getField("id")) + .as("Field id should match") + .isEqualTo(expected.get(i).getField("id")); + assertThat(firstHalf.get(i).getField("data")) + .as("Field data should match") + .isEqualTo(expected.get(i).getField("data")); + } + + for (int i = 0; i < secondHalf.size(); i += 1) { + assertThat(secondHalf.get(i).getField(MetadataColumns.ROW_POSITION.name())) + .as("Field _pos should match") + .isEqualTo((long) (firstHalf.size() + i)); + assertThat(secondHalf.get(i).getField("id")) + .as("Field id should match") + .isEqualTo(expected.get(firstHalf.size() + i).getField("id")); + assertThat(secondHalf.get(i).getField("data")) + .as("Field data should match") + .isEqualTo(expected.get(firstHalf.size() + i).getField("data")); + } + } + + @Test + public void testPosWithEOFSplit() throws IOException { + Schema projection = + new Schema(SCHEMA.columns().get(0), MetadataColumns.ROW_POSITION, SCHEMA.columns().get(1)); + + long end = file.getLength(); + + List records = readAvro(file, projection, end - 10, 10); + assertThat(records.size()).as("Should not read any records").isEqualTo(0); + } + + public List readAvro(InputFile in, Schema projection, long start, long length) + throws IOException { + try (AvroIterable reader = + Avro.read(in) + .createReaderFunc(DataReader::create) + .split(start, length) + .project(projection) + .build()) { + return Lists.newArrayList(reader); + } + } +} diff --git a/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java b/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java new file mode 100644 index 000000000000..f1f59dac9f36 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import java.util.Map; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +public class EncryptionTestHelpers { + public static EncryptionManager createEncryptionManager() { + Map catalogProperties = Maps.newHashMap(); + catalogProperties.put( + CatalogProperties.ENCRYPTION_KMS_IMPL, UnitestKMS.class.getCanonicalName()); + Map tableProperties = Maps.newHashMap(); + tableProperties.put(TableProperties.ENCRYPTION_TABLE_KEY, UnitestKMS.MASTER_KEY_NAME1); + tableProperties.put(TableProperties.FORMAT_VERSION, "2"); + + return EncryptionUtil.createEncryptionManager( + tableProperties, EncryptionUtil.createKmsClient(catalogProperties)); + } +} diff --git a/core/src/test/java/org/apache/iceberg/encryption/UnitestKMS.java b/core/src/test/java/org/apache/iceberg/encryption/UnitestKMS.java new file mode 100644 index 000000000000..52a0e36c0011 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/encryption/UnitestKMS.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class UnitestKMS extends MemoryMockKMS { + public static final String MASTER_KEY_NAME1 = "keyA"; + public static final byte[] MASTER_KEY1 = "0123456789012345".getBytes(StandardCharsets.UTF_8); + public static final String MASTER_KEY_NAME2 = "keyB"; + public static final byte[] MASTER_KEY2 = "1123456789012345".getBytes(StandardCharsets.UTF_8); + + @Override + public void initialize(Map properties) { + masterKeys = + ImmutableMap.of( + MASTER_KEY_NAME1, MASTER_KEY1, + MASTER_KEY_NAME2, MASTER_KEY2); + } +} From 0c5598dd59d813394f48bfd1da5331281f047d8b Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Mon, 8 Jan 2024 14:25:44 +0200 Subject: [PATCH 2/6] private constructor for util --- .../org/apache/iceberg/encryption/EncryptionTestHelpers.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java b/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java index f1f59dac9f36..aa49e1c40fe2 100644 --- a/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java +++ b/core/src/test/java/org/apache/iceberg/encryption/EncryptionTestHelpers.java @@ -24,6 +24,9 @@ import org.apache.iceberg.relocated.com.google.common.collect.Maps; public class EncryptionTestHelpers { + + private EncryptionTestHelpers() {} + public static EncryptionManager createEncryptionManager() { Map catalogProperties = Maps.newHashMap(); catalogProperties.put( From dc6f80e11e28a14afb224528a5fb699955c35f0e Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Wed, 10 Jan 2024 08:55:11 +0200 Subject: [PATCH 3/6] output stream fix --- .../iceberg/encryption/AesGcmInputStream.java | 252 ------------------ .../encryption/AesGcmOutputStream.java | 161 ----------- 2 files changed, 413 deletions(-) delete mode 100644 core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java delete mode 100644 core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java deleted file mode 100644 index 1f52ab3682f8..000000000000 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.encryption; - -import java.io.EOFException; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import org.apache.iceberg.io.IOUtil; -import org.apache.iceberg.io.SeekableInputStream; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; - -public class AesGcmInputStream extends SeekableInputStream { - private final SeekableInputStream sourceStream; - private final byte[] fileAADPrefix; - private final Ciphers.AesGcmDecryptor decryptor; - private final byte[] cipherBlockBuffer; - private final byte[] currentPlainBlock; - private final long numBlocks; - private final int lastCipherBlockSize; - private final long plainStreamSize; - private final byte[] singleByte; - - private long plainStreamPosition; - private long currentPlainBlockIndex; - private int currentPlainBlockSize; - - AesGcmInputStream( - SeekableInputStream sourceStream, long sourceLength, byte[] aesKey, byte[] fileAADPrefix) { - this.sourceStream = sourceStream; - this.fileAADPrefix = fileAADPrefix; - this.decryptor = new Ciphers.AesGcmDecryptor(aesKey); - this.cipherBlockBuffer = new byte[Ciphers.CIPHER_BLOCK_SIZE]; - this.currentPlainBlock = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - this.plainStreamPosition = 0; - this.currentPlainBlockIndex = -1; - this.currentPlainBlockSize = 0; - - long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; - long numFullBlocks = Math.toIntExact(streamLength / Ciphers.CIPHER_BLOCK_SIZE); - long cipherFullBlockLength = numFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; - int cipherBytesInLastBlock = Math.toIntExact(streamLength - cipherFullBlockLength); - boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); - this.numBlocks = fullBlocksOnly ? numFullBlocks : numFullBlocks + 1; - this.lastCipherBlockSize = - fullBlocksOnly ? Ciphers.CIPHER_BLOCK_SIZE : cipherBytesInLastBlock; // never 0 - - long lastPlainBlockSize = - (long) lastCipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; - this.plainStreamSize = - numFullBlocks * Ciphers.PLAIN_BLOCK_SIZE + (fullBlocksOnly ? 0 : lastPlainBlockSize); - this.singleByte = new byte[1]; - } - - private void validateHeader() throws IOException { - byte[] headerBytes = new byte[Ciphers.GCM_STREAM_HEADER_LENGTH]; - IOUtil.readFully(sourceStream, headerBytes, 0, headerBytes.length); - - Preconditions.checkState( - Ciphers.GCM_STREAM_MAGIC.equals(ByteBuffer.wrap(headerBytes, 0, 4)), - "Invalid GCM stream: magic does not match AGS1"); - - int plainBlockSize = ByteBuffer.wrap(headerBytes, 4, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(); - Preconditions.checkState( - plainBlockSize == Ciphers.PLAIN_BLOCK_SIZE, - "Invalid GCM stream: block size %d != %d", - plainBlockSize, - Ciphers.PLAIN_BLOCK_SIZE); - } - - @Override - public int available() { - long maxAvailable = plainStreamSize - plainStreamPosition; - // See InputStream.available contract - if (maxAvailable >= Integer.MAX_VALUE) { - return Integer.MAX_VALUE; - } else { - return (int) maxAvailable; - } - } - - private int availableInCurrentBlock() { - if (blockIndex(plainStreamPosition) != currentPlainBlockIndex) { - return 0; - } - - return currentPlainBlockSize - offsetInBlock(plainStreamPosition); - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - Preconditions.checkArgument(len >= 0, "Invalid read length: " + len); - - if (currentPlainBlockIndex < 0) { - decryptBlock(0); - } - - if (available() <= 0 && len > 0) { - return -1; - } - - if (len == 0) { - return 0; - } - - int totalBytesRead = 0; - int resultBufferOffset = off; - int remainingBytesToRead = len; - - while (remainingBytesToRead > 0) { - int availableInBlock = availableInCurrentBlock(); - if (availableInBlock > 0) { - int bytesToCopy = Math.min(availableInBlock, remainingBytesToRead); - int offsetInBlock = offsetInBlock(plainStreamPosition); - System.arraycopy(currentPlainBlock, offsetInBlock, b, resultBufferOffset, bytesToCopy); - totalBytesRead += bytesToCopy; - remainingBytesToRead -= bytesToCopy; - resultBufferOffset += bytesToCopy; - this.plainStreamPosition += bytesToCopy; - } else if (available() > 0) { - decryptBlock(blockIndex(plainStreamPosition)); - - } else { - break; - } - } - - // return -1 for EOF - return totalBytesRead > 0 ? totalBytesRead : -1; - } - - @Override - public void seek(long newPos) throws IOException { - if (newPos < 0) { - throw new IOException("Invalid position: " + newPos); - } else if (newPos > plainStreamSize) { - throw new EOFException( - "Invalid position: " + newPos + " > stream length, " + plainStreamSize); - } - - this.plainStreamPosition = newPos; - } - - @Override - public long skip(long n) { - if (n <= 0) { - return 0; - } - - long bytesLeftInStream = plainStreamSize - plainStreamPosition; - if (n > bytesLeftInStream) { - // skip the rest of the stream - this.plainStreamPosition = plainStreamSize; - return bytesLeftInStream; - } - - this.plainStreamPosition += n; - - return n; - } - - @Override - public long getPos() throws IOException { - return plainStreamPosition; - } - - @Override - public int read() throws IOException { - int read = read(singleByte); - if (read == -1) { - return -1; - } - - return singleByte[0] >= 0 ? singleByte[0] : 256 + singleByte[0]; - } - - @Override - public void close() throws IOException { - sourceStream.close(); - } - - private void decryptBlock(long blockIndex) throws IOException { - if (blockIndex == currentPlainBlockIndex) { - return; - } - - long blockPositionInStream = blockOffset(blockIndex); - if (sourceStream.getPos() != blockPositionInStream) { - if (sourceStream.getPos() == 0) { - validateHeader(); - } - - sourceStream.seek(blockPositionInStream); - } - - boolean isLastBlock = blockIndex == numBlocks - 1; - int cipherBlockSize = isLastBlock ? lastCipherBlockSize : Ciphers.CIPHER_BLOCK_SIZE; - IOUtil.readFully(sourceStream, cipherBlockBuffer, 0, cipherBlockSize); - - byte[] blockAAD = Ciphers.streamBlockAAD(fileAADPrefix, Math.toIntExact(blockIndex)); - decryptor.decrypt(cipherBlockBuffer, 0, cipherBlockSize, currentPlainBlock, 0, blockAAD); - this.currentPlainBlockSize = cipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; - this.currentPlainBlockIndex = blockIndex; - } - - private static long blockIndex(long plainPosition) { - return plainPosition / Ciphers.PLAIN_BLOCK_SIZE; - } - - private static int offsetInBlock(long plainPosition) { - return Math.toIntExact(plainPosition % Ciphers.PLAIN_BLOCK_SIZE); - } - - private static long blockOffset(long blockIndex) { - return blockIndex * Ciphers.CIPHER_BLOCK_SIZE + Ciphers.GCM_STREAM_HEADER_LENGTH; - } - - static long calculatePlaintextLength(long sourceLength) { - long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; - - if (streamLength == 0) { - return 0; - } - - long numberOfFullBlocks = streamLength / Ciphers.CIPHER_BLOCK_SIZE; - long fullBlockSize = numberOfFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; - long cipherBytesInLastBlock = streamLength - fullBlockSize; - boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); - long plainBytesInLastBlock = - fullBlocksOnly - ? 0 - : (cipherBytesInLastBlock - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH); - - return (numberOfFullBlocks * Ciphers.PLAIN_BLOCK_SIZE) + plainBytesInLastBlock; - } -} diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java deleted file mode 100644 index b4f723cca3e7..000000000000 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.encryption; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import org.apache.iceberg.io.PositionOutputStream; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; - -public class AesGcmOutputStream extends PositionOutputStream { - - private static final byte[] HEADER_BYTES = - ByteBuffer.allocate(Ciphers.GCM_STREAM_HEADER_LENGTH) - .order(ByteOrder.LITTLE_ENDIAN) - .put(Ciphers.GCM_STREAM_MAGIC_ARRAY) - .putInt(Ciphers.PLAIN_BLOCK_SIZE) - .array(); - - private final Ciphers.AesGcmEncryptor gcmEncryptor; - private final PositionOutputStream targetStream; - private final byte[] fileAadPrefix; - private final byte[] singleByte; - private final byte[] plainBlock; - private final byte[] cipherBlock; - - private int positionInPlainBlock; - private int currentBlockIndex; - private boolean isHeaderWritten; - private boolean lastBlockWritten; - private boolean isClosed; - private long finalPosition; - - AesGcmOutputStream(PositionOutputStream targetStream, byte[] aesKey, byte[] fileAadPrefix) { - this.targetStream = targetStream; - this.gcmEncryptor = new Ciphers.AesGcmEncryptor(aesKey); - this.fileAadPrefix = fileAadPrefix; - this.singleByte = new byte[1]; - this.plainBlock = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - this.cipherBlock = new byte[Ciphers.CIPHER_BLOCK_SIZE]; - this.positionInPlainBlock = 0; - this.currentBlockIndex = 0; - this.isHeaderWritten = false; - this.lastBlockWritten = false; - this.isClosed = false; - this.finalPosition = 0; - } - - @Override - public void write(int b) throws IOException { - singleByte[0] = (byte) (b & 0x000000FF); - write(singleByte); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - if (isClosed) { - throw new IOException("Writing to closed stream"); - } - - if (!isHeaderWritten) { - writeHeader(); - } - - if (b.length - off < len) { - throw new IOException( - "Insufficient bytes in buffer: " + b.length + " - " + off + " < " + len); - } - - int remaining = len; - int offset = off; - - while (remaining > 0) { - int freeBlockBytes = plainBlock.length - positionInPlainBlock; - int toWrite = Math.min(freeBlockBytes, remaining); - - System.arraycopy(b, offset, plainBlock, positionInPlainBlock, toWrite); - positionInPlainBlock += toWrite; - offset += toWrite; - remaining -= toWrite; - - if (positionInPlainBlock == plainBlock.length) { - encryptAndWriteBlock(); - } - } - } - - @Override - public long getPos() throws IOException { - if (isClosed) { - return finalPosition; - } - - return (long) currentBlockIndex * Ciphers.PLAIN_BLOCK_SIZE + positionInPlainBlock; - } - - @Override - public void flush() throws IOException { - targetStream.flush(); - } - - @Override - public void close() throws IOException { - if (!isHeaderWritten) { - writeHeader(); - } - - finalPosition = getPos(); - isClosed = true; - - encryptAndWriteBlock(); - - targetStream.close(); - } - - private void writeHeader() throws IOException { - targetStream.write(HEADER_BYTES); - isHeaderWritten = true; - } - - private void encryptAndWriteBlock() throws IOException { - Preconditions.checkState( - !lastBlockWritten, "Cannot encrypt block: a partial block has already been written"); - - if (currentBlockIndex == Integer.MAX_VALUE) { - throw new IOException("Cannot write block: exceeded Integer.MAX_VALUE blocks"); - } - - if (positionInPlainBlock == 0 && currentBlockIndex != 0) { - return; - } - - if (positionInPlainBlock != plainBlock.length) { - // signal that a partial block has been written and must be the last - this.lastBlockWritten = true; - } - - byte[] aad = Ciphers.streamBlockAAD(fileAadPrefix, currentBlockIndex); - int ciphertextLength = - gcmEncryptor.encrypt(plainBlock, 0, positionInPlainBlock, cipherBlock, 0, aad); - targetStream.write(cipherBlock, 0, ciphertextLength); - positionInPlainBlock = 0; - currentBlockIndex++; - } -} From d7540ee330e9ef047354c95967aa09dfd5b7bfd9 Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Tue, 16 Jan 2024 08:12:07 +0200 Subject: [PATCH 4/6] rebase --- .../iceberg/encryption/TestGcmStreams.java | 416 ------------------ 1 file changed, 416 deletions(-) delete mode 100644 core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java deleted file mode 100644 index a954cf760baa..000000000000 --- a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java +++ /dev/null @@ -1,416 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.encryption; - -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.file.StandardOpenOption; -import java.util.Arrays; -import java.util.Random; -import javax.crypto.AEADBadTagException; -import org.apache.iceberg.Files; -import org.apache.iceberg.io.PositionOutputStream; -import org.apache.iceberg.io.SeekableInputStream; -import org.assertj.core.api.Assertions; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; - -public class TestGcmStreams { - - @Rule public TemporaryFolder temp = new TemporaryFolder(); - - @Test - public void testEmptyFile() throws IOException { - Random random = new Random(); - byte[] key = new byte[16]; - random.nextBytes(key); - byte[] aadPrefix = new byte[16]; - random.nextBytes(aadPrefix); - byte[] readBytes = new byte[1]; - - File testFile = temp.newFile(); - - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); - encryptedStream.close(); - - AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - Assert.assertEquals("File size", 0, decryptedFile.getLength()); - - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - Assert.assertEquals("Read empty stream", -1, decryptedStream.read(readBytes)); - } - - // check that the AAD is still verified, even for an empty file - byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); - badAAD[1] -= 1; // modify the AAD slightly - AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); - Assert.assertEquals("File size", 0, badAADFile.getLength()); - - try (SeekableInputStream decryptedStream = badAADFile.newStream()) { - Assertions.assertThatThrownBy(() -> decryptedStream.read(readBytes)) - .isInstanceOf(RuntimeException.class) - .hasCauseInstanceOf(AEADBadTagException.class) - .hasMessageContaining("GCM tag check failed"); - } - } - - @Test - public void testAADValidation() throws IOException { - Random random = new Random(); - byte[] key = new byte[16]; - random.nextBytes(key); - byte[] aadPrefix = new byte[16]; - random.nextBytes(aadPrefix); - byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block - random.nextBytes(content); - - File testFile = temp.newFile(); - - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { - encryptedStream.write(content); - } - - // verify the data can be read correctly with the right AAD - AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - Assert.assertEquals("File size", content.length, decryptedFile.getLength()); - - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - int bytesRead = decryptedStream.read(readContent); - Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); - Assert.assertEquals( - "Content should match", - ByteBuffer.wrap(content), - ByteBuffer.wrap(readContent, 0, bytesRead)); - } - - // test with the wrong AAD - byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); - badAAD[1] -= 1; // modify the AAD slightly - AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); - Assert.assertEquals("File size", content.length, badAADFile.getLength()); - - try (SeekableInputStream decryptedStream = badAADFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) - .isInstanceOf(RuntimeException.class) - .hasCauseInstanceOf(AEADBadTagException.class) - .hasMessageContaining("GCM tag check failed"); - } - - // modify the file contents - try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { - long lastTagPosition = testFile.length() - Ciphers.GCM_TAG_LENGTH; - out.position(lastTagPosition); - out.write(ByteBuffer.wrap(key)); // overwrite the tag with other random bytes (the key) - } - - // read with the correct AAD and verify the tag check fails - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) - .isInstanceOf(RuntimeException.class) - .hasCauseInstanceOf(AEADBadTagException.class) - .hasMessageContaining("GCM tag check failed"); - } - } - - @Test - public void testCorruptNonce() throws IOException { - Random random = new Random(); - byte[] key = new byte[16]; - random.nextBytes(key); - byte[] aadPrefix = new byte[16]; - random.nextBytes(aadPrefix); - byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block - random.nextBytes(content); - - File testFile = temp.newFile(); - - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { - encryptedStream.write(content); - } - - // verify the data can be read correctly with the right AAD - AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - Assert.assertEquals("File size", content.length, decryptedFile.getLength()); - - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - int bytesRead = decryptedStream.read(readContent); - Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); - Assert.assertEquals( - "Content should match", - ByteBuffer.wrap(content), - ByteBuffer.wrap(readContent, 0, bytesRead)); - } - - // replace the first block's nonce - try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { - out.position(Ciphers.GCM_STREAM_HEADER_LENGTH); - // overwrite the nonce with other random bytes (the key) - out.write(ByteBuffer.wrap(key, 0, Ciphers.NONCE_LENGTH)); - } - - // read with the correct AAD and verify the read fails - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) - .isInstanceOf(RuntimeException.class) - .hasCauseInstanceOf(AEADBadTagException.class) - .hasMessageContaining("GCM tag check failed"); - } - } - - @Test - public void testCorruptCiphertext() throws IOException { - Random random = new Random(); - byte[] key = new byte[16]; - random.nextBytes(key); - byte[] aadPrefix = new byte[16]; - random.nextBytes(aadPrefix); - byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block - random.nextBytes(content); - - File testFile = temp.newFile(); - - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { - encryptedStream.write(content); - } - - // verify the data can be read correctly with the right AAD - AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - Assert.assertEquals("File size", content.length, decryptedFile.getLength()); - - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - int bytesRead = decryptedStream.read(readContent); - Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); - Assert.assertEquals( - "Content should match", - ByteBuffer.wrap(content), - ByteBuffer.wrap(readContent, 0, bytesRead)); - } - - // replace part of the first block's content - try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { - out.position(Ciphers.GCM_STREAM_HEADER_LENGTH + Ciphers.NONCE_LENGTH + 34); - // overwrite the nonce with other random bytes (the key) - out.write(ByteBuffer.wrap(key)); - } - - // read with the correct AAD and verify the read fails - try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { - byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; - Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) - .isInstanceOf(RuntimeException.class) - .hasCauseInstanceOf(AEADBadTagException.class) - .hasMessageContaining("GCM tag check failed"); - } - } - - @Test - public void testRandomWriteRead() throws IOException { - Random random = new Random(); - int smallerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 0.5); - int largerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 1.5); - int alignedWithBlock = Ciphers.PLAIN_BLOCK_SIZE; - int[] testFileSizes = { - smallerThanBlock, - largerThanBlock, - alignedWithBlock, - alignedWithBlock - 1, - alignedWithBlock + 1 - }; - - for (int testFileSize : testFileSizes) { - byte[] testFileContents = new byte[testFileSize]; - random.nextBytes(testFileContents); - int[] aesKeyLengthArray = {16, 24, 32}; - byte[] aadPrefix = new byte[16]; - for (int keyLength : aesKeyLengthArray) { - byte[] key = new byte[keyLength]; - random.nextBytes(key); - random.nextBytes(aadPrefix); - File testFile = temp.newFile(); - - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); - - int maxChunkLen = testFileSize / 5; - int offset = 0; - int left = testFileSize; - - while (left > 0) { - int chunkLen = random.nextInt(maxChunkLen); - if (chunkLen > left) { - chunkLen = left; - } - encryptedStream.write(testFileContents, offset, chunkLen); - offset += chunkLen; - Assert.assertEquals("Position", offset, encryptedStream.getPos()); - left -= chunkLen; - } - - encryptedStream.close(); - Assert.assertEquals("Final position in closed stream", offset, encryptedStream.getPos()); - - AesGcmInputFile decryptedFile = - new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - SeekableInputStream decryptedStream = decryptedFile.newStream(); - Assert.assertEquals("File size", testFileSize, decryptedFile.getLength()); - - byte[] chunk = new byte[testFileSize]; - - // Test seek and read - for (int n = 0; n < 100; n++) { - int chunkLen = random.nextInt(testFileSize); - int pos = random.nextInt(testFileSize); - left = testFileSize - pos; - - if (left < chunkLen) { - chunkLen = left; - } - - decryptedStream.seek(pos); - int len = decryptedStream.read(chunk, 0, chunkLen); - Assert.assertEquals("Read length", len, chunkLen); - long pos2 = decryptedStream.getPos(); - Assert.assertEquals("Position", pos + len, pos2); - - ByteBuffer bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); - ByteBuffer bb2 = ByteBuffer.wrap(testFileContents, pos, chunkLen); - Assert.assertEquals("Read contents", bb1, bb2); - - // Test skip - long toSkip = random.nextInt(testFileSize); - long skipped = decryptedStream.skip(toSkip); - - if (pos2 + toSkip < testFileSize) { - Assert.assertEquals("Skipped", toSkip, skipped); - } else { - Assert.assertEquals("Skipped", (testFileSize - pos2), skipped); - } - - int pos3 = (int) decryptedStream.getPos(); - Assert.assertEquals("Position", pos2 + skipped, pos3); - - chunkLen = random.nextInt(testFileSize); - left = testFileSize - pos3; - - if (left < chunkLen) { - chunkLen = left; - } - - decryptedStream.read(chunk, 0, chunkLen); - bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); - bb2 = ByteBuffer.wrap(testFileContents, pos3, chunkLen); - Assert.assertEquals("Read contents", bb1, bb2); - } - - decryptedStream.close(); - } - } - } - - @Test - public void testAlignedWriteRead() throws IOException { - Random random = new Random(); - int[] testFileSizes = { - Ciphers.PLAIN_BLOCK_SIZE, Ciphers.PLAIN_BLOCK_SIZE + 1, Ciphers.PLAIN_BLOCK_SIZE - 1 - }; - - for (int testFileSize : testFileSizes) { - byte[] testFileContents = new byte[testFileSize]; - random.nextBytes(testFileContents); - byte[] key = new byte[16]; - random.nextBytes(key); - byte[] aadPrefix = new byte[16]; - random.nextBytes(aadPrefix); - - File testFile = temp.newFile(); - AesGcmOutputFile encryptedFile = - new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); - PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); - - int offset = 0; - int chunkLen = Ciphers.PLAIN_BLOCK_SIZE; - int left = testFileSize; - - while (left > 0) { - - if (chunkLen > left) { - chunkLen = left; - } - - encryptedStream.write(testFileContents, offset, chunkLen); - offset += chunkLen; - Assert.assertEquals("Position", offset, encryptedStream.getPos()); - left -= chunkLen; - } - - encryptedStream.close(); - Assert.assertEquals("Final position in closed stream", offset, encryptedStream.getPos()); - - AesGcmInputFile decryptedFile = - new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - SeekableInputStream decryptedStream = decryptedFile.newStream(); - Assert.assertEquals("File size", testFileSize, decryptedFile.getLength()); - - offset = 0; - chunkLen = Ciphers.PLAIN_BLOCK_SIZE; - byte[] chunk = new byte[chunkLen]; - left = testFileSize; - - while (left > 0) { - - if (chunkLen > left) { - chunkLen = left; - } - - decryptedStream.seek(offset); - int len = decryptedStream.read(chunk, 0, chunkLen); - Assert.assertEquals("Read length", len, chunkLen); - Assert.assertEquals("Position", offset + len, decryptedStream.getPos()); - - ByteBuffer bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); - ByteBuffer bb2 = ByteBuffer.wrap(testFileContents, offset, chunkLen); - Assert.assertEquals("Read contents", bb1, bb2); - - offset += len; - left = testFileSize - offset; - } - - decryptedStream.close(); - } - } -} From 9c90db856d4a79ee57742c12761da8cc8219bde3 Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Tue, 16 Jan 2024 08:14:52 +0200 Subject: [PATCH 5/6] fix patch --- .../iceberg/encryption/AesGcmInputStream.java | 252 +++++++++++ .../encryption/AesGcmOutputStream.java | 161 +++++++ .../iceberg/encryption/TestGcmStreams.java | 414 ++++++++++++++++++ 3 files changed, 827 insertions(+) create mode 100644 core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java create mode 100644 core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java create mode 100644 core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java new file mode 100644 index 000000000000..1f52ab3682f8 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.iceberg.io.IOUtil; +import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +public class AesGcmInputStream extends SeekableInputStream { + private final SeekableInputStream sourceStream; + private final byte[] fileAADPrefix; + private final Ciphers.AesGcmDecryptor decryptor; + private final byte[] cipherBlockBuffer; + private final byte[] currentPlainBlock; + private final long numBlocks; + private final int lastCipherBlockSize; + private final long plainStreamSize; + private final byte[] singleByte; + + private long plainStreamPosition; + private long currentPlainBlockIndex; + private int currentPlainBlockSize; + + AesGcmInputStream( + SeekableInputStream sourceStream, long sourceLength, byte[] aesKey, byte[] fileAADPrefix) { + this.sourceStream = sourceStream; + this.fileAADPrefix = fileAADPrefix; + this.decryptor = new Ciphers.AesGcmDecryptor(aesKey); + this.cipherBlockBuffer = new byte[Ciphers.CIPHER_BLOCK_SIZE]; + this.currentPlainBlock = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + this.plainStreamPosition = 0; + this.currentPlainBlockIndex = -1; + this.currentPlainBlockSize = 0; + + long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; + long numFullBlocks = Math.toIntExact(streamLength / Ciphers.CIPHER_BLOCK_SIZE); + long cipherFullBlockLength = numFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; + int cipherBytesInLastBlock = Math.toIntExact(streamLength - cipherFullBlockLength); + boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); + this.numBlocks = fullBlocksOnly ? numFullBlocks : numFullBlocks + 1; + this.lastCipherBlockSize = + fullBlocksOnly ? Ciphers.CIPHER_BLOCK_SIZE : cipherBytesInLastBlock; // never 0 + + long lastPlainBlockSize = + (long) lastCipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; + this.plainStreamSize = + numFullBlocks * Ciphers.PLAIN_BLOCK_SIZE + (fullBlocksOnly ? 0 : lastPlainBlockSize); + this.singleByte = new byte[1]; + } + + private void validateHeader() throws IOException { + byte[] headerBytes = new byte[Ciphers.GCM_STREAM_HEADER_LENGTH]; + IOUtil.readFully(sourceStream, headerBytes, 0, headerBytes.length); + + Preconditions.checkState( + Ciphers.GCM_STREAM_MAGIC.equals(ByteBuffer.wrap(headerBytes, 0, 4)), + "Invalid GCM stream: magic does not match AGS1"); + + int plainBlockSize = ByteBuffer.wrap(headerBytes, 4, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(); + Preconditions.checkState( + plainBlockSize == Ciphers.PLAIN_BLOCK_SIZE, + "Invalid GCM stream: block size %d != %d", + plainBlockSize, + Ciphers.PLAIN_BLOCK_SIZE); + } + + @Override + public int available() { + long maxAvailable = plainStreamSize - plainStreamPosition; + // See InputStream.available contract + if (maxAvailable >= Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) maxAvailable; + } + } + + private int availableInCurrentBlock() { + if (blockIndex(plainStreamPosition) != currentPlainBlockIndex) { + return 0; + } + + return currentPlainBlockSize - offsetInBlock(plainStreamPosition); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Preconditions.checkArgument(len >= 0, "Invalid read length: " + len); + + if (currentPlainBlockIndex < 0) { + decryptBlock(0); + } + + if (available() <= 0 && len > 0) { + return -1; + } + + if (len == 0) { + return 0; + } + + int totalBytesRead = 0; + int resultBufferOffset = off; + int remainingBytesToRead = len; + + while (remainingBytesToRead > 0) { + int availableInBlock = availableInCurrentBlock(); + if (availableInBlock > 0) { + int bytesToCopy = Math.min(availableInBlock, remainingBytesToRead); + int offsetInBlock = offsetInBlock(plainStreamPosition); + System.arraycopy(currentPlainBlock, offsetInBlock, b, resultBufferOffset, bytesToCopy); + totalBytesRead += bytesToCopy; + remainingBytesToRead -= bytesToCopy; + resultBufferOffset += bytesToCopy; + this.plainStreamPosition += bytesToCopy; + } else if (available() > 0) { + decryptBlock(blockIndex(plainStreamPosition)); + + } else { + break; + } + } + + // return -1 for EOF + return totalBytesRead > 0 ? totalBytesRead : -1; + } + + @Override + public void seek(long newPos) throws IOException { + if (newPos < 0) { + throw new IOException("Invalid position: " + newPos); + } else if (newPos > plainStreamSize) { + throw new EOFException( + "Invalid position: " + newPos + " > stream length, " + plainStreamSize); + } + + this.plainStreamPosition = newPos; + } + + @Override + public long skip(long n) { + if (n <= 0) { + return 0; + } + + long bytesLeftInStream = plainStreamSize - plainStreamPosition; + if (n > bytesLeftInStream) { + // skip the rest of the stream + this.plainStreamPosition = plainStreamSize; + return bytesLeftInStream; + } + + this.plainStreamPosition += n; + + return n; + } + + @Override + public long getPos() throws IOException { + return plainStreamPosition; + } + + @Override + public int read() throws IOException { + int read = read(singleByte); + if (read == -1) { + return -1; + } + + return singleByte[0] >= 0 ? singleByte[0] : 256 + singleByte[0]; + } + + @Override + public void close() throws IOException { + sourceStream.close(); + } + + private void decryptBlock(long blockIndex) throws IOException { + if (blockIndex == currentPlainBlockIndex) { + return; + } + + long blockPositionInStream = blockOffset(blockIndex); + if (sourceStream.getPos() != blockPositionInStream) { + if (sourceStream.getPos() == 0) { + validateHeader(); + } + + sourceStream.seek(blockPositionInStream); + } + + boolean isLastBlock = blockIndex == numBlocks - 1; + int cipherBlockSize = isLastBlock ? lastCipherBlockSize : Ciphers.CIPHER_BLOCK_SIZE; + IOUtil.readFully(sourceStream, cipherBlockBuffer, 0, cipherBlockSize); + + byte[] blockAAD = Ciphers.streamBlockAAD(fileAADPrefix, Math.toIntExact(blockIndex)); + decryptor.decrypt(cipherBlockBuffer, 0, cipherBlockSize, currentPlainBlock, 0, blockAAD); + this.currentPlainBlockSize = cipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; + this.currentPlainBlockIndex = blockIndex; + } + + private static long blockIndex(long plainPosition) { + return plainPosition / Ciphers.PLAIN_BLOCK_SIZE; + } + + private static int offsetInBlock(long plainPosition) { + return Math.toIntExact(plainPosition % Ciphers.PLAIN_BLOCK_SIZE); + } + + private static long blockOffset(long blockIndex) { + return blockIndex * Ciphers.CIPHER_BLOCK_SIZE + Ciphers.GCM_STREAM_HEADER_LENGTH; + } + + static long calculatePlaintextLength(long sourceLength) { + long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; + + if (streamLength == 0) { + return 0; + } + + long numberOfFullBlocks = streamLength / Ciphers.CIPHER_BLOCK_SIZE; + long fullBlockSize = numberOfFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; + long cipherBytesInLastBlock = streamLength - fullBlockSize; + boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); + long plainBytesInLastBlock = + fullBlocksOnly + ? 0 + : (cipherBytesInLastBlock - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH); + + return (numberOfFullBlocks * Ciphers.PLAIN_BLOCK_SIZE) + plainBytesInLastBlock; + } +} diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java new file mode 100644 index 000000000000..b4f723cca3e7 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.apache.iceberg.io.PositionOutputStream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +public class AesGcmOutputStream extends PositionOutputStream { + + private static final byte[] HEADER_BYTES = + ByteBuffer.allocate(Ciphers.GCM_STREAM_HEADER_LENGTH) + .order(ByteOrder.LITTLE_ENDIAN) + .put(Ciphers.GCM_STREAM_MAGIC_ARRAY) + .putInt(Ciphers.PLAIN_BLOCK_SIZE) + .array(); + + private final Ciphers.AesGcmEncryptor gcmEncryptor; + private final PositionOutputStream targetStream; + private final byte[] fileAadPrefix; + private final byte[] singleByte; + private final byte[] plainBlock; + private final byte[] cipherBlock; + + private int positionInPlainBlock; + private int currentBlockIndex; + private boolean isHeaderWritten; + private boolean lastBlockWritten; + private boolean isClosed; + private long finalPosition; + + AesGcmOutputStream(PositionOutputStream targetStream, byte[] aesKey, byte[] fileAadPrefix) { + this.targetStream = targetStream; + this.gcmEncryptor = new Ciphers.AesGcmEncryptor(aesKey); + this.fileAadPrefix = fileAadPrefix; + this.singleByte = new byte[1]; + this.plainBlock = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + this.cipherBlock = new byte[Ciphers.CIPHER_BLOCK_SIZE]; + this.positionInPlainBlock = 0; + this.currentBlockIndex = 0; + this.isHeaderWritten = false; + this.lastBlockWritten = false; + this.isClosed = false; + this.finalPosition = 0; + } + + @Override + public void write(int b) throws IOException { + singleByte[0] = (byte) (b & 0x000000FF); + write(singleByte); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (isClosed) { + throw new IOException("Writing to closed stream"); + } + + if (!isHeaderWritten) { + writeHeader(); + } + + if (b.length - off < len) { + throw new IOException( + "Insufficient bytes in buffer: " + b.length + " - " + off + " < " + len); + } + + int remaining = len; + int offset = off; + + while (remaining > 0) { + int freeBlockBytes = plainBlock.length - positionInPlainBlock; + int toWrite = Math.min(freeBlockBytes, remaining); + + System.arraycopy(b, offset, plainBlock, positionInPlainBlock, toWrite); + positionInPlainBlock += toWrite; + offset += toWrite; + remaining -= toWrite; + + if (positionInPlainBlock == plainBlock.length) { + encryptAndWriteBlock(); + } + } + } + + @Override + public long getPos() throws IOException { + if (isClosed) { + return finalPosition; + } + + return (long) currentBlockIndex * Ciphers.PLAIN_BLOCK_SIZE + positionInPlainBlock; + } + + @Override + public void flush() throws IOException { + targetStream.flush(); + } + + @Override + public void close() throws IOException { + if (!isHeaderWritten) { + writeHeader(); + } + + finalPosition = getPos(); + isClosed = true; + + encryptAndWriteBlock(); + + targetStream.close(); + } + + private void writeHeader() throws IOException { + targetStream.write(HEADER_BYTES); + isHeaderWritten = true; + } + + private void encryptAndWriteBlock() throws IOException { + Preconditions.checkState( + !lastBlockWritten, "Cannot encrypt block: a partial block has already been written"); + + if (currentBlockIndex == Integer.MAX_VALUE) { + throw new IOException("Cannot write block: exceeded Integer.MAX_VALUE blocks"); + } + + if (positionInPlainBlock == 0 && currentBlockIndex != 0) { + return; + } + + if (positionInPlainBlock != plainBlock.length) { + // signal that a partial block has been written and must be the last + this.lastBlockWritten = true; + } + + byte[] aad = Ciphers.streamBlockAAD(fileAadPrefix, currentBlockIndex); + int ciphertextLength = + gcmEncryptor.encrypt(plainBlock, 0, positionInPlainBlock, cipherBlock, 0, aad); + targetStream.write(cipherBlock, 0, ciphertextLength); + positionInPlainBlock = 0; + currentBlockIndex++; + } +} diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java new file mode 100644 index 000000000000..cbbe96fb58c3 --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; +import java.util.Random; +import javax.crypto.AEADBadTagException; +import org.apache.iceberg.Files; +import org.apache.iceberg.io.PositionOutputStream; +import org.apache.iceberg.io.SeekableInputStream; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestGcmStreams { + + @Rule public TemporaryFolder temp = new TemporaryFolder(); + + @Test + public void testEmptyFile() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] readBytes = new byte[1]; + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); + encryptedStream.close(); + + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", 0, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + Assert.assertEquals("Read empty stream", -1, decryptedStream.read(readBytes)); + } + + // check that the AAD is still verified, even for an empty file + byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); + badAAD[1] -= 1; // modify the AAD slightly + AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); + Assert.assertEquals("File size", 0, badAADFile.getLength()); + + try (SeekableInputStream decryptedStream = badAADFile.newStream()) { + Assertions.assertThatThrownBy(() -> decryptedStream.read(readBytes)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testAADValidation() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // test with the wrong AAD + byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); + badAAD[1] -= 1; // modify the AAD slightly + AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); + Assert.assertEquals("File size", content.length, badAADFile.getLength()); + + try (SeekableInputStream decryptedStream = badAADFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + + // modify the file contents + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + long lastTagPosition = testFile.length() - Ciphers.GCM_TAG_LENGTH; + out.position(lastTagPosition); + out.write(ByteBuffer.wrap(key)); // overwrite the tag with other random bytes (the key) + } + + // read with the correct AAD and verify the tag check fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testCorruptNonce() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // replace the first block's nonce + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + out.position(Ciphers.GCM_STREAM_HEADER_LENGTH); + // overwrite the nonce with other random bytes (the key) + out.write(ByteBuffer.wrap(key, 0, Ciphers.NONCE_LENGTH)); + } + + // read with the correct AAD and verify the read fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testCorruptCiphertext() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // replace part of the first block's content + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + out.position(Ciphers.GCM_STREAM_HEADER_LENGTH + Ciphers.NONCE_LENGTH + 34); + // overwrite the nonce with other random bytes (the key) + out.write(ByteBuffer.wrap(key)); + } + + // read with the correct AAD and verify the read fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testRandomWriteRead() throws IOException { + Random random = new Random(); + int smallerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 0.5); + int largerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 1.5); + int alignedWithBlock = Ciphers.PLAIN_BLOCK_SIZE; + int[] testFileSizes = { + smallerThanBlock, + largerThanBlock, + alignedWithBlock, + alignedWithBlock - 1, + alignedWithBlock + 1 + }; + + for (int testFileSize : testFileSizes) { + byte[] testFileContents = new byte[testFileSize]; + random.nextBytes(testFileContents); + int[] aesKeyLengthArray = {16, 24, 32}; + byte[] aadPrefix = new byte[16]; + for (int keyLength : aesKeyLengthArray) { + byte[] key = new byte[keyLength]; + random.nextBytes(key); + random.nextBytes(aadPrefix); + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); + + int maxChunkLen = testFileSize / 5; + int offset = 0; + int left = testFileSize; + + while (left > 0) { + int chunkLen = random.nextInt(maxChunkLen); + if (chunkLen > left) { + chunkLen = left; + } + encryptedStream.write(testFileContents, offset, chunkLen); + offset += chunkLen; + Assert.assertEquals("Position", offset, encryptedStream.getPos()); + left -= chunkLen; + } + + encryptedStream.close(); + + AesGcmInputFile decryptedFile = + new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + SeekableInputStream decryptedStream = decryptedFile.newStream(); + Assert.assertEquals("File size", testFileSize, decryptedFile.getLength()); + + byte[] chunk = new byte[testFileSize]; + + // Test seek and read + for (int n = 0; n < 100; n++) { + int chunkLen = random.nextInt(testFileSize); + int pos = random.nextInt(testFileSize); + left = testFileSize - pos; + + if (left < chunkLen) { + chunkLen = left; + } + + decryptedStream.seek(pos); + int len = decryptedStream.read(chunk, 0, chunkLen); + Assert.assertEquals("Read length", len, chunkLen); + long pos2 = decryptedStream.getPos(); + Assert.assertEquals("Position", pos + len, pos2); + + ByteBuffer bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); + ByteBuffer bb2 = ByteBuffer.wrap(testFileContents, pos, chunkLen); + Assert.assertEquals("Read contents", bb1, bb2); + + // Test skip + long toSkip = random.nextInt(testFileSize); + long skipped = decryptedStream.skip(toSkip); + + if (pos2 + toSkip < testFileSize) { + Assert.assertEquals("Skipped", toSkip, skipped); + } else { + Assert.assertEquals("Skipped", (testFileSize - pos2), skipped); + } + + int pos3 = (int) decryptedStream.getPos(); + Assert.assertEquals("Position", pos2 + skipped, pos3); + + chunkLen = random.nextInt(testFileSize); + left = testFileSize - pos3; + + if (left < chunkLen) { + chunkLen = left; + } + + decryptedStream.read(chunk, 0, chunkLen); + bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); + bb2 = ByteBuffer.wrap(testFileContents, pos3, chunkLen); + Assert.assertEquals("Read contents", bb1, bb2); + } + + decryptedStream.close(); + } + } + } + + @Test + public void testAlignedWriteRead() throws IOException { + Random random = new Random(); + int[] testFileSizes = { + Ciphers.PLAIN_BLOCK_SIZE, Ciphers.PLAIN_BLOCK_SIZE + 1, Ciphers.PLAIN_BLOCK_SIZE - 1 + }; + + for (int testFileSize : testFileSizes) { + byte[] testFileContents = new byte[testFileSize]; + random.nextBytes(testFileContents); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + + File testFile = temp.newFile(); + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); + + int offset = 0; + int chunkLen = Ciphers.PLAIN_BLOCK_SIZE; + int left = testFileSize; + + while (left > 0) { + + if (chunkLen > left) { + chunkLen = left; + } + + encryptedStream.write(testFileContents, offset, chunkLen); + offset += chunkLen; + Assert.assertEquals("Position", offset, encryptedStream.getPos()); + left -= chunkLen; + } + + encryptedStream.close(); + + AesGcmInputFile decryptedFile = + new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + SeekableInputStream decryptedStream = decryptedFile.newStream(); + Assert.assertEquals("File size", testFileSize, decryptedFile.getLength()); + + offset = 0; + chunkLen = Ciphers.PLAIN_BLOCK_SIZE; + byte[] chunk = new byte[chunkLen]; + left = testFileSize; + + while (left > 0) { + + if (chunkLen > left) { + chunkLen = left; + } + + decryptedStream.seek(offset); + int len = decryptedStream.read(chunk, 0, chunkLen); + Assert.assertEquals("Read length", len, chunkLen); + Assert.assertEquals("Position", offset + len, decryptedStream.getPos()); + + ByteBuffer bb1 = ByteBuffer.wrap(chunk, 0, chunkLen); + ByteBuffer bb2 = ByteBuffer.wrap(testFileContents, offset, chunkLen); + Assert.assertEquals("Read contents", bb1, bb2); + + offset += len; + left = testFileSize - offset; + } + + decryptedStream.close(); + } + } +} From 9ad9801a5e9d6b3efb2f2ce3d8e31733f2d3ba51 Mon Sep 17 00:00:00 2001 From: Gidon Gershinsky Date: Tue, 16 Jan 2024 09:14:02 +0200 Subject: [PATCH 6/6] fix patch 2 --- .../test/java/org/apache/iceberg/encryption/TestGcmStreams.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java index cbbe96fb58c3..a954cf760baa 100644 --- a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java +++ b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java @@ -282,6 +282,7 @@ public void testRandomWriteRead() throws IOException { } encryptedStream.close(); + Assert.assertEquals("Final position in closed stream", offset, encryptedStream.getPos()); AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); @@ -378,6 +379,7 @@ public void testAlignedWriteRead() throws IOException { } encryptedStream.close(); + Assert.assertEquals("Final position in closed stream", offset, encryptedStream.getPos()); AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix);