Skip to content

Commit

Permalink
Errr didn't git add some things
Browse files Browse the repository at this point in the history
  • Loading branch information
desmondcheongzx committed Oct 23, 2024
1 parent 0559142 commit 43c8a6a
Showing 1 changed file with 80 additions and 59 deletions.
139 changes: 80 additions & 59 deletions src/daft-csv/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use core::str;
use std::{
io::Read,
num::NonZeroUsize,
sync::{Arc, Mutex, RwLock, Weak},
ops::{Deref, DerefMut},
sync::{Arc, Weak},
};

use arrow2::{
Expand All @@ -22,6 +23,7 @@ use daft_dsl::{optimization::get_required_columns, Expr};
use daft_io::{IOClient, IOStatsRef};
use daft_table::Table;
use futures::{Stream, StreamExt, TryStreamExt};
use parking_lot::{Mutex, RwLock};
use rayon::{
iter::IndexedParallelIterator,
prelude::{IntoParallelRefIterator, ParallelIterator},
Expand Down Expand Up @@ -135,55 +137,71 @@ use crate::{
// we simply move on to the next chunk of bytes and try to find a valid CSV record there. This is a
// simplification that makes the implementation a lot easier to maintain.

#[derive(Clone, Debug, Default)]
struct CsvSlab(Vec<read::ByteRecord>);

impl CsvSlab {
fn new(record_size: usize, num_fields: usize, num_rows: usize) -> Self {
Self(vec![
read::ByteRecord::with_capacity(record_size, num_fields);
num_rows
])
}
}

impl Deref for CsvSlab {
type Target = Vec<read::ByteRecord>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for CsvSlab {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

/// A pool of ByteRecord slabs. Used for deserializing CSV.
#[derive(Debug)]
struct CsvBufferPool {
buffers: Mutex<Vec<Vec<read::ByteRecord>>>,
buffer_size: usize,
record_buffer_size: usize,
buffers: Mutex<Vec<CsvSlab>>,
record_size: usize,
num_fields: usize,
num_rows: usize,
}

/// A slab of ByteRecords. Used for deserializing CSV.
struct CsvBuffer {
buffer: Vec<read::ByteRecord>,
buffer: CsvSlab,
pool: Weak<CsvBufferPool>,
}

impl CsvBufferPool {
pub fn new(
record_buffer_size: usize,
record_size: usize,
num_fields: usize,
chunk_size_rows: usize,
num_rows: usize,
initial_pool_size: usize,
) -> Self {
let chunk_buffers = vec![
vec![
read::ByteRecord::with_capacity(record_buffer_size, num_fields);
chunk_size_rows
];
initial_pool_size
];
let chunk_buffers =
vec![CsvSlab::new(record_size, num_fields, num_rows); initial_pool_size];
Self {
buffers: Mutex::new(chunk_buffers),
buffer_size: chunk_size_rows,
record_buffer_size,
record_size,
num_fields,
num_rows,
}
}

pub fn get_buffer(self: &Arc<Self>) -> CsvBuffer {
let buffer = {
let mut buffers = self.buffers.lock().unwrap();
let mut buffers = self.buffers.lock();
let buffer = buffers.pop();
match buffer {
Some(buffer) => buffer,
None => {
vec![
read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields);
self.buffer_size
]
}
None => CsvSlab::new(self.record_size, self.num_fields, self.num_rows),
}
};

Expand All @@ -193,10 +211,9 @@ impl CsvBufferPool {
}
}

fn return_buffer(&self, buffer: Vec<read::ByteRecord>) {
if let Ok(mut buffers) = self.buffers.lock() {
buffers.push(buffer);
}
fn return_buffer(&self, buffer: CsvSlab) {
let mut buffers = self.buffers.lock();
buffers.push(buffer);
}
}

Expand Down Expand Up @@ -226,12 +243,7 @@ impl FileSlabPool {
// We get uninitialized buffers because we will always populate the buffers with a file read before use.
.map(|_| Box::new_uninit_slice(SLABSIZE))
.map(|x| unsafe { x.assume_init() })
.map(|buffer| {
RwLock::new(FileSlabState {
buffer,
valid_bytes: 0,
})
})
.map(|buffer| RwLock::new(FileSlabState::new(buffer, 0)))
.collect();
Arc::new(Self {
slabs: Mutex::new(slabs),
Expand All @@ -240,14 +252,14 @@ impl FileSlabPool {

fn get_slab(self: &Arc<Self>) -> Arc<FileSlab> {
let slab = {
let mut slabs = self.slabs.lock().unwrap();
let mut slabs = self.slabs.lock();
let slab = slabs.pop();
match slab {
Some(slab) => slab,
None => RwLock::new(FileSlabState {
buffer: unsafe { Box::new_uninit_slice(SLABSIZE).assume_init() },
valid_bytes: 0,
}),
None => RwLock::new(FileSlabState::new(
unsafe { Box::new_uninit_slice(SLABSIZE).assume_init() },
0,
)),
}
};

Expand All @@ -258,9 +270,8 @@ impl FileSlabPool {
}

fn return_slab(&self, slab: RwLock<FileSlabState>) {
if let Ok(mut slabs) = self.slabs.lock() {
slabs.push(slab);
}
let mut slabs = self.slabs.lock();
slabs.push(slab);
}
}

Expand All @@ -272,9 +283,11 @@ struct FileSlab {
}

impl FileSlab {
fn find_newline(&self, offset: usize) -> Option<usize> {
let guard = self.state.read().unwrap();
guard.find_newline(offset)
/// Given an offset into a FileSlab, finds the first \n char found in the FileSlabState's buffer,
/// then the returns the position relative to the given offset.
fn find_first_newline_from(&self, offset: usize) -> Option<usize> {
let guard = self.state.read();
guard.find_first_newline_from(offset)
}
}

Expand All @@ -295,8 +308,15 @@ struct FileSlabState {
}

impl FileSlabState {
fn new(buffer: Box<[u8]>, valid_bytes: usize) -> Self {
Self {
buffer,
valid_bytes,
}
}

/// Helper function that find the first \n char in the file slab state's buffer starting from `offset.`
fn find_newline(&self, offset: usize) -> Option<usize> {
fn find_first_newline_from(&self, offset: usize) -> Option<usize> {
newline_position(&self.buffer[offset..self.valid_bytes])
}

Expand Down Expand Up @@ -535,7 +555,7 @@ impl Iterator for SlabIterator {
fn next(&mut self) -> Option<Self::Item> {
let slab = self.slabpool.get_slab();
let bytes_read = {
let mut writer = slab.state.write().unwrap();
let mut writer = slab.state.write();
let bytes_read = self.file.read(&mut writer.buffer).unwrap();
if bytes_read == 0 {
return None;
Expand Down Expand Up @@ -580,11 +600,11 @@ where
let mut curr_pos = 0;
let mut chunk_state: Option<ChunkState> = None;
while chunk_state.is_none()
&& let Some(pos) = slab.find_newline(curr_pos)
&& let Some(pos) = slab.find_first_newline_from(curr_pos)
&& curr_pos < valid_bytes
{
let offset = curr_pos + pos;
let guard = slab.state.read().unwrap();
let guard = slab.state.read();
chunk_state = match guard.validate_record(&mut self.validator, offset + 1) {
Some(true) => Some(ChunkState::Final {
slab: slab.clone(),
Expand Down Expand Up @@ -716,7 +736,7 @@ fn consume_slab_iterator(
let (tx, rx) = tokio::sync::oneshot::channel();
rayon::spawn(move || {
let reader = MultiSliceReader::new(&w);
let tables = dispatch_to_parse_csv(
let tables = collect_tables(
has_header,
&parse_options,
reader,
Expand Down Expand Up @@ -785,7 +805,7 @@ impl std::fmt::Display for ChunkState {
}
}
}
/// A helper struct that implements `std::io::Read` over a slice of ChunkStates' buffers.
/// A helper struct that implements `std::io::Read` over a slice of ChunkStates.
struct MultiSliceReader<'a> {
states: &'a [ChunkState],
curr_read_idx: usize,
Expand All @@ -810,16 +830,15 @@ impl<'a> Read for MultiSliceReader<'a> {
let state = &self.states[self.curr_read_idx];
let (start, end, guard) = match state {
ChunkState::Start { slab, start, end } => {
let guard: std::sync::RwLockReadGuard<'_, FileSlabState> =
slab.state.read().unwrap();
let guard = slab.state.read();
(*start, *end, guard)
}
ChunkState::Continue { slab, end } => {
let guard = slab.state.read().unwrap();
let guard = slab.state.read();
(0, *end, guard)
}
ChunkState::Final { slab, end, .. } => {
let guard = slab.state.read().unwrap();
let guard = slab.state.read();
(0, *end, guard)
}
};
Expand Down Expand Up @@ -951,8 +970,10 @@ impl CsvValidator {
}

fn validate_record<'a>(&mut self, iter: &mut impl Iterator<Item = &'a u8>) -> Option<bool> {
// Reset state machine for each new validation attempt.
self.state = CsvState::FieldStart;
self.num_fields_seen = 1;
// Start running the state machine against each byte.
for &byte in iter {
let next_state = self.transition_table[self.state as usize][byte as usize];

Expand All @@ -977,13 +998,13 @@ impl CsvValidator {
}
}

/// Helper function that takes in a BufferSource, calls parse_csv() to extract table values from
/// the buffer source, then streams the results to `sender`.
/// Helper function that takes in a source of bytes, calls parse_csv() to extract table values from
/// the buffer source, then returns the vector of Daft tables.
#[allow(clippy::too_many_arguments)]
fn dispatch_to_parse_csv<R>(
fn collect_tables<R>(
has_header: bool,
parse_options: &CsvParseOptions,
buffer_source: R,
byte_reader: R,
projection_indices: Arc<Vec<usize>>,
fields: Vec<Field>,
read_daft_fields: Arc<Vec<Arc<daft_core::datatypes::Field>>>,
Expand All @@ -1004,7 +1025,7 @@ where
.escape(parse_options.escape_char)
.comment(parse_options.comment)
.flexible(parse_options.allow_variable_columns)
.from_reader(buffer_source);
.from_reader(byte_reader);
// The header should not count towards the limit.
let limit = limit.map(|limit| limit + (has_header as usize));
parse_csv_chunk(
Expand Down

0 comments on commit 43c8a6a

Please sign in to comment.