diff --git a/lang/java/avro/src/main/java/org/apache/avro/file/DataFileReader.java b/lang/java/avro/src/main/java/org/apache/avro/file/DataFileReader.java index 8f333a1cb48..ae33df59fbe 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/file/DataFileReader.java +++ b/lang/java/avro/src/main/java/org/apache/avro/file/DataFileReader.java @@ -20,7 +20,7 @@ import org.apache.avro.InvalidAvroMagicException; import org.apache.avro.io.DatumReader; import org.apache.avro.io.DecoderFactory; -import org.apache.commons.compress.utils.IOUtils; +import org.apache.commons.io.IOUtils; import java.io.EOFException; import java.io.File; @@ -37,7 +37,7 @@ * @see DataFileWriter */ public class DataFileReader extends DataFileStream implements FileReader { - private SeekableInputStream sin; + private final SeekableInputStream sin; private long blockStart; private int[] partialMatchTable; @@ -264,9 +264,9 @@ public long tell() throws IOException { static class SeekableInputStream extends InputStream implements SeekableInput { private final byte[] oneByte = new byte[1]; - private SeekableInput in; + private final SeekableInput in; - SeekableInputStream(SeekableInput in) throws IOException { + SeekableInputStream(SeekableInput in) { this.in = in; } @@ -310,15 +310,10 @@ public int read() throws IOException { @Override public long skip(long skip) throws IOException { long position = in.tell(); + long skipToPosition = position + skip; long length = in.length(); - long remaining = length - position; - if (remaining > skip) { - in.seek(skip); - return in.tell() - position; - } else { - in.seek(remaining); - return in.tell() - position; - } + in.seek(Math.min(skipToPosition, length)); + return in.tell() - position; } @Override @@ -330,7 +325,7 @@ public void close() throws IOException { @Override public int available() throws IOException { long remaining = (in.length() - in.tell()); - return (remaining > Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) remaining; + return (int) Math.min(remaining, Integer.MAX_VALUE); } } } diff --git a/lang/java/avro/src/main/java/org/apache/avro/file/SeekableByteArrayInput.java b/lang/java/avro/src/main/java/org/apache/avro/file/SeekableByteArrayInput.java index 991fc44b4e8..49994a9bc8e 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/file/SeekableByteArrayInput.java +++ b/lang/java/avro/src/main/java/org/apache/avro/file/SeekableByteArrayInput.java @@ -18,6 +18,7 @@ package org.apache.avro.file; import java.io.ByteArrayInputStream; +import java.io.EOFException; import java.io.IOException; /** A {@link SeekableInput} backed with data in a byte array. */ @@ -34,8 +35,12 @@ public long length() throws IOException { @Override public void seek(long p) throws IOException { - this.reset(); - this.skip(p); + if (p >= this.count) { + throw new EOFException(); + } + if (p >= 0) { + this.pos = (int) p; + } } @Override diff --git a/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableByteArrayInput.java b/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableByteArrayInput.java index bf6103c6fd8..2e6b46e5d1f 100644 --- a/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableByteArrayInput.java +++ b/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableByteArrayInput.java @@ -17,12 +17,6 @@ */ package org.apache.avro.file; -import static org.junit.jupiter.api.Assertions.*; - -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; -import java.util.List; - import org.apache.avro.Schema; import org.apache.avro.Schema.Field; import org.apache.avro.Schema.Type; @@ -34,6 +28,18 @@ import org.apache.avro.util.Utf8; import org.junit.jupiter.api.Test; +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + public class TestSeekableByteArrayInput { private byte[] getSerializedMessage(IndexedRecord message, Schema schema) throws Exception { @@ -68,7 +74,34 @@ void serialization() throws Exception { result = dfr.next(); } assertNotNull(result); - assertTrue(result instanceof GenericRecord); + assertInstanceOf(GenericRecord.class, result); assertEquals(new Utf8("testValue"), ((GenericRecord) result).get("name")); } + + @Test + void readingData() throws IOException { + byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8); + byte[] result = new byte[16]; + try (SeekableInput in = new SeekableByteArrayInput(data)) { + in.read(result, 0, 8); + in.seek(4); + in.read(result, 8, 8); + assertEquals(12, in.tell()); + assertEquals(data.length, in.length()); + assertEquals("01234567456789AB", new String(result, StandardCharsets.UTF_8)); + } + } + + @Test + void illegalSeeks() throws IOException { + byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8); + try (SeekableInput in = new SeekableByteArrayInput(data)) { + byte[] buf = new byte[2]; + in.read(buf, 0, buf.length); + in.seek(-4); + assertEquals(2, in.tell()); + + assertThrows(EOFException.class, () -> in.seek(64)); + } + } } diff --git a/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableInputStream.java b/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableInputStream.java new file mode 100644 index 00000000000..34dbf298215 --- /dev/null +++ b/lang/java/avro/src/test/java/org/apache/avro/file/TestSeekableInputStream.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.avro.file; + +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TestSeekableInputStream { + @Test + void readingData() throws IOException { + byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8); + try (DataFileReader.SeekableInputStream sin = new DataFileReader.SeekableInputStream( + new SeekableByteArrayInput(data))) { + byte[] first8 = new byte[8]; + assertEquals(first8.length, sin.read(first8, 0, 8)); + assertArrayEquals("01234567".getBytes(StandardCharsets.UTF_8), first8); + sin.seek(4); + assertEquals(10, sin.available()); + assertEquals(2, sin.skip(2)); + assertEquals((byte) '6', sin.read()); + byte[] next4 = new byte[4]; + assertEquals(next4.length, sin.read(next4)); + assertArrayEquals("789A".getBytes(StandardCharsets.UTF_8), next4); + assertEquals(11, sin.tell()); + assertEquals(data.length, sin.length()); + } + } + + @Test + void illegalSeek() throws IOException { + try (SeekableInput in = new SeekableByteArrayInput("".getBytes(StandardCharsets.UTF_8)); + DataFileReader.SeekableInputStream sin = new DataFileReader.SeekableInputStream(in)) { + Assert.assertThrows(IOException.class, () -> sin.seek(-5)); + } + } +}