use std::error::FromError;
use std::io;
use std::str;
use serialize::Encodable;
use {ByteString, CsvResult, Encoded, Error, RecordTerminator};
#[deriving(Copy)]
pub enum QuoteStyle {
Always,
Necessary,
}
pub struct Writer<W> {
buf: io::BufferedWriter<W>,
delimiter: u8,
record_terminator: RecordTerminator,
flexible: bool,
quote: u8,
escape: u8,
double_quote: bool,
quote_style: QuoteStyle,
first_len: uint,
}
impl Writer<io::IoResult<io::File>> {
pub fn from_file(path: &Path) -> Writer<io::IoResult<io::File>> {
Writer::from_writer(io::File::create(path))
}
}
impl<W: io::Writer> Writer<W> {
pub fn from_writer(w: W) -> Writer<W> {
Writer::from_buffer(io::BufferedWriter::new(w))
}
pub fn from_buffer(buf: io::BufferedWriter<W>) -> Writer<W> {
Writer {
buf: buf,
delimiter: b',',
record_terminator: RecordTerminator::Any(b'\n'),
flexible: false,
quote: b'"',
escape: b'\\',
double_quote: true,
quote_style: QuoteStyle::Necessary,
first_len: 0,
}
}
}
impl Writer<Vec<u8>> {
pub fn from_memory() -> Writer<Vec<u8>> {
Writer::from_writer(Vec::with_capacity(1024 * 64))
}
pub fn as_string<'r>(&'r mut self) -> &'r str {
match self.buf.flush() {
Err(err) => panic!("Error flushing to Vec<u8>: {}", err),
Ok(()) => str::from_utf8(self.buf.get_ref()[]).unwrap(),
}
}
pub fn as_bytes<'r>(&'r mut self) -> &'r [u8] {
match self.buf.flush() {
Err(err) => panic!("Error flushing to Vec<u8>: {}", err),
Ok(()) => self.buf.get_ref()[],
}
}
}
impl<W: io::Writer> Writer<W> {
pub fn encode<E: Encodable<Encoded, Error>>
(&mut self, e: E) -> CsvResult<()> {
let mut erecord = Encoded::new();
try!(e.encode(&mut erecord));
self.write_bytes(erecord.unwrap().into_iter())
}
pub fn write<'a, Sized? S: 'a + Str, I: Iterator<&'a S>>
(&mut self, r: I) -> CsvResult<()> {
self.write_iter(r, |f| Ok(f.as_slice().as_bytes()))
}
pub fn write_bytes<S: AsSlice<u8>, I: Iterator<S>>
(&mut self, r: I) -> CsvResult<()> {
self.write_iter(r, |f| Ok(f))
}
#[doc(hidden)]
pub fn write_results<S: AsSlice<u8>, I: Iterator<CsvResult<S>>>
(&mut self, r: I) -> CsvResult<()> {
self.write_iter(r, |f| f)
}
fn write_iter<T, R: AsSlice<u8>, I: Iterator<T>>
(&mut self, mut r: I, as_sliceable: |T| -> CsvResult<R>)
-> CsvResult<()> {
let delim = self.delimiter;
let mut count = 0;
let mut last_len = 0;
for field in r {
if count > 0 {
try!(self.w_bytes(&[delim]));
}
count += 1;
let field = try!(as_sliceable(field));
last_len = field.as_slice().len();
try!(self.w_user_bytes(field.as_slice()));
}
if count == 1 && last_len == 0 {
let q = self.quote;
try!(self.w_bytes(&[q, q]));
}
try!(self.w_lineterm());
self.set_first_len(count)
}
pub fn flush(&mut self) -> CsvResult<()> {
self.buf.flush().map_err(FromError::from_error)
}
}
impl<W: io::Writer> Writer<W> {
pub fn delimiter(mut self, delimiter: u8) -> Writer<W> {
self.delimiter = delimiter;
self
}
pub fn flexible(mut self, yes: bool) -> Writer<W> {
self.flexible = yes;
self
}
pub fn record_terminator(mut self, term: RecordTerminator) -> Writer<W> {
self.record_terminator = term;
self
}
pub fn quote_style(mut self, style: QuoteStyle) -> Writer<W> {
self.quote_style = style;
self
}
pub fn quote(mut self, quote: u8) -> Writer<W> {
self.quote = quote;
self
}
pub fn escape(mut self, escape: u8) -> Writer<W> {
self.escape = escape;
self
}
pub fn double_quote(mut self, yes: bool) -> Writer<W> {
self.double_quote = yes;
self
}
}
impl<W: io::Writer> Writer<W> {
fn err<S: StrAllocating>(&self, msg: S) -> CsvResult<()> {
Err(Error::Encode(msg.into_string()))
}
fn w_bytes(&mut self, s: &[u8]) -> CsvResult<()> {
self.buf.write(s).map_err(Error::Io)
}
fn w_user_bytes(&mut self, s: &[u8]) -> CsvResult<()> {
if self.should_quote(s) {
let quoted = self.quote_field(s);
self.w_bytes(quoted.as_slice())
} else {
self.w_bytes(s)
}
}
fn w_lineterm(&mut self) -> CsvResult<()> {
match self.record_terminator {
RecordTerminator::CRLF => self.w_bytes(b"\r\n"),
RecordTerminator::Any(b) => self.w_bytes(&[b]),
}
}
fn set_first_len(&mut self, cur_len: uint) -> CsvResult<()> {
if cur_len == 0 {
return self.err("Records must have length greater than 0.")
}
if !self.flexible {
if self.first_len == 0 {
self.first_len = cur_len;
} else if self.first_len != cur_len {
return self.err(format!(
"Record has length {} but other records have length {}",
cur_len, self.first_len))
}
}
Ok(())
}
fn should_quote(&self, field: &[u8]) -> bool {
match self.quote_style {
QuoteStyle::Always => true,
QuoteStyle::Necessary =>
field.iter().any(|&b| self.byte_needs_quotes(b)),
}
}
fn byte_needs_quotes(&self, b: u8) -> bool {
b == self.delimiter
|| self.record_terminator == b
|| b == self.quote
|| b == b'\r' || b == b'\n'
}
fn quote_field(&self, mut s: &[u8]) -> ByteString {
let mut buf = Vec::with_capacity(s.len() + 2);
buf.push(self.quote);
loop {
match s.position_elem(&self.quote) {
None => {
buf.push_all(s);
break
}
Some(next_quote) => {
buf.push_all(s.slice_to(next_quote));
if self.double_quote {
buf.push(self.quote);
buf.push(self.quote);
} else {
buf.push(self.escape);
buf.push(self.quote);
}
s = s.slice_from(next_quote + 1);
}
}
}
buf.push(self.quote);
ByteString::from_bytes(buf)
}
}