Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 153 additions & 11 deletions parquet/src/arrow/array_reader/row_group_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,69 @@ use std::sync::Arc;

pub(crate) struct RowGroupIndexReader {
buffered_indices: Vec<i64>,
remaining_indices: std::iter::Flatten<std::vec::IntoIter<std::iter::RepeatN<i64>>>,
state: ReaderState,
}

enum ReaderState {
// fast path: single row group with constant index value
SingleRowGroup {
index: i64,
remaining_rows: usize,
},
// general path: multiple row groups with iterator
MultipleRowGroups {
remaining_indices: std::iter::Flatten<std::vec::IntoIter<std::iter::RepeatN<i64>>>,
},
}

impl RowGroupIndexReader {
pub(crate) fn try_new<'a>(
parquet_metadata: &'a ParquetMetaData,
row_groups: impl Iterator<Item = &'a RowGroupMetaData>,
) -> Result<Self> {
// build mapping from ordinal to row group index
let row_groups: Vec<_> = row_groups.collect();

// optimize for single row group case
if row_groups.len() == 1 {
let rg = row_groups[0];
let ordinal = rg.ordinal().ok_or_else(|| {
ParquetError::General(
"Row group missing ordinal field, required to compute row group indices"
.to_string(),
)
})?;

// find the row group index by a linear scan through metadata
// this is O(n) but only done once, avoiding HashMap allocation
let index = parquet_metadata
.row_groups()
.iter()
.enumerate()
.find_map(|(idx, metadata_rg)| {
if metadata_rg.ordinal() == Some(ordinal) {
Some(idx as i64)
} else {
None
}
})
.ok_or_else(|| {
ParquetError::General(format!(
"Row group with ordinal {} not found in metadata",
ordinal
))
})?;

return Ok(Self {
buffered_indices: Vec::new(),
state: ReaderState::SingleRowGroup {
index,
remaining_rows: rg.num_rows() as usize,
},
});
}

// general path: many row groups
// builds a mapping from ordinal to row group index
// this is O(n) where n is the total number of row groups in the file
let ordinal_to_index: HashMap<i16, i64> =
HashMap::from_iter(parquet_metadata.row_groups().iter().enumerate().filter_map(
Expand All @@ -46,7 +100,8 @@ impl RowGroupIndexReader {

// build repeating iterators in the order specified by the row_groups iterator
// this is O(m) where m is the number of selected row groups
let repeated_indices: Vec<_> = row_groups
let repeated_indices = row_groups
.iter()
.map(|rg| {
let ordinal = rg.ordinal().ok_or_else(|| {
ParquetError::General(
Expand All @@ -68,26 +123,51 @@ impl RowGroupIndexReader {
rg.num_rows() as usize,
))
})
.collect::<Result<_>>()?;
.collect::<Result<Vec<_>>>()?;

Ok(Self {
buffered_indices: Vec::new(),
remaining_indices: repeated_indices.into_iter().flatten(),
state: ReaderState::MultipleRowGroups {
remaining_indices: repeated_indices.into_iter().flatten(),
},
})
}
}

impl ArrayReader for RowGroupIndexReader {
fn read_records(&mut self, batch_size: usize) -> Result<usize> {
let starting_len = self.buffered_indices.len();
self.buffered_indices
.extend((self.remaining_indices.by_ref()).take(batch_size));
Ok(self.buffered_indices.len() - starting_len)
match &mut self.state {
ReaderState::SingleRowGroup {
index,
remaining_rows,
} => {
let num_to_read = batch_size.min(*remaining_rows);
self.buffered_indices
.resize(self.buffered_indices.len() + num_to_read, *index);
*remaining_rows -= num_to_read;
Ok(num_to_read)
}
ReaderState::MultipleRowGroups { remaining_indices } => {
let starting_len = self.buffered_indices.len();
self.buffered_indices
.extend((remaining_indices.by_ref()).take(batch_size));
Ok(self.buffered_indices.len() - starting_len)
}
}
}

fn skip_records(&mut self, num_records: usize) -> Result<usize> {
// TODO: Use advance_by when it stabilizes to improve performance
Ok((self.remaining_indices.by_ref()).take(num_records).count())
match &mut self.state {
ReaderState::SingleRowGroup { remaining_rows, .. } => {
let num_to_skip = num_records.min(*remaining_rows);
*remaining_rows -= num_to_skip;
Ok(num_to_skip)
}
ReaderState::MultipleRowGroups { remaining_indices } => {
// TODO: Use advance_by when it stabilizes to improve performance
Ok((remaining_indices.by_ref()).take(num_records).count())
}
}
}

fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -238,4 +318,66 @@ mod tests {
let actual = indices.iter().map(|v| v.unwrap()).collect::<Vec<i64>>();
assert_eq!(actual, [1, 1, 2, 2]);
}

#[test]
fn test_row_group_index_reader_single_row_group() {
// 3 row groups
let metadata = create_test_parquet_metadata(vec![(0, 2), (1, 3), (2, 5)]);

// select last row group [2, 2, 2, 2, 2]
let selected_row_groups = [&metadata.row_groups()[2]];

let mut reader =
RowGroupIndexReader::try_new(&metadata, selected_row_groups.into_iter()).unwrap();

assert!(matches!(
reader.state,
ReaderState::SingleRowGroup { index: 2, .. }
));

let num_read = reader.read_records(10).unwrap();
assert_eq!(num_read, 5);

let array = reader.consume_batch().unwrap();
let indices = array.as_any().downcast_ref::<Int64Array>().unwrap();

let actual: Vec<i64> = indices.iter().map(|v| v.unwrap()).collect();
assert_eq!(actual, [2, 2, 2, 2, 2]);
}

#[test]
fn test_row_group_index_reader_single_row_group_with_skip() {
let metadata = create_test_parquet_metadata(vec![(0, 10), (1, 10), (2, 10)]);

// second row group (index 1, ordinal 1 with 10 rows)
let selected_row_groups = vec![&metadata.row_groups()[1]];

let mut reader =
RowGroupIndexReader::try_new(&metadata, selected_row_groups.into_iter()).unwrap();

assert!(matches!(
reader.state,
ReaderState::SingleRowGroup { index: 1, .. }
));

// skip first 3 rows
let num_skipped = reader.skip_records(3).unwrap();
assert_eq!(num_skipped, 3);

// read next 5 rows
let num_read = reader.read_records(5).unwrap();
assert_eq!(num_read, 5);

let array = reader.consume_batch().unwrap();
let indices = array.as_any().downcast_ref::<Int64Array>().unwrap();

let actual: Vec<i64> = indices.iter().map(|v| v.unwrap()).collect();
assert_eq!(actual, [1, 1, 1, 1, 1]);

if let ReaderState::SingleRowGroup { remaining_rows, .. } = reader.state {
assert_eq!(remaining_rows, 2);
} else {
panic!("Expected SingleRowGroup state");
}
}
}
Loading