Skip to content

Commit

Permalink
512-bit vectors in utf8 validator
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrrzysko committed Nov 12, 2023
1 parent 32d9059 commit 3d91a69
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 58 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repositories {

java {
toolchain {
languageVersion = JavaLanguageVersion.of(18)
languageVersion = JavaLanguageVersion.of(21)
}
withJavadocJar()
withSourcesJar()
Expand Down
82 changes: 31 additions & 51 deletions src/main/java/org/simdjson/Utf8Validator.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,22 @@

import java.util.Arrays;

public class Utf8Validator {
private static final VectorSpecies<Byte> VECTOR_SPECIES = ByteVector.SPECIES_256;
class Utf8Validator {
private static final VectorSpecies<Byte> VECTOR_SPECIES = StructuralIndexer.SPECIES;
private static final ByteVector INCOMPLETE_CHECK = getIncompleteCheck();
private static final VectorShuffle<Integer> SHIFT_FOUR_BYTES_FORWARD = VectorShuffle.iota(IntVector.SPECIES_256,
IntVector.SPECIES_256.elementSize() - 1, 1, true);
private static final ByteVector LOW_NIBBLE_MASK = ByteVector.broadcast(VECTOR_SPECIES, 0b0000_1111);
private static final byte LOW_NIBBLE_MASK = 0x0f;
private static final ByteVector ALL_ASCII_MASK = ByteVector.broadcast(VECTOR_SPECIES, (byte) 0b1000_0000);


/**
* Validate the input bytes are valid UTF8
*
* @param inputBytes the input bytes to validate
* @throws JsonParsingException if the input is not valid UTF8
*/
public static void validate(byte[] inputBytes) {
long previousIncomplete = 0;
long errors = 0;
int previousFourUtf8Bytes = 0;
static void validate(byte[] inputBytes) {
boolean previousIncomplete = false;
boolean errors = false;
ByteVector prevChunk = ByteVector.zero(VECTOR_SPECIES);

int idx = 0;
for (; idx < VECTOR_SPECIES.loopBound(inputBytes.length); idx += VECTOR_SPECIES.vectorByteSize()) {
Expand All @@ -33,14 +30,12 @@ public static void validate(byte[] inputBytes) {
} else {
previousIncomplete = isIncomplete(utf8Vector);

var fourBytesPrevious = fourBytesPreviousSlice(utf8Vector, previousFourUtf8Bytes);

ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious);
ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious, firstCheck);
ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector, prevChunk);
ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector, prevChunk, firstCheck);

errors |= secondCheck.compare(VectorOperators.NE, 0).toLong();
errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue();
}
previousFourUtf8Bytes = utf8Vector.reinterpretAsInts().lane(IntVector.SPECIES_256.length() - 1);
prevChunk = utf8Vector;
}

// if the input file doesn't align with the vector width, pad the missing bytes with zero
Expand All @@ -51,47 +46,26 @@ public static void validate(byte[] inputBytes) {
} else {
previousIncomplete = isIncomplete(lastVectorChunk);

var fourBytesPrevious = fourBytesPreviousSlice(lastVectorChunk, previousFourUtf8Bytes);

ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious);
ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious, firstCheck);
ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk, prevChunk);
ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk, prevChunk, firstCheck);

errors |= secondCheck.compare(VectorOperators.NE, 0).toLong();
errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue();
}

if ((errors | previousIncomplete) != 0) {
if (errors | previousIncomplete) {
throw new JsonParsingException("Invalid UTF8");
}
}

/* Shuffles the input forward by four bytes to make space for the previous four bytes.
The previous three bytes are required for validation, pulling in the last integer will give the previous four bytes.
The switch to integer vectors is to allow for integer shifting instead of the more expensive shuffle / slice operations */
private static IntVector fourBytesPreviousSlice(ByteVector vectorChunk, int previousFourUtf8Bytes) {
return vectorChunk.reinterpretAsInts()
.rearrange(SHIFT_FOUR_BYTES_FORWARD)
.withLane(0, previousFourUtf8Bytes);
}

// works similar to previousUtf8Vector.slice(VECTOR_SPECIES.length() - numOfBytesToInclude, utf8Vector) but without the performance cost
private static ByteVector previousVectorSlice(IntVector utf8Vector, IntVector fourBytesPrevious, int numOfPreviousBytes) {
return utf8Vector
.lanewise(VectorOperators.LSHL, Byte.SIZE * numOfPreviousBytes)
.or(fourBytesPrevious.lanewise(VectorOperators.LSHR, Byte.SIZE * (4 - numOfPreviousBytes)))
.reinterpretAsBytes();
}

private static ByteVector firstTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious) {
private static ByteVector firstTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk) {
// shift the current input forward by 1 byte to include 1 byte from the previous input
var oneBytePrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 1);
var oneBytePrevious = concatenate(utf8Vector, prevChunk, 1);

