Skip to content

Commit

Permalink
Tests for json validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde committed Mar 30, 2024
1 parent 494162e commit e5420bc
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 49 deletions.
10 changes: 4 additions & 6 deletions arrow-json/src/reader/boolean_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,16 @@ use crate::reader::tape::{Tape, TapeElement};
use crate::reader::ArrayDecoder;

pub struct BooleanArrayDecoder {
is_nullable: bool
is_nullable: bool,
}

impl BooleanArrayDecoder {
pub fn new(is_nullable: bool) -> Self {
Self {
is_nullable
}
Self { is_nullable }
}
}

impl ArrayDecoder for BooleanArrayDecoder {
impl ArrayDecoder for BooleanArrayDecoder {
fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result<ArrayData, ArrowError> {
let mut builder = BooleanBuilder::with_capacity(pos.len());
for p in pos {
Expand All @@ -55,6 +53,6 @@ impl ArrayDecoder for BooleanArrayDecoder {
TapeElement::Null => self.is_nullable,
TapeElement::True | TapeElement::False => true,
_ => false,
}
}
}
}
5 changes: 2 additions & 3 deletions arrow-json/src/reader/decimal_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::marker::PhantomData;

use arrow_array::builder::PrimitiveBuilder;
use arrow_array::types::DecimalType;
use arrow_array::{Array};
use arrow_array::Array;
use arrow_cast::parse::parse_decimal;
use arrow_data::ArrayData;
use arrow_schema::ArrowError;
Expand Down Expand Up @@ -81,8 +81,7 @@ where
TapeElement::Null => self.is_nullable,
TapeElement::String(idx) => {
let s = tape.get_string(idx);
parse_decimal::<D>(s, self.precision, self.scale)
.is_ok()
parse_decimal::<D>(s, self.precision, self.scale).is_ok()
}
TapeElement::Number(idx) => {
let s = tape.get_string(idx);
Expand Down
2 changes: 1 addition & 1 deletion arrow-json/src/reader/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl<O: OffsetSizeTrait> ArrayDecoder for ListArrayDecoder<O> {
return false;
}
}

true
}
}
6 changes: 3 additions & 3 deletions arrow-json/src/reader/map_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,18 @@ impl ArrayDecoder for MapArrayDecoder {
let Ok(value) = tape.next(key, "map key") else {
return false;
};

if let Ok(i) = tape.next(value, "map value") {
cur_idx = i;
} else {
return false;
}
}

if !(self.keys.validate_row(tape, key) && self.values.validate_row(tape, value)) {
return false;
}
}

