Skip to content

Commit

Permalink
AVRO-3748: [Java] Fix SeekableInput.skip (#2984)
Browse files Browse the repository at this point in the history
* AVRO-3748: Fix SeekableInput.skip

Two of the implementations of SeekableInput.skip had a bug: skip was
implemented as seek (i.e. using an absolute input position instead of a
relative one). This fixes that.

* AVRO-3748: Avoid reset+skip confusion
  • Loading branch information
opwvhk authored Jun 28, 2024
1 parent 677e982 commit 9443fa9
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.avro.InvalidAvroMagicException;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DecoderFactory;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.IOUtils;

import java.io.EOFException;
import java.io.File;
Expand All @@ -37,7 +37,7 @@
* @see DataFileWriter
*/
public class DataFileReader<D> extends DataFileStream<D> implements FileReader<D> {
private SeekableInputStream sin;
private final SeekableInputStream sin;
private long blockStart;
private int[] partialMatchTable;

Expand Down Expand Up @@ -264,9 +264,9 @@ public long tell() throws IOException {

static class SeekableInputStream extends InputStream implements SeekableInput {
private final byte[] oneByte = new byte[1];
private SeekableInput in;
private final SeekableInput in;

SeekableInputStream(SeekableInput in) throws IOException {
SeekableInputStream(SeekableInput in) {
this.in = in;
}

Expand Down Expand Up @@ -310,15 +310,10 @@ public int read() throws IOException {
@Override
public long skip(long skip) throws IOException {
long position = in.tell();
long skipToPosition = position + skip;
long length = in.length();
long remaining = length - position;
if (remaining > skip) {
in.seek(skip);
return in.tell() - position;
} else {
in.seek(remaining);
return in.tell() - position;
}
in.seek(Math.min(skipToPosition, length));
return in.tell() - position;
}

@Override
Expand All @@ -330,7 +325,7 @@ public void close() throws IOException {
@Override
public int available() throws IOException {
long remaining = (in.length() - in.tell());
return (remaining > Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) remaining;
return (int) Math.min(remaining, Integer.MAX_VALUE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.avro.file;

import java.io.ByteArrayInputStream;
import java.io.EOFException;
import java.io.IOException;

/** A {@link SeekableInput} backed with data in a byte array. */
Expand All @@ -34,8 +35,12 @@ public long length() throws IOException {

@Override
public void seek(long p) throws IOException {
this.reset();
this.skip(p);
if (p >= this.count) {
throw new EOFException();
}
if (p >= 0) {
this.pos = (int) p;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@
*/
package org.apache.avro.file;

import static org.junit.jupiter.api.Assertions.*;

import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;

import org.apache.avro.Schema;
import org.apache.avro.Schema.Field;
import org.apache.avro.Schema.Type;
Expand All @@ -34,6 +28,18 @@
import org.apache.avro.util.Utf8;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class TestSeekableByteArrayInput {

private byte[] getSerializedMessage(IndexedRecord message, Schema schema) throws Exception {
Expand Down Expand Up @@ -68,7 +74,34 @@ void serialization() throws Exception {
result = dfr.next();
}
assertNotNull(result);
assertTrue(result instanceof GenericRecord);
assertInstanceOf(GenericRecord.class, result);
assertEquals(new Utf8("testValue"), ((GenericRecord) result).get("name"));
}

@Test
void readingData() throws IOException {
byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8);
byte[] result = new byte[16];
try (SeekableInput in = new SeekableByteArrayInput(data)) {
in.read(result, 0, 8);
in.seek(4);
in.read(result, 8, 8);
assertEquals(12, in.tell());
assertEquals(data.length, in.length());
assertEquals("01234567456789AB", new String(result, StandardCharsets.UTF_8));
}
}

@Test
void illegalSeeks() throws IOException {
byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8);
try (SeekableInput in = new SeekableByteArrayInput(data)) {
byte[] buf = new byte[2];
in.read(buf, 0, buf.length);
in.seek(-4);
assertEquals(2, in.tell());

assertThrows(EOFException.class, () -> in.seek(64));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.avro.file;

import org.junit.Assert;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class TestSeekableInputStream {
@Test
void readingData() throws IOException {
byte[] data = "0123456789ABCD".getBytes(StandardCharsets.UTF_8);
try (DataFileReader.SeekableInputStream sin = new DataFileReader.SeekableInputStream(
new SeekableByteArrayInput(data))) {
byte[] first8 = new byte[8];
assertEquals(first8.length, sin.read(first8, 0, 8));
assertArrayEquals("01234567".getBytes(StandardCharsets.UTF_8), first8);
sin.seek(4);
assertEquals(10, sin.available());
assertEquals(2, sin.skip(2));
assertEquals((byte) '6', sin.read());
byte[] next4 = new byte[4];
assertEquals(next4.length, sin.read(next4));
assertArrayEquals("789A".getBytes(StandardCharsets.UTF_8), next4);
assertEquals(11, sin.tell());
assertEquals(data.length, sin.length());
}
}

@Test
void illegalSeek() throws IOException {
try (SeekableInput in = new SeekableByteArrayInput("".getBytes(StandardCharsets.UTF_8));
DataFileReader.SeekableInputStream sin = new DataFileReader.SeekableInputStream(in)) {
Assert.assertThrows(IOException.class, () -> sin.seek(-5));
}
}
}

0 comments on commit 9443fa9

Please sign in to comment.