use std::error::FromError;
use std::io;
use std::str;
use rustc_serialize::Encodable;
use {
BorrowBytes, ByteString, CsvResult, Encoded, Error, RecordTerminator,
StrAllocating,
};
#[derive(Copy)]
pub enum QuoteStyle {
Always,
Necessary,
Never,
}
pub struct Writer<W> {
buf: io::BufferedWriter<W>,
delimiter: u8,
record_terminator: RecordTerminator,
flexible: bool,
quote: u8,
escape: u8,
double_quote: bool,
quote_style: QuoteStyle,
first_len: usize,
}
impl Writer<io::IoResult<io::File>> {
pub fn from_file(path: &Path) -> Writer<io::IoResult<io::File>> {
Writer::from_writer(io::File::create(path))
}
}
impl<W: io::Writer> Writer<W> {
pub fn from_writer(w: W) -> Writer<W> {
Writer::from_buffer(io::BufferedWriter::new(w))
}
pub fn from_buffer(buf: io::BufferedWriter<W>) -> Writer<W> {
Writer {
buf: buf,
delimiter: b',',
record_terminator: RecordTerminator::Any(b'\n'),
flexible: false,
quote: b'"',
escape: b'\\',
double_quote: true,
quote_style: QuoteStyle::Necessary,
first_len: 0,
}
}
}
impl Writer<Vec<u8>> {
pub fn from_memory() -> Writer<Vec<u8>> {
Writer::from_writer(Vec::with_capacity(1024 * 64))
}
pub fn as_string<'r>(&'r mut self) -> &'r str {
match self.buf.flush() {
Err(err) => panic!("Error flushing to Vec<u8>: {}", err),
Ok(()) => str::from_utf8(&**self.buf.get_ref()).unwrap(),
}
}
pub fn as_bytes<'r>(&'r mut self) -> &'r [u8] {
match self.buf.flush() {
Err(err) => panic!("Error flushing to Vec<u8>: {}", err),
Ok(()) => &**self.buf.get_ref(),
}
}
}
impl<W: io::Writer> Writer<W> {
pub fn encode<E>(&mut self, e: E) -> CsvResult<()> where E: Encodable {
let mut erecord = Encoded::new();
try!(e.encode(&mut erecord));
self.write(erecord.unwrap().into_iter())
}
pub fn write<'a, I>(&mut self, r: I) -> CsvResult<()>
where I: Iterator, <I as Iterator>::Item: BorrowBytes {
self.write_iter(r.map(|f| Ok(f)))
}
#[doc(hidden)]
pub fn write_iter<'a, I, F>(&mut self, mut r: I) -> CsvResult<()>
where I: Iterator<Item=CsvResult<F>>, F: BorrowBytes {
let delim = self.delimiter;
let mut count = 0;
let mut last_len = 0;
for field in r {
if count > 0 {
try!(self.w_bytes(&[delim]));
}
count += 1;
let field = try!(field);
last_len = field.borrow_bytes().len();
try!(self.w_user_bytes(field.borrow_bytes()));
}
if count == 1 && last_len == 0 {
let q = self.quote;
try!(self.w_bytes(&[q, q]));
}
try!(self.w_lineterm());
self.set_first_len(count)
}
pub fn flush(&mut self) -> CsvResult<()> {
self.buf.flush().map_err(FromError::from_error)
}
}
impl<W: io::Writer> Writer<W> {
pub fn delimiter(mut self, delimiter: u8) -> Writer<W> {
self.delimiter = delimiter;
self
}
pub fn flexible(mut self, yes: bool) -> Writer<W> {
self.flexible = yes;
self
}
pub fn record_terminator(mut self, term: RecordTerminator) -> Writer<W> {
self.record_terminator = term;
self
}
pub fn quote_style(mut self, style: QuoteStyle) -> Writer<W> {
self.quote_style = style;
self
}
pub fn quote(mut self, quote: u8) -> Writer<W> {
self.quote = quote;
self
}
pub fn escape(mut self, escape: u8) -> Writer<W> {
self.escape = escape;
self
}
pub fn double_quote(mut self, yes: bool) -> Writer<W> {
self.double_quote = yes;
self
}
}
impl<W: io::Writer> Writer<W> {
fn err<S, T>(&self, msg: S) -> CsvResult<T> where S: StrAllocating {
Err(Error::Encode(msg.into_str()))
}
fn w_bytes(&mut self, s: &[u8]) -> CsvResult<()> {
self.buf.write(s).map_err(Error::Io)
}
fn w_user_bytes(&mut self, s: &[u8]) -> CsvResult<()> {
if try!(self.should_quote(s)) {
let quoted = self.quote_field(s);
self.w_bytes(&*quoted)
} else {
self.w_bytes(s)
}
}
fn w_lineterm(&mut self) -> CsvResult<()> {
match self.record_terminator {
RecordTerminator::CRLF => self.w_bytes(b"\r\n"),
RecordTerminator::Any(b) => self.w_bytes(&[b]),
}
}
fn set_first_len(&mut self, cur_len: usize) -> CsvResult<()> {
if cur_len == 0 {
return self.err("Records must have length greater than 0.")
}
if !self.flexible {
if self.first_len == 0 {
self.first_len = cur_len;
} else if self.first_len != cur_len {
return self.err(format!(
"Record has length {} but other records have length {}",
cur_len, self.first_len))
}
}
Ok(())
}
fn should_quote(&self, field: &[u8]) -> CsvResult<bool> {
let needs = |&:| field.iter().any(|&b| self.byte_needs_quotes(b));
match self.quote_style {
QuoteStyle::Always => Ok(true),
QuoteStyle::Necessary => Ok(needs()),
QuoteStyle::Never => {
if !needs() {
Ok(false)
} else {
self.err(format!(
"Field requires quotes, but quote style \
is 'Never': '{}'",
String::from_utf8_lossy(field)))
}
}
}
}
fn byte_needs_quotes(&self, b: u8) -> bool {
b == self.delimiter
|| self.record_terminator == b
|| b == self.quote
|| b == b'\r' || b == b'\n'
}
fn quote_field(&self, mut s: &[u8]) -> ByteString {
let mut buf = Vec::with_capacity(s.len() + 2);
buf.push(self.quote);
loop {
match s.position_elem(&self.quote) {
None => {
buf.push_all(s);
break
}
Some(next_quote) => {
buf.push_all(s.slice_to(next_quote));
if self.double_quote {
buf.push(self.quote);
buf.push(self.quote);
} else {
buf.push(self.escape);
buf.push(self.quote);
}
s = s.slice_from(next_quote + 1);
}
}
}
buf.push(self.quote);
ByteString::from_bytes(buf)
}
}