// high nibbles of the current input (e.g. 0xC3 >> 4 = 0xC)
ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4)
.reinterpretAsBytes().and(LOW_NIBBLE_MASK);
ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4);

// high nibbles of the shifted input
ByteVector byte1HighNibbles = oneBytePrevious.reinterpretAsInts().lanewise(VectorOperators.LSHR, 4)
.reinterpretAsBytes().and(LOW_NIBBLE_MASK);
ByteVector byte1HighNibbles = oneBytePrevious.lanewise(VectorOperators.LSHR, 4);

// low nibbles of the shifted input (e.g. 0xC3 & 0xF = 0x3)
ByteVector byte1LowNibbles = oneBytePrevious.and(LOW_NIBBLE_MASK);
Expand All @@ -104,20 +78,26 @@ private static ByteVector firstTwoByteSequenceCheck(IntVector utf8Vector, IntVec
}

// All remaining checks are invalid 3–4 byte sequences, which either have too many continuations bytes or not enough
private static ByteVector lastTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious, ByteVector firstCheck) {
private static ByteVector lastTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk, ByteVector firstCheck) {
// the minimum 3byte lead - 1110_0000 is always greater than the max 2byte lead - 110_11111
ByteVector twoBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 2);
ByteVector twoBytesPrevious = concatenate(utf8Vector, prevChunk, 2);

VectorMask<Byte> is3ByteLead = twoBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b110_11111);

// the minimum 4byte lead - 1111_0000 is always greater than the max 3byte lead - 1110_1111
ByteVector threeBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 3);
ByteVector threeBytesPrevious = concatenate(utf8Vector, prevChunk, 3);

VectorMask<Byte> is4ByteLead = threeBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b1110_1111);

// the firstCheck vector contains 0x80 values on continuation byte indexes
// the 3/4 byte lead bytes should match up with these indexes and zero them out
return firstCheck.add((byte) 0x80, is3ByteLead.or(is4ByteLead));
}

private static ByteVector concatenate(ByteVector curr, ByteVector prev, int byteCountFromPrev) {
return prev.slice(VECTOR_SPECIES.length() - byteCountFromPrev, curr);
}

/* checks that the previous vector isn't in an incomplete state.
Previous vector is in an incomplete state if the last byte is smaller than 0xC0,
or the second last byte is smaller than 0xE0, or the third last byte is smaller than 0xF0.*/
Expand All @@ -131,12 +111,12 @@ private static ByteVector getIncompleteCheck() {
return ByteVector.fromArray(VECTOR_SPECIES, eofArray, 0);
}

protected static long isIncomplete(ByteVector utf8Vector) {
return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).toLong();
private static boolean isIncomplete(ByteVector utf8Vector) {
return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).anyTrue();
}

// ASCII will never exceed 01111_1111
protected static boolean isAscii(ByteVector utf8Vector) {
private static boolean isAscii(ByteVector utf8Vector) {
return utf8Vector.and(ALL_ASCII_MASK).compare(VectorOperators.EQ, 0).allTrue();
}

Expand Down
11 changes: 5 additions & 6 deletions src/test/java/org/simdjson/Utf8ValidatorTest.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package org.simdjson;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorSpecies;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

class Utf8ValidatorTest {
private static final VectorSpecies<Byte> VECTOR_SPECIES = ByteVector.SPECIES_256;
private static final VectorSpecies<Byte> VECTOR_SPECIES = StructuralIndexer.SPECIES;


/* ASCII / 1 BYTE TESTS */
Expand Down Expand Up @@ -482,14 +481,14 @@ void validate_continuationThreeBytesTooShort_4Byte_eof_invalid() {
@ParameterizedTest
@ValueSource(strings = {"/twitter.json", "/nhkworld.json"})
void validate_utf8InputFiles_valid(String inputFilePath) throws IOException {
byte[] inputBytes = Objects.requireNonNull(Utf8ValidatorTest.class.getResourceAsStream(inputFilePath)).readAllBytes();
byte[] inputBytes = TestUtils.loadTestFile(inputFilePath);
SimdJsonParser parser = new SimdJsonParser();
assertThatCode(() -> parser.parse(inputBytes, inputBytes.length)).doesNotThrowAnyException();
}

@Test
void validate_utf8InputFile_invalid() throws IOException {
byte[] inputBytes = Objects.requireNonNull(Utf8ValidatorTest.class.getResourceAsStream("/malformed.txt")).readAllBytes();
byte[] inputBytes = TestUtils.loadTestFile("/malformed.txt");
SimdJsonParser parser = new SimdJsonParser();
assertThatExceptionOfType(JsonParsingException.class)
.isThrownBy(() -> parser.parse(inputBytes, inputBytes.length))
Expand Down

0 comments on commit 3d91a69

Please sign in to comment.