true
}
}
122 changes: 106 additions & 16 deletions arrow-json/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ use serde::Serialize;
use arrow_array::timezone::Tz;
use arrow_array::types::Float32Type;
use arrow_array::types::*;
use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader, StringArray, StructArray};
use arrow_array::builder::StringBuilder;
use arrow_array::{
downcast_integer, make_array, BooleanArray, RecordBatch, RecordBatchReader, StringArray,
StructArray,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit};
pub use schema::*;
Expand Down Expand Up @@ -291,7 +293,7 @@ impl ReaderBuilder {
/// Sets if the decoder should continue decoding when encountering records that do not
/// match the schema. If set, the schema is made nullable as bad data records will be
/// recorded as null values.
pub fn with_allow_dad_data(self, allow_bad_data: bool) -> Self {
pub fn with_allow_bad_data(self, allow_bad_data: bool) -> Self {
Self {
allow_bad_data,
..self
Expand Down Expand Up @@ -658,12 +660,11 @@ impl Decoder {

// First offset is null sentinel
let mut next_object = 1;
let pos = (0..tape.num_rows())
.map(|_| {
let next = tape.next(next_object, "row").unwrap();
std::mem::replace(&mut next_object, next)
});

let pos = (0..tape.num_rows()).map(|_| {
let next = tape.next(next_object, "row").unwrap();
std::mem::replace(&mut next_object, next)
});

let pos: Vec<_> = if self.allow_bad_data {
// filter out invalid rows before we attempt to deserialize
pos.filter(|p| self.decoder.validate_row(&tape, *p))
Expand All @@ -684,8 +685,17 @@ impl Decoder {

Ok(Some(batch))
}

pub fn flush_with_bad_data(&mut self) -> Result<Option<(RecordBatch, Option<StringArray>)>, ArrowError> {

/// Flushes schema-conforming JSON in the current buffer to a [`RecordBatch`], and returns
/// an BooleanArray that marks good rows and an Option<StringArray> with invalid records, if
/// any exist
///
/// Returns `Ok(None)` if no buffered data
///
/// Note: if called part way through decoding a record, this will return an error
pub fn flush_with_bad_data(
&mut self,
) -> Result<Option<(RecordBatch, BooleanArray, Option<StringArray>)>, ArrowError> {
let tape = self.tape_decoder.finish()?;

if tape.num_rows() == 0 {
Expand All @@ -694,14 +704,18 @@ impl Decoder {

// First offset is null sentinel
let mut next_object = 1;
let mut good_rows = Vec::with_capacity(tape.num_rows());

let (good, bad): (Vec<_>, Vec<_>) = (0..tape.num_rows())
.map(|_| {
let next = tape.next(next_object, "row").unwrap();

std::mem::replace(&mut next_object, next)
})
.partition(|p| {
self.decoder.validate_row(&tape, *p)
let valid = self.decoder.validate_row(&tape, *p);
good_rows.push(valid);
valid
});

let bad_data = if !bad.is_empty() {
Expand All @@ -711,7 +725,7 @@ impl Decoder {
} else {
None
};

let decoded = self.decoder.decode(&tape, &good)?;
self.tape_decoder.clear();

Expand All @@ -722,8 +736,7 @@ impl Decoder {
}
};

Ok(Some((batch, bad_data)))

Ok(Some((batch, good_rows.into(), bad_data)))
}
}

Expand All @@ -737,7 +750,10 @@ trait ArrayDecoder: Send {

macro_rules! primitive_decoder {
($t:ty, $data_type:expr, $is_nullable:expr) => {
Ok(Box::new(PrimitiveArrayDecoder::<$t>::new($data_type, $is_nullable)))
Ok(Box::new(PrimitiveArrayDecoder::<$t>::new(
$data_type,
$is_nullable,
)))
};
}

Expand Down Expand Up @@ -847,6 +863,7 @@ mod tests {
unbuffered = ReaderBuilder::new(schema.clone())
.with_batch_size(batch_size)
.with_coerce_primitive(coerce_primitive)
.with_allow_bad_data(true)
.build(Cursor::new(buf.as_bytes()))
.unwrap()
.collect::<Result<Vec<_>, _>>()
Expand Down Expand Up @@ -2479,4 +2496,77 @@ mod tests {
assert_eq!(b, false);
assert_eq!(c, 10);
}

#[test]
fn test_deserialize_bad_data() {
let j1 = r#"{"a":5,"b":{"d":5},"c":10,"e":[1,2,3]}"#; // valid
let j2 = r#"{"a":5,"b":{"d":"nope"},"c":10}"#; // invalid
let j3 = r#"{"a":5,"c":10}"#; // invalid
let j4 = r#"{"a":5,"b":null,"c":10}"#; // invalid
let j5 = r#"{"a":5,"b":{"d":5},"c":10}"#; // valid
let j6 = r#"{"a":5,"b":{"d":5},"c":10,"e":["hello"]}"#; // invalid
let j7 = r#"{"a":5,"b":{"d":5},"c":10,"e":true}"#; // invalid

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new(
"b",
DataType::Struct(vec![Field::new("d", DataType::Int64, false)].into()),
false,
),
Field::new("c", DataType::Int64, false),
Field::new(
"e",
DataType::List(Arc::new(Field::new("item", DataType::Int64, false))),
true,
),
]));

// allow_bad_data
let mut decoder = ReaderBuilder::new(schema.clone())
.with_batch_size(10)
.with_coerce_primitive(false)
.with_allow_bad_data(true)
.build_decoder()
.unwrap();

decoder.decode(&j1.as_bytes()).unwrap();
decoder.decode(&j2.as_bytes()).unwrap();
decoder.decode(&j3.as_bytes()).unwrap();
decoder.decode(&j4.as_bytes()).unwrap();
decoder.decode(&j5.as_bytes()).unwrap();
decoder.decode(&j6.as_bytes()).unwrap();
decoder.decode(&j7.as_bytes()).unwrap();
let batch = decoder.flush().unwrap().unwrap();
assert_eq!(batch.num_rows(), 2);

// flush_with_bad_data
let mut decoder = ReaderBuilder::new(schema.clone())
.with_batch_size(10)
.with_coerce_primitive(false)
.build_decoder()
.unwrap();

decoder.decode(&j1.as_bytes()).unwrap();
decoder.decode(&j2.as_bytes()).unwrap();
decoder.decode(&j3.as_bytes()).unwrap();
decoder.decode(&j4.as_bytes()).unwrap();
decoder.decode(&j5.as_bytes()).unwrap();
decoder.decode(&j6.as_bytes()).unwrap();
decoder.decode(&j7.as_bytes()).unwrap();

let (good, mask, bad) = decoder.flush_with_bad_data().unwrap().unwrap();
assert_eq!(
mask,
vec![true, false, false, false, true, false, false].into()
);

assert_eq!(good.num_rows(), 2);
let bad = bad.unwrap();
assert_eq!(bad.value(0), j2);
assert_eq!(bad.value(1), j3);
assert_eq!(bad.value(2), j4);
assert_eq!(bad.value(3), j6);
assert_eq!(bad.value(4), j7);
}
}
6 changes: 3 additions & 3 deletions arrow-json/src/reader/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ impl<P: ArrowPrimitiveType> PrimitiveArrayDecoder<P> {
}
}


impl<P> ArrayDecoder for PrimitiveArrayDecoder<P>
where
P: ArrowPrimitiveType + Parser,
Expand Down Expand Up @@ -159,7 +158,7 @@ where

Ok(builder.finish().into_data())
}

fn validate_row(&self, tape: &Tape<'_>, pos: u32) -> bool {
match tape.get(pos) {
TapeElement::Null => self.is_nullable,
Expand All @@ -169,7 +168,8 @@ where
}
TapeElement::Number(idx) => {
let s = tape.get_string(idx);
let v: Option<<P as ArrowPrimitiveType>::Native> = ParseJsonNumber::parse(s.as_bytes());
let v: Option<<P as ArrowPrimitiveType>::Native> =
ParseJsonNumber::parse(s.as_bytes());
v.is_some()
}
TapeElement::F32(v) => {
Expand Down
17 changes: 8 additions & 9 deletions arrow-json/src/reader/string_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,14 @@ impl<O: OffsetSizeTrait> ArrayDecoder for StringArrayDecoder<O> {
match tape.get(pos) {
TapeElement::String(_) => true,
TapeElement::Null => self.is_nullable,
TapeElement::True | TapeElement::False |
TapeElement::Number(_) |
TapeElement::I64(_) | TapeElement::I32(_) |
TapeElement::F32(_) | TapeElement::F64(_) => {
self.coerce_primitive
}
_ => {
false
}
TapeElement::True
| TapeElement::False
| TapeElement::Number(_)
| TapeElement::I64(_)
| TapeElement::I32(_)
| TapeElement::F32(_)
| TapeElement::F64(_) => self.coerce_primitive,
_ => false,
}
}
}
21 changes: 15 additions & 6 deletions arrow-json/src/reader/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ impl ArrayDecoder for StructArrayDecoder {
};

let fields = struct_fields(&self.data_type);

let mut validated_fields = vec![false; fields.len()];

let mut cur_idx = pos + 1;
while cur_idx < end_idx {
// Read field name
Expand All @@ -178,11 +179,11 @@ impl ArrayDecoder for StructArrayDecoder {
match fields.iter().position(|x| x.name() == field_name) {
Some(field_idx) => {
let child_pos = cur_idx + 1;
if !self.decoders[field_idx]
.validate_row(tape, child_pos) {
if !self.decoders[field_idx].validate_row(tape, child_pos) {
return false;
}
}
validated_fields[field_idx] = true;
}
None => {
if self.strict_mode {
return false;
Expand All @@ -198,8 +199,16 @@ impl ArrayDecoder for StructArrayDecoder {
}
}
}

true

validated_fields
.iter()
.zip(fields)
.all(|(validated, field)| {
if !validated && !field.is_nullable() {
return false;
}
true
})
}
}

Expand Down
4 changes: 2 additions & 2 deletions arrow-json/src/reader/timestamp_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ where
if let Ok(d) = string_to_datetime(&self.timezone, s) {
match P::UNIT {
TimeUnit::Nanosecond => d.timestamp_nanos_opt().is_some(),
_ => true
_ => true,
}
} else {
false
Expand All @@ -132,7 +132,7 @@ where
.is_ok()
}
TapeElement::I32(_) | TapeElement::I64(_) => true,
_ => false
_ => false,
}
}
}

0 comments on commit e5420bc

Please sign in to comment.