use std::io::Write;
use std::{collections::HashMap, sync::Arc};
use arrow_format::ipc;
use arrow_format::ipc::flatbuffers::FlatBufferBuilder;
use crate::array::Array;
use crate::error::{ArrowError, Result};
use crate::io::ipc::endianess::is_native_little_endian;
use crate::record_batch::RecordBatch;
use crate::{array::DictionaryArray, datatypes::*};
use super::super::CONTINUATION_MARKER;
use super::{write, write_dictionary};
#[derive(Debug)]
pub struct IpcWriteOptions {
alignment: usize,
write_legacy_ipc_format: bool,
metadata_version: ipc::Schema::MetadataVersion,
}
impl IpcWriteOptions {
pub fn try_new(
alignment: usize,
write_legacy_ipc_format: bool,
metadata_version: ipc::Schema::MetadataVersion,
) -> Result<Self> {
if alignment == 0 || alignment % 8 != 0 {
return Err(ArrowError::InvalidArgumentError(
"Alignment should be greater than 0 and be a multiple of 8".to_string(),
));
}
match metadata_version {
ipc::Schema::MetadataVersion::V1
| ipc::Schema::MetadataVersion::V2
| ipc::Schema::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
"Writing IPC metadata version 3 and lower not supported".to_string(),
)),
ipc::Schema::MetadataVersion::V4 => Ok(Self {
alignment,
write_legacy_ipc_format,
metadata_version,
}),
ipc::Schema::MetadataVersion::V5 => {
if write_legacy_ipc_format {
Err(ArrowError::InvalidArgumentError(
"Legacy IPC format only supported on metadata version 4".to_string(),
))
} else {
Ok(Self {
alignment,
write_legacy_ipc_format,
metadata_version,
})
}
}
z => panic!("Unsupported ipc::Schema::MetadataVersion {:?}", z),
}
}
pub fn metadata_version(&self) -> &ipc::Schema::MetadataVersion {
&self.metadata_version
}
}
impl Default for IpcWriteOptions {
fn default() -> Self {
Self {
alignment: 8,
write_legacy_ipc_format: false,
metadata_version: ipc::Schema::MetadataVersion::V5,
}
}
}
pub fn encoded_batch(
batch: &RecordBatch,
dictionary_tracker: &mut DictionaryTracker,
write_options: &IpcWriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
let schema = batch.schema();
let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());
for (i, field) in schema.fields().iter().enumerate() {
let column = batch.column(i);
if let DataType::Dictionary(_key_type, _value_type) = column.data_type() {
let dict_id = field
.dict_id()
.expect("All Dictionary types have `dict_id`");
let emit = dictionary_tracker.insert(dict_id, column)?;
if emit {
encoded_dictionaries.push(dictionary_batch_to_bytes(
dict_id,
column.as_ref(),
write_options,
is_native_little_endian(),
));
}
}
}
let encoded_message = record_batch_to_bytes(batch, write_options);
Ok((encoded_dictionaries, encoded_message))
}
fn record_batch_to_bytes(batch: &RecordBatch, write_options: &IpcWriteOptions) -> EncodedData {
let mut fbb = FlatBufferBuilder::new();
let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let mut offset = 0;
for array in batch.columns() {
write(
array.as_ref(),
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut offset,
is_native_little_endian(),
)
}
let buffers = fbb.create_vector(&buffers);
let nodes = fbb.create_vector(&nodes);
let root = {
let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
batch_builder.add_length(batch.num_rows() as i64);
batch_builder.add_nodes(nodes);
batch_builder.add_buffers(buffers);
let b = batch_builder.finish();
b.as_union_value()
};
let mut message = ipc::Message::MessageBuilder::new(&mut fbb);
message.add_version(write_options.metadata_version);
message.add_header_type(ipc::Message::MessageHeader::RecordBatch);
message.add_bodyLength(arrow_data.len() as i64);
message.add_header(root);
let root = message.finish();
fbb.finish(root, None);
let finished_data = fbb.finished_data();
EncodedData {
ipc_message: finished_data.to_vec(),
arrow_data,
}
}
fn dictionary_batch_to_bytes(
dict_id: i64,
array: &dyn Array,
write_options: &IpcWriteOptions,
is_little_endian: bool,
) -> EncodedData {
let mut fbb = FlatBufferBuilder::new();
let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
let mut arrow_data: Vec<u8> = vec![];
let length = write_dictionary(
array,
&mut buffers,
&mut arrow_data,
&mut nodes,
&mut 0,
is_little_endian,
false,
);
let buffers = fbb.create_vector(&buffers);
let nodes = fbb.create_vector(&nodes);
let root = {
let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
batch_builder.add_length(length as i64);
batch_builder.add_nodes(nodes);
batch_builder.add_buffers(buffers);
batch_builder.finish()
};
let root = {
let mut batch_builder = ipc::Message::DictionaryBatchBuilder::new(&mut fbb);
batch_builder.add_id(dict_id);
batch_builder.add_data(root);
batch_builder.finish().as_union_value()
};
let root = {
let mut message_builder = ipc::Message::MessageBuilder::new(&mut fbb);
message_builder.add_version(write_options.metadata_version);
message_builder.add_header_type(ipc::Message::MessageHeader::DictionaryBatch);
message_builder.add_bodyLength(arrow_data.len() as i64);
message_builder.add_header(root);
message_builder.finish()
};
fbb.finish(root, None);
let finished_data = fbb.finished_data();
EncodedData {
ipc_message: finished_data.to_vec(),
arrow_data,
}
}
pub struct DictionaryTracker {
written: HashMap<i64, Arc<dyn Array>>,
error_on_replacement: bool,
}
impl DictionaryTracker {
pub fn new(error_on_replacement: bool) -> Self {
Self {
written: HashMap::new(),
error_on_replacement,
}
}
pub fn insert(&mut self, dict_id: i64, array: &Arc<dyn Array>) -> Result<bool> {
let values = match array.data_type() {
DataType::Dictionary(key_type, _) => {
with_match_dictionary_key_type!(key_type.as_ref(), |$T| {
let array = array
.as_any()
.downcast_ref::<DictionaryArray<$T>>()
.unwrap();
array.values()
})
}
_ => unreachable!(),
};
if let Some(last) = self.written.get(&dict_id) {
if last.as_ref() == values.as_ref() {
return Ok(false);
} else if self.error_on_replacement {
return Err(ArrowError::InvalidArgumentError(
"Dictionary replacement detected when writing IPC file format. \
Arrow IPC files only support a single dictionary for a given field \
across all batches."
.to_string(),
));
}
};
self.written.insert(dict_id, values.clone());
Ok(true)
}
}
pub struct EncodedData {
pub ipc_message: Vec<u8>,
pub arrow_data: Vec<u8>,
}
pub fn write_message<W: Write>(
writer: &mut W,
encoded: EncodedData,
write_options: &IpcWriteOptions,
) -> Result<(usize, usize)> {
let arrow_data_len = encoded.arrow_data.len();
if arrow_data_len % 8 != 0 {
return Err(ArrowError::Ipc("Arrow data not aligned".to_string()));
}
let a = write_options.alignment - 1;
let buffer = encoded.ipc_message;
let flatbuf_size = buffer.len();
let prefix_size = if write_options.write_legacy_ipc_format {
4
} else {
8
};
let aligned_size = (flatbuf_size + prefix_size + a) & !a;
let padding_bytes = aligned_size - flatbuf_size - prefix_size;
write_continuation(writer, write_options, (aligned_size - prefix_size) as i32)?;
if flatbuf_size > 0 {
writer.write_all(&buffer)?;
}
writer.write_all(&vec![0; padding_bytes])?;
let body_len = if arrow_data_len > 0 {
write_body_buffers(writer, &encoded.arrow_data)?
} else {
0
};
Ok((aligned_size, body_len))
}
fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> Result<usize> {
let len = data.len() as u32;
let pad_len = pad_to_8(len) as u32;
let total_len = len + pad_len;
writer.write_all(data)?;
if pad_len > 0 {
writer.write_all(&vec![0u8; pad_len as usize][..])?;
}
writer.flush()?;
Ok(total_len as usize)
}
pub fn write_continuation<W: Write>(
writer: &mut W,
write_options: &IpcWriteOptions,
total_len: i32,
) -> Result<usize> {
let mut written = 8;
match write_options.metadata_version {
ipc::Schema::MetadataVersion::V1
| ipc::Schema::MetadataVersion::V2
| ipc::Schema::MetadataVersion::V3 => {
unreachable!("Options with the metadata version cannot be created")
}
ipc::Schema::MetadataVersion::V4 => {
if !write_options.write_legacy_ipc_format {
writer.write_all(&CONTINUATION_MARKER)?;
written = 4;
}
writer.write_all(&total_len.to_le_bytes()[..])?;
}
ipc::Schema::MetadataVersion::V5 => {
writer.write_all(&CONTINUATION_MARKER)?;
writer.write_all(&total_len.to_le_bytes()[..])?;
}
z => panic!("Unsupported ipc::Schema::MetadataVersion {:?}", z),
};
writer.flush()?;
Ok(written)
}
#[inline]
pub(crate) fn pad_to_8(len: u32) -> usize {
(((len + 7) & !7) - len) as usize
}