[go: up one dir, main page]

arrow2 0.7.1

Unofficial implementation of Apache Arrow spec in safe Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
//! Common utilities used to write to Arrow's IPC format.

use std::io::Write;
use std::{collections::HashMap, sync::Arc};

use arrow_format::ipc;
use arrow_format::ipc::flatbuffers::FlatBufferBuilder;

use crate::array::Array;
use crate::error::{ArrowError, Result};
use crate::io::ipc::endianess::is_native_little_endian;
use crate::record_batch::RecordBatch;
use crate::{array::DictionaryArray, datatypes::*};

use super::super::CONTINUATION_MARKER;
use super::{write, write_dictionary};

/// IPC write options used to control the behaviour of the writer
#[derive(Debug)]
pub struct IpcWriteOptions {
    /// Write padding after memory buffers to this multiple of bytes.
    /// Generally 8 or 64, defaults to 8
    alignment: usize,
    /// The legacy format is for releases before 0.15.0, and uses metadata V4
    write_legacy_ipc_format: bool,
    /// The metadata version to write. The Rust IPC writer supports V4+
    ///
    /// *Default versions per crate*
    ///
    /// When creating the default IpcWriteOptions, the following metadata versions are used:
    ///
    /// version 2.0.0: V4, with legacy format enabled
    /// version 4.0.0: V5
    metadata_version: ipc::Schema::MetadataVersion,
}

impl IpcWriteOptions {
    /// Try create IpcWriteOptions, checking for incompatible settings
    pub fn try_new(
        alignment: usize,
        write_legacy_ipc_format: bool,
        metadata_version: ipc::Schema::MetadataVersion,
    ) -> Result<Self> {
        if alignment == 0 || alignment % 8 != 0 {
            return Err(ArrowError::InvalidArgumentError(
                "Alignment should be greater than 0 and be a multiple of 8".to_string(),
            ));
        }
        match metadata_version {
            ipc::Schema::MetadataVersion::V1
            | ipc::Schema::MetadataVersion::V2
            | ipc::Schema::MetadataVersion::V3 => Err(ArrowError::InvalidArgumentError(
                "Writing IPC metadata version 3 and lower not supported".to_string(),
            )),
            ipc::Schema::MetadataVersion::V4 => Ok(Self {
                alignment,
                write_legacy_ipc_format,
                metadata_version,
            }),
            ipc::Schema::MetadataVersion::V5 => {
                if write_legacy_ipc_format {
                    Err(ArrowError::InvalidArgumentError(
                        "Legacy IPC format only supported on metadata version 4".to_string(),
                    ))
                } else {
                    Ok(Self {
                        alignment,
                        write_legacy_ipc_format,
                        metadata_version,
                    })
                }
            }
            z => panic!("Unsupported ipc::Schema::MetadataVersion {:?}", z),
        }
    }

    pub fn metadata_version(&self) -> &ipc::Schema::MetadataVersion {
        &self.metadata_version
    }
}

impl Default for IpcWriteOptions {
    fn default() -> Self {
        Self {
            alignment: 8,
            write_legacy_ipc_format: false,
            metadata_version: ipc::Schema::MetadataVersion::V5,
        }
    }
}

pub fn encoded_batch(
    batch: &RecordBatch,
    dictionary_tracker: &mut DictionaryTracker,
    write_options: &IpcWriteOptions,
) -> Result<(Vec<EncodedData>, EncodedData)> {
    // TODO: handle nested dictionaries
    let schema = batch.schema();
    let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len());

    for (i, field) in schema.fields().iter().enumerate() {
        let column = batch.column(i);

        if let DataType::Dictionary(_key_type, _value_type) = column.data_type() {
            let dict_id = field
                .dict_id()
                .expect("All Dictionary types have `dict_id`");

            let emit = dictionary_tracker.insert(dict_id, column)?;

            if emit {
                encoded_dictionaries.push(dictionary_batch_to_bytes(
                    dict_id,
                    column.as_ref(),
                    write_options,
                    is_native_little_endian(),
                ));
            }
        }
    }

    let encoded_message = record_batch_to_bytes(batch, write_options);

    Ok((encoded_dictionaries, encoded_message))
}

/// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Schema::Message) and the
/// other for the batch's data
fn record_batch_to_bytes(batch: &RecordBatch, write_options: &IpcWriteOptions) -> EncodedData {
    let mut fbb = FlatBufferBuilder::new();

    let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
    let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
    let mut arrow_data: Vec<u8> = vec![];
    let mut offset = 0;
    for array in batch.columns() {
        write(
            array.as_ref(),
            &mut buffers,
            &mut arrow_data,
            &mut nodes,
            &mut offset,
            is_native_little_endian(),
        )
    }

    // write data
    let buffers = fbb.create_vector(&buffers);
    let nodes = fbb.create_vector(&nodes);

    let root = {
        let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
        batch_builder.add_length(batch.num_rows() as i64);
        batch_builder.add_nodes(nodes);
        batch_builder.add_buffers(buffers);
        let b = batch_builder.finish();
        b.as_union_value()
    };
    // create an ipc::Schema::Message
    let mut message = ipc::Message::MessageBuilder::new(&mut fbb);
    message.add_version(write_options.metadata_version);
    message.add_header_type(ipc::Message::MessageHeader::RecordBatch);
    message.add_bodyLength(arrow_data.len() as i64);
    message.add_header(root);
    let root = message.finish();
    fbb.finish(root, None);
    let finished_data = fbb.finished_data();

    EncodedData {
        ipc_message: finished_data.to_vec(),
        arrow_data,
    }
}

/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the
/// other for the data
fn dictionary_batch_to_bytes(
    dict_id: i64,
    array: &dyn Array,
    write_options: &IpcWriteOptions,
    is_little_endian: bool,
) -> EncodedData {
    let mut fbb = FlatBufferBuilder::new();

    let mut nodes: Vec<ipc::Message::FieldNode> = vec![];
    let mut buffers: Vec<ipc::Schema::Buffer> = vec![];
    let mut arrow_data: Vec<u8> = vec![];

    let length = write_dictionary(
        array,
        &mut buffers,
        &mut arrow_data,
        &mut nodes,
        &mut 0,
        is_little_endian,
        false,
    );

    // write data
    let buffers = fbb.create_vector(&buffers);
    let nodes = fbb.create_vector(&nodes);

    let root = {
        let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb);
        batch_builder.add_length(length as i64);
        batch_builder.add_nodes(nodes);
        batch_builder.add_buffers(buffers);
        batch_builder.finish()
    };

    let root = {
        let mut batch_builder = ipc::Message::DictionaryBatchBuilder::new(&mut fbb);
        batch_builder.add_id(dict_id);
        batch_builder.add_data(root);
        batch_builder.finish().as_union_value()
    };

    let root = {
        let mut message_builder = ipc::Message::MessageBuilder::new(&mut fbb);
        message_builder.add_version(write_options.metadata_version);
        message_builder.add_header_type(ipc::Message::MessageHeader::DictionaryBatch);
        message_builder.add_bodyLength(arrow_data.len() as i64);
        message_builder.add_header(root);
        message_builder.finish()
    };

    fbb.finish(root, None);
    let finished_data = fbb.finished_data();

    EncodedData {
        ipc_message: finished_data.to_vec(),
        arrow_data,
    }
}

/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary
/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which
/// isn't allowed in the `FileWriter`.
pub struct DictionaryTracker {
    written: HashMap<i64, Arc<dyn Array>>,
    error_on_replacement: bool,
}

impl DictionaryTracker {
    pub fn new(error_on_replacement: bool) -> Self {
        Self {
            written: HashMap::new(),
            error_on_replacement,
        }
    }

    /// Keep track of the dictionary with the given ID and values. Behavior:
    ///
    /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate
    ///   that the dictionary was not actually inserted (because it's already been seen).
    /// * If this ID has been written already but with different data, and this tracker is
    ///   configured to return an error, return an error.
    /// * If the tracker has not been configured to error on replacement or this dictionary
    ///   has never been seen before, return `Ok(true)` to indicate that the dictionary was just
    ///   inserted.
    pub fn insert(&mut self, dict_id: i64, array: &Arc<dyn Array>) -> Result<bool> {
        let values = match array.data_type() {
            DataType::Dictionary(key_type, _) => {
                with_match_dictionary_key_type!(key_type.as_ref(), |$T| {
                    let array = array
                        .as_any()
                        .downcast_ref::<DictionaryArray<$T>>()
                        .unwrap();
                    array.values()
                })
            }
            _ => unreachable!(),
        };

        // If a dictionary with this id was already emitted, check if it was the same.
        if let Some(last) = self.written.get(&dict_id) {
            if last.as_ref() == values.as_ref() {
                // Same dictionary values => no need to emit it again
                return Ok(false);
            } else if self.error_on_replacement {
                return Err(ArrowError::InvalidArgumentError(
                    "Dictionary replacement detected when writing IPC file format. \
                     Arrow IPC files only support a single dictionary for a given field \
                     across all batches."
                        .to_string(),
                ));
            }
        };

        self.written.insert(dict_id, values.clone());
        Ok(true)
    }
}

/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data
pub struct EncodedData {
    /// An encoded ipc::Schema::Message
    pub ipc_message: Vec<u8>,
    /// Arrow buffers to be written, should be an empty vec for schema messages
    pub arrow_data: Vec<u8>,
}

/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
pub fn write_message<W: Write>(
    writer: &mut W,
    encoded: EncodedData,
    write_options: &IpcWriteOptions,
) -> Result<(usize, usize)> {
    let arrow_data_len = encoded.arrow_data.len();
    if arrow_data_len % 8 != 0 {
        return Err(ArrowError::Ipc("Arrow data not aligned".to_string()));
    }

    let a = write_options.alignment - 1;
    let buffer = encoded.ipc_message;
    let flatbuf_size = buffer.len();
    let prefix_size = if write_options.write_legacy_ipc_format {
        4
    } else {
        8
    };
    let aligned_size = (flatbuf_size + prefix_size + a) & !a;
    let padding_bytes = aligned_size - flatbuf_size - prefix_size;

    write_continuation(writer, write_options, (aligned_size - prefix_size) as i32)?;

    // write the flatbuf
    if flatbuf_size > 0 {
        writer.write_all(&buffer)?;
    }
    // write padding
    writer.write_all(&vec![0; padding_bytes])?;

    // write arrow data
    let body_len = if arrow_data_len > 0 {
        write_body_buffers(writer, &encoded.arrow_data)?
    } else {
        0
    };

    Ok((aligned_size, body_len))
}

fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> Result<usize> {
    let len = data.len() as u32;
    let pad_len = pad_to_8(len) as u32;
    let total_len = len + pad_len;

    // write body buffer
    writer.write_all(data)?;
    if pad_len > 0 {
        writer.write_all(&vec![0u8; pad_len as usize][..])?;
    }

    writer.flush()?;
    Ok(total_len as usize)
}

/// Write a record batch to the writer, writing the message size before the message
/// if the record batch is being written to a stream
pub fn write_continuation<W: Write>(
    writer: &mut W,
    write_options: &IpcWriteOptions,
    total_len: i32,
) -> Result<usize> {
    let mut written = 8;

    // the version of the writer determines whether continuation markers should be added
    match write_options.metadata_version {
        ipc::Schema::MetadataVersion::V1
        | ipc::Schema::MetadataVersion::V2
        | ipc::Schema::MetadataVersion::V3 => {
            unreachable!("Options with the metadata version cannot be created")
        }
        ipc::Schema::MetadataVersion::V4 => {
            if !write_options.write_legacy_ipc_format {
                // v0.15.0 format
                writer.write_all(&CONTINUATION_MARKER)?;
                written = 4;
            }
            writer.write_all(&total_len.to_le_bytes()[..])?;
        }
        ipc::Schema::MetadataVersion::V5 => {
            // write continuation marker and message length
            writer.write_all(&CONTINUATION_MARKER)?;
            writer.write_all(&total_len.to_le_bytes()[..])?;
        }
        z => panic!("Unsupported ipc::Schema::MetadataVersion {:?}", z),
    };

    writer.flush()?;

    Ok(written)
}

/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
#[inline]
pub(crate) fn pad_to_8(len: u32) -> usize {
    (((len + 7) & !7) - len) as usize
}