#![doc(html_root_url="https://sfackler.github.io/doc")]
#![feature(globs, macro_rules, phase, unsafe_destructor, slicing_syntax, default_type_params)]
#![warn(missing_docs)]
extern crate openssl;
extern crate serialize;
extern crate phf;
#[phase(plugin)]
extern crate phf_mac;
#[phase(plugin, link)]
extern crate log;
extern crate time;
use url::Url;
use openssl::crypto::hash::{HashType, Hasher};
use openssl::ssl::{SslContext, MaybeSslStream};
use serialize::hex::ToHex;
use std::cell::{Cell, RefCell};
use std::cmp::max;
use std::collections::{RingBuf, HashMap};
use std::io::{BufferedStream, IoResult, IoError, IoErrorKind};
use std::io::net::ip::Port;
use std::iter::IteratorCloneExt;
use std::time::Duration;
use std::mem;
use std::fmt;
use std::result;
use io::{InternalStream, Timeout};
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
use message::FrontendMessage::*;
use message::BackendMessage::*;
use message::{WriteMessage, ReadMessage};
#[doc(inline)]
pub use types::{Oid, Type, ToSql, FromSql};
pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition};
#[macro_escape]
mod macros;
mod io;
mod message;
mod url;
mod util;
mod error;
pub mod types;
const CANARY: u32 = 0xdeadbeef;
pub type Result<T> = result::Result<T, Error>;
#[deriving(Clone)]
pub enum ConnectTarget {
Tcp(String),
Unix(Path)
}
#[deriving(Clone)]
pub struct UserInfo {
pub user: String,
pub password: Option<String>,
}
#[deriving(Clone)]
pub struct ConnectParams {
pub target: ConnectTarget,
pub port: Option<Port>,
pub user: Option<UserInfo>,
pub database: Option<String>,
pub options: Vec<(String, String)>,
}
pub trait IntoConnectParams {
fn into_connect_params(self) -> result::Result<ConnectParams, ConnectError>;
}
impl IntoConnectParams for ConnectParams {
fn into_connect_params(self) -> result::Result<ConnectParams, ConnectError> {
Ok(self)
}
}
impl<'a> IntoConnectParams for &'a str {
fn into_connect_params(self) -> result::Result<ConnectParams, ConnectError> {
match Url::parse(self) {
Ok(url) => url.into_connect_params(),
Err(err) => return Err(ConnectError::InvalidUrl(err)),
}
}
}
impl IntoConnectParams for Url {
fn into_connect_params(self) -> result::Result<ConnectParams, ConnectError> {
let Url {
host,
port,
user,
path: url::Path { path, query: options, .. },
..
} = self;
let maybe_path = try!(url::decode_component(&*host).map_err(ConnectError::InvalidUrl));
let target = if maybe_path.starts_with("/") {
ConnectTarget::Unix(Path::new(maybe_path))
} else {
ConnectTarget::Tcp(host)
};
let user = user.map(|url::UserInfo { user, pass }| {
UserInfo { user: user, password: pass }
});
let database = path.slice_shift_char().map(|(_, path)| path.into_string());
Ok(ConnectParams {
target: target,
port: port,
user: user,
database: database,
options: options,
})
}
}
pub trait NoticeHandler {
fn handle(&mut self, notice: DbError);
}
pub struct DefaultNoticeHandler;
impl NoticeHandler for DefaultNoticeHandler {
fn handle(&mut self, notice: DbError) {
info!("{}: {}", notice.severity, notice.message);
}
}
pub struct Notification {
pub pid: u32,
pub channel: String,
pub payload: String,
}
pub struct Notifications<'conn> {
conn: &'conn Connection
}
impl<'conn> Iterator<Notification> for Notifications<'conn> {
fn next(&mut self) -> Option<Notification> {
self.conn.conn.borrow_mut().notifications.pop_front()
}
}
impl<'conn> Notifications<'conn> {
pub fn next_block(&mut self) -> Result<Notification> {
if let Some(notification) = self.next() {
return Ok(notification);
}
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
match try!(conn.read_message_with_notification()) {
NotificationResponse { pid, channel, payload } => {
Ok(Notification {
pid: pid,
channel: channel,
payload: payload
})
}
_ => unreachable!()
}
}
pub fn next_block_for(&mut self, timeout: Duration) -> Result<Notification> {
fn now() -> i64 {
(time::precise_time_ns() / 100_000) as i64
}
if let Some(notification) = self.next() {
return Ok(notification);
}
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
let end = now() + timeout.num_milliseconds();
loop {
let timeout = max(0, end - now()) as u64;
conn.stream.set_read_timeout(Some(timeout));
match conn.read_one_message() {
Ok(Some(NotificationResponse { pid, channel, payload })) => {
return Ok(Notification {
pid: pid,
channel: channel,
payload: payload
})
}
Ok(Some(_)) => unreachable!(),
Ok(None) => {}
Err(e @ IoError { kind: IoErrorKind::TimedOut, .. }) => {
conn.desynchronized = false;
return Err(Error::IoError(e));
}
Err(e) => return Err(Error::IoError(e)),
}
}
}
}
pub struct CancelData {
pub process_id: u32,
pub secret_key: u32,
}
pub fn cancel_query<T>(params: T, ssl: &SslMode, data: CancelData)
-> result::Result<(), ConnectError> where T: IntoConnectParams {
let params = try!(params.into_connect_params());
let mut socket = try!(io::initialize_stream(¶ms, ssl));
try!(socket.write_message(&CancelRequest {
code: message::CANCEL_CODE,
process_id: data.process_id,
secret_key: data.secret_key
}));
try!(socket.flush());
Ok(())
}
struct InnerConnection {
stream: BufferedStream<MaybeSslStream<InternalStream>>,
next_stmt_id: uint,
notice_handler: Box<NoticeHandler+Send>,
notifications: RingBuf<Notification>,
cancel_data: CancelData,
unknown_types: HashMap<Oid, String>,
desynchronized: bool,
finished: bool,
trans_depth: u32,
canary: u32,
}
impl Drop for InnerConnection {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl InnerConnection {
fn connect<T>(params: T, ssl: &SslMode) -> result::Result<InnerConnection, ConnectError>
where T: IntoConnectParams {
let params = try!(params.into_connect_params());
let stream = try!(io::initialize_stream(¶ms, ssl));
let ConnectParams { user, database, mut options, .. } = params;
let user = try!(user.ok_or(ConnectError::MissingUser));
let mut conn = InnerConnection {
stream: BufferedStream::new(stream),
next_stmt_id: 0,
notice_handler: box DefaultNoticeHandler,
notifications: RingBuf::new(),
cancel_data: CancelData { process_id: 0, secret_key: 0 },
unknown_types: HashMap::new(),
desynchronized: false,
finished: false,
trans_depth: 0,
canary: CANARY,
};
options.push(("client_encoding".into_string(), "UTF8".into_string()));
options.push(("TimeZone".into_string(), "GMT".into_string()));
options.push(("user".into_string(), user.user.clone()));
if let Some(database) = database {
options.push(("database".into_string(), database));
}
try!(conn.write_messages(&[StartupMessage {
version: message::PROTOCOL_VERSION,
parameters: &*options
}]));
try!(conn.handle_auth(user));
loop {
match try!(conn.read_message()) {
BackendKeyData { process_id, secret_key } => {
conn.cancel_data.process_id = process_id;
conn.cancel_data.secret_key = secret_key;
}
ReadyForQuery { .. } => break,
ErrorResponse { fields } => return DbError::new_connect(fields),
_ => return Err(ConnectError::BadResponse),
}
}
Ok(conn)
}
fn write_messages(&mut self, messages: &[FrontendMessage]) -> IoResult<()> {
debug_assert!(!self.desynchronized);
for message in messages.iter() {
try_desync!(self, self.stream.write_message(message));
}
Ok(try_desync!(self, self.stream.flush()))
}
fn read_one_message(&mut self) -> IoResult<Option<BackendMessage>> {
debug_assert!(!self.desynchronized);
match try_desync!(self, self.stream.read_message()) {
NoticeResponse { fields } => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle(err);
}
Ok(None)
}
ParameterStatus { parameter, value } => {
debug!("Parameter {} = {}", parameter, value);
Ok(None)
}
val => Ok(Some(val))
}
}
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
loop {
if let Some(msg) = try!(self.read_one_message()) {
return Ok(msg);
}
}
}
fn read_message(&mut self) -> IoResult<BackendMessage> {
loop {
match try!(self.read_message_with_notification()) {
NotificationResponse { pid, channel, payload } => {
self.notifications.push_back(Notification {
pid: pid,
channel: channel,
payload: payload
})
}
val => return Ok(val)
}
}
}
fn handle_auth(&mut self, user: UserInfo) -> result::Result<(), ConnectError> {
match try!(self.read_message()) {
AuthenticationOk => return Ok(()),
AuthenticationCleartextPassword => {
let pass = try!(user.password.ok_or(ConnectError::MissingPassword));
try!(self.write_messages(&[PasswordMessage {
password: &*pass,
}]));
}
AuthenticationMD5Password { salt } => {
let pass = try!(user.password.ok_or(ConnectError::MissingPassword));
let mut hasher = Hasher::new(HashType::MD5);
hasher.update(pass.as_bytes());
hasher.update(user.user.as_bytes());
let output = hasher.finalize().to_hex();
let mut hasher = Hasher::new(HashType::MD5);
hasher.update(output.as_bytes());
hasher.update(&salt);
let output = format!("md5{}", hasher.finalize().to_hex());
try!(self.write_messages(&[PasswordMessage {
password: &*output
}]));
}
AuthenticationKerberosV5
| AuthenticationSCMCredential
| AuthenticationGSS
| AuthenticationSSPI => return Err(ConnectError::UnsupportedAuthentication),
ErrorResponse { fields } => return DbError::new_connect(fields),
_ => return Err(ConnectError::BadResponse)
}
match try!(self.read_message()) {
AuthenticationOk => Ok(()),
ErrorResponse { fields } => return DbError::new_connect(fields),
_ => return Err(ConnectError::BadResponse)
}
}
fn set_notice_handler(&mut self, handler: Box<NoticeHandler+Send>) -> Box<NoticeHandler+Send> {
mem::replace(&mut self.notice_handler, handler)
}
fn raw_prepare(&mut self, stmt_name: &str, query: &str)
-> Result<(Vec<Type>, Vec<ResultDescription>)> {
try!(self.write_messages(&[
Parse {
name: stmt_name,
query: query,
param_types: &[]
},
Describe {
variant: b'S',
name: stmt_name,
},
Sync]));
match try!(self.read_message()) {
ParseComplete => {}
ErrorResponse { fields } => {
try!(self.wait_for_ready());
return DbError::new(fields);
}
_ => bad_response!(self),
}
let mut param_types: Vec<_> = match try!(self.read_message()) {
ParameterDescription { types } => {
types.into_iter().map(Type::from_oid).collect()
}
_ => bad_response!(self),
};
let mut result_desc: Vec<_> = match try!(self.read_message()) {
RowDescription { descriptions } => {
descriptions.into_iter().map(|RowDescriptionEntry { name, type_oid, .. }| {
ResultDescription {
name: name,
ty: Type::from_oid(type_oid)
}
}).collect()
}
NoData => vec![],
_ => bad_response!(self)
};
try!(self.wait_for_ready());
if stmt_name != "" {
try!(self.set_type_names(param_types.iter_mut()));
try!(self.set_type_names(result_desc.iter_mut().map(|d| &mut d.ty)));
}
Ok((param_types, result_desc))
}
fn make_stmt_name(&mut self) -> String {
let stmt_name = format!("s{}", self.next_stmt_id);
self.next_stmt_id += 1;
stmt_name
}
fn prepare<'a>(&mut self, query: &str, conn: &'a Connection) -> Result<Statement<'a>> {
let stmt_name = self.make_stmt_name();
let (param_types, result_desc) = try!(self.raw_prepare(&*stmt_name, query));
Ok(Statement {
conn: conn,
name: stmt_name,
param_types: param_types,
result_desc: result_desc,
next_portal_id: Cell::new(0),
finished: false,
})
}
fn prepare_copy_in<'a>(&mut self, table: &str, rows: &[&str], conn: &'a Connection)
-> Result<CopyInStatement<'a>> {
let mut query = vec![];
let _ = write!(&mut query, "SELECT ");
let _ = util::comma_join(&mut query, rows.iter().cloned());
let _ = write!(&mut query, " FROM {}", table);
let query = String::from_utf8(query).unwrap();
let (_, result_desc) = try!(self.raw_prepare("", &*query));
let column_types = result_desc.into_iter().map(|desc| desc.ty).collect();
let mut query = vec![];
let _ = write!(&mut query, "COPY {} (", table);
let _ = util::comma_join(&mut query, rows.iter().cloned());
let _ = write!(&mut query, ") FROM STDIN WITH (FORMAT binary)");
let query = String::from_utf8(query).unwrap();
let stmt_name = self.make_stmt_name();
try!(self.raw_prepare(&*stmt_name, &*query));
Ok(CopyInStatement {
conn: conn,
name: stmt_name,
column_types: column_types,
finished: false,
})
}
fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> {
try!(self.write_messages(&[
Close {
variant: type_,
name: name,
},
Sync]));
let resp = match try!(self.read_message()) {
CloseComplete => Ok(()),
ErrorResponse { fields } => DbError::new(fields),
_ => bad_response!(self)
};
try!(self.wait_for_ready());
resp
}
fn set_type_names<'a, I>(&mut self, mut it: I) -> Result<()> where I: Iterator<&'a mut Type> {
for ty in it {
if let &Type::Unknown { oid, ref mut name } = ty {
*name = try!(self.get_type_name(oid));
}
}
Ok(())
}
fn get_type_name(&mut self, oid: Oid) -> Result<String> {
if let Some(name) = self.unknown_types.get(&oid) {
return Ok(name.clone());
}
let name = try!(self.quick_query(&*format!("SELECT typname FROM pg_type \
WHERE oid={}", oid)))
.into_iter().next().unwrap().into_iter().next().unwrap().unwrap();
self.unknown_types.insert(oid, name.clone());
Ok(name)
}
fn is_desynchronized(&self) -> bool {
self.desynchronized
}
fn canary(&self) -> u32 {
self.canary
}
fn wait_for_ready(&mut self) -> Result<()> {
match try!(self.read_message()) {
ReadyForQuery { .. } => Ok(()),
_ => bad_response!(self)
}
}
fn quick_query(&mut self, query: &str) -> Result<Vec<Vec<Option<String>>>> {
check_desync!(self);
try!(self.write_messages(&[Query { query: query }]));
let mut result = vec![];
loop {
match try!(self.read_message()) {
ReadyForQuery { .. } => break,
DataRow { row } => {
result.push(row.into_iter().map(|opt| {
opt.map(|b| String::from_utf8_lossy(&*b).into_string())
}).collect());
}
CopyInResponse { .. } => {
try!(self.write_messages(&[
CopyFail {
message: "COPY queries cannot be directly executed",
},
Sync]));
}
ErrorResponse { fields } => {
try!(self.wait_for_ready());
return DbError::new(fields);
}
_ => {}
}
}
Ok(result)
}
fn finish_inner(&mut self) -> Result<()> {
check_desync!(self);
self.canary = 0;
try!(self.write_messages(&[Terminate]));
Ok(())
}
}
pub struct Connection {
conn: RefCell<InnerConnection>
}
impl Connection {
pub fn connect<T>(params: T, ssl: &SslMode) -> result::Result<Connection, ConnectError>
where T: IntoConnectParams {
InnerConnection::connect(params, ssl).map(|conn| {
Connection { conn: RefCell::new(conn) }
})
}
pub fn set_notice_handler(&self, handler: Box<NoticeHandler+Send>) -> Box<NoticeHandler+Send> {
self.conn.borrow_mut().set_notice_handler(handler)
}
pub fn notifications<'a>(&'a self) -> Notifications<'a> {
Notifications { conn: self }
}
pub fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
let mut conn = self.conn.borrow_mut();
if conn.trans_depth != 0 {
return Err(Error::WrongTransaction);
}
conn.prepare(query, self)
}
pub fn prepare_copy_in<'a>(&'a self, table: &str, rows: &[&str])
-> Result<CopyInStatement<'a>> {
let mut conn = self.conn.borrow_mut();
if conn.trans_depth != 0 {
return Err(Error::WrongTransaction);
}
conn.prepare_copy_in(table, rows, self)
}
pub fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
let mut conn = self.conn.borrow_mut();
check_desync!(conn);
if conn.trans_depth != 0 {
return Err(Error::WrongTransaction);
}
try!(conn.quick_query("BEGIN"));
conn.trans_depth += 1;
Ok(Transaction {
conn: self,
commit: Cell::new(false),
depth: 1,
finished: false,
})
}
pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<uint> {
let (param_types, result_desc) = try!(self.conn.borrow_mut().raw_prepare("", query));
let stmt = Statement {
conn: self,
name: "".into_string(),
param_types: param_types,
result_desc: result_desc,
next_portal_id: Cell::new(0),
finished: true, };
stmt.execute(params)
}
pub fn batch_execute(&self, query: &str) -> Result<()> {
let mut conn = self.conn.borrow_mut();
if conn.trans_depth != 0 {
return Err(Error::WrongTransaction);
}
conn.quick_query(query).map(|_| ())
}
pub fn cancel_data(&self) -> CancelData {
self.conn.borrow().cancel_data
}
pub fn is_desynchronized(&self) -> bool {
self.conn.borrow().is_desynchronized()
}
pub fn finish(self) -> Result<()> {
let mut conn = self.conn.borrow_mut();
conn.finished = true;
conn.finish_inner()
}
fn canary(&self) -> u32 {
self.conn.borrow().canary()
}
fn write_messages(&self, messages: &[FrontendMessage]) -> IoResult<()> {
self.conn.borrow_mut().write_messages(messages)
}
}
pub enum SslMode {
None,
Prefer(SslContext),
Require(SslContext)
}
pub struct Transaction<'conn> {
conn: &'conn Connection,
commit: Cell<bool>,
depth: u32,
finished: bool,
}
#[unsafe_destructor]
impl<'conn> Drop for Transaction<'conn> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'conn> Transaction<'conn> {
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.conn.conn.borrow_mut();
debug_assert!(self.depth == conn.trans_depth);
let query = match (self.commit.get(), self.depth != 1) {
(false, true) => "ROLLBACK TO sp",
(false, false) => "ROLLBACK",
(true, true) => "RELEASE sp",
(true, false) => "COMMIT",
};
conn.trans_depth -= 1;
conn.quick_query(query).map(|_| ())
}
pub fn prepare(&self, query: &str) -> Result<Statement<'conn>> {
self.conn.conn.borrow_mut().prepare(query, self.conn)
}
pub fn prepare_copy_in(&self, table: &str, cols: &[&str]) -> Result<CopyInStatement<'conn>> {
self.conn.conn.borrow_mut().prepare_copy_in(table, cols, self.conn)
}
pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<uint> {
self.conn.execute(query, params)
}
pub fn batch_execute(&self, query: &str) -> Result<()> {
self.conn.conn.borrow_mut().quick_query(query).map(|_| ())
}
pub fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
if conn.trans_depth != self.depth {
return Err(Error::WrongTransaction);
}
try!(conn.quick_query("SAVEPOINT sp"));
conn.trans_depth += 1;
Ok(Transaction {
conn: self.conn,
commit: Cell::new(false),
depth: self.depth + 1,
finished: false,
})
}
pub fn lazy_query<'trans, 'stmt>(&'trans self,
stmt: &'stmt Statement,
params: &[&ToSql],
row_limit: i32)
-> Result<LazyRows<'trans, 'stmt>> {
if self.conn as *const _ != stmt.conn as *const _ {
return Err(Error::WrongConnection);
}
let conn = self.conn.conn.borrow();
check_desync!(conn);
if conn.trans_depth != self.depth {
return Err(Error::WrongTransaction);
}
drop(conn);
stmt.lazy_query(row_limit, params).map(|result| {
LazyRows {
_trans: self,
result: result
}
})
}
pub fn will_commit(&self) -> bool {
self.commit.get()
}
pub fn set_commit(&self) {
self.commit.set(true);
}
pub fn set_rollback(&self) {
self.commit.set(false);
}
pub fn commit(self) -> Result<()> {
self.set_commit();
self.finish()
}
pub fn finish(mut self) -> Result<()> {
self.finished = true;
self.finish_inner()
}
}
pub struct Statement<'conn> {
conn: &'conn Connection,
name: String,
param_types: Vec<Type>,
result_desc: Vec<ResultDescription>,
next_portal_id: Cell<uint>,
finished: bool,
}
#[unsafe_destructor]
impl<'conn> Drop for Statement<'conn> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'conn> Statement<'conn> {
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
conn.close_statement(&*self.name, b'S')
}
fn inner_execute(&self, portal_name: &str, row_limit: i32, params: &[&ToSql]) -> Result<()> {
let mut conn = self.conn.conn.borrow_mut();
if self.param_types.len() != params.len() {
return Err(Error::WrongParamCount {
expected: self.param_types.len(),
actual: params.len(),
});
}
let mut values = vec![];
for (param, ty) in params.iter().zip(self.param_types.iter()) {
values.push(try!(param.to_sql(ty)));
};
try!(conn.write_messages(&[
Bind {
portal: portal_name,
statement: &*self.name,
formats: &[1],
values: &*values,
result_formats: &[1]
},
Execute {
portal: portal_name,
max_rows: row_limit
},
Sync]));
match try!(conn.read_message()) {
BindComplete => Ok(()),
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
DbError::new(fields)
}
_ => {
conn.desynchronized = true;
Err(Error::BadResponse)
}
}
}
fn lazy_query<'a>(&'a self, row_limit: i32, params: &[&ToSql]) -> Result<Rows<'a>> {
let id = self.next_portal_id.get();
self.next_portal_id.set(id + 1);
let portal_name = format!("{}p{}", self.name, id);
try!(self.inner_execute(&*portal_name, row_limit, params));
let mut result = Rows {
stmt: self,
name: portal_name,
data: RingBuf::new(),
row_limit: row_limit,
more_rows: true,
finished: false,
};
try!(result.read_rows())
Ok(result)
}
pub fn param_types(&self) -> &[Type] {
&*self.param_types
}
pub fn result_descriptions(&self) -> &[ResultDescription] {
&*self.result_desc
}
pub fn execute(&self, params: &[&ToSql]) -> Result<uint> {
check_desync!(self.conn);
try!(self.inner_execute("", 0, params));
let mut conn = self.conn.conn.borrow_mut();
let num;
loop {
match try!(conn.read_message()) {
DataRow { .. } => {}
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
return DbError::new(fields);
}
CommandComplete { tag } => {
num = util::parse_update_count(tag);
break;
}
EmptyQueryResponse => {
num = 0;
break;
}
CopyInResponse { .. } => {
try!(conn.write_messages(&[
CopyFail {
message: "COPY queries cannot be directly executed",
},
Sync]));
}
_ => {
conn.desynchronized = true;
return Err(Error::BadResponse);
}
}
}
try!(conn.wait_for_ready());
Ok(num)
}
pub fn query<'a>(&'a self, params: &[&ToSql]) -> Result<Rows<'a>> {
check_desync!(self.conn);
self.lazy_query(0, params)
}
pub fn finish(mut self) -> Result<()> {
self.finished = true;
self.finish_inner()
}
}
#[deriving(PartialEq, Eq)]
pub struct ResultDescription {
pub name: String,
pub ty: Type
}
pub struct Rows<'stmt> {
stmt: &'stmt Statement<'stmt>,
name: String,
data: RingBuf<Vec<Option<Vec<u8>>>>,
row_limit: i32,
more_rows: bool,
finished: bool,
}
#[unsafe_destructor]
impl<'stmt> Drop for Rows<'stmt> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'stmt> Rows<'stmt> {
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.stmt.conn.conn.borrow_mut();
check_desync!(conn);
conn.close_statement(&*self.name, b'P')
}
fn read_rows(&mut self) -> Result<()> {
let mut conn = self.stmt.conn.conn.borrow_mut();
loop {
match try!(conn.read_message()) {
EmptyQueryResponse | CommandComplete { .. } => {
self.more_rows = false;
break;
}
PortalSuspended => {
self.more_rows = true;
break;
}
DataRow { row } => self.data.push_back(row),
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
return DbError::new(fields);
}
CopyInResponse { .. } => {
try!(conn.write_messages(&[
CopyFail {
message: "COPY queries cannot be directly executed",
},
Sync]));
}
_ => {
conn.desynchronized = true;
return Err(Error::BadResponse);
}
}
}
conn.wait_for_ready()
}
fn execute(&mut self) -> Result<()> {
try!(self.stmt.conn.write_messages(&[
Execute {
portal: &*self.name,
max_rows: self.row_limit
},
Sync]));
self.read_rows()
}
pub fn result_descriptions(&self) -> &'stmt [ResultDescription] {
self.stmt.result_descriptions()
}
pub fn finish(mut self) -> Result<()> {
self.finished = true;
self.finish_inner()
}
fn try_next(&mut self) -> Option<Result<Row<'stmt>>> {
if self.data.is_empty() && self.more_rows {
if let Err(err) = self.execute() {
return Some(Err(err));
}
}
self.data.pop_front().map(|row| Ok(Row { stmt: self.stmt, data: row }))
}
}
impl<'stmt> Iterator<Row<'stmt>> for Rows<'stmt> {
#[inline]
fn next(&mut self) -> Option<Row<'stmt>> {
self.try_next().map(|r| r.unwrap())
}
#[inline]
fn size_hint(&self) -> (uint, Option<uint>) {
let lower = self.data.len();
let upper = if self.more_rows {
None
} else {
Some(lower)
};
(lower, upper)
}
}
pub struct Row<'stmt> {
stmt: &'stmt Statement<'stmt>,
data: Vec<Option<Vec<u8>>>
}
impl<'stmt> Row<'stmt> {
pub fn len(&self) -> uint {
self.data.len()
}
pub fn result_descriptions(&self) -> &'stmt [ResultDescription] {
self.stmt.result_descriptions()
}
pub fn get_opt<I, T>(&self, idx: I) -> Result<T> where I: RowIndex, T: FromSql {
let idx = try!(idx.idx(self.stmt).ok_or(Error::InvalidColumn));
FromSql::from_sql(&self.stmt.result_desc[idx].ty, &self.data[idx])
}
pub fn get<I, T>(&self, idx: I) -> T where I: RowIndex + fmt::Show + Clone, T: FromSql {
match self.get_opt(idx.clone()) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err)
}
}
}
pub trait RowIndex {
fn idx(&self, stmt: &Statement) -> Option<uint>;
}
impl RowIndex for uint {
#[inline]
fn idx(&self, stmt: &Statement) -> Option<uint> {
if *self > stmt.result_desc.len() {
None
} else {
Some(*self)
}
}
}
impl<'a> RowIndex for &'a str {
#[inline]
fn idx(&self, stmt: &Statement) -> Option<uint> {
stmt.result_descriptions().iter().position(|d| &*d.name == *self)
}
}
pub struct LazyRows<'trans, 'stmt> {
result: Rows<'stmt>,
_trans: &'trans Transaction<'trans>,
}
impl<'trans, 'stmt> LazyRows<'trans, 'stmt> {
pub fn finish(self) -> Result<()> {
self.result.finish()
}
}
impl<'trans, 'stmt> Iterator<Result<Row<'stmt>>> for LazyRows<'trans, 'stmt> {
fn next(&mut self) -> Option<Result<Row<'stmt>>> {
self.result.try_next()
}
fn size_hint(&self) -> (uint, Option<uint>) {
self.result.size_hint()
}
}
pub struct CopyInStatement<'a> {
conn: &'a Connection,
name: String,
column_types: Vec<Type>,
finished: bool,
}
#[unsafe_destructor]
impl<'a> Drop for CopyInStatement<'a> {
fn drop(&mut self) {
if !self.finished {
let _ = self.finish_inner();
}
}
}
impl<'a> CopyInStatement<'a> {
fn finish_inner(&mut self) -> Result<()> {
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
conn.close_statement(&*self.name, b'S')
}
pub fn column_types(&self) -> &[Type] {
&*self.column_types
}
pub fn execute<'b, I, J>(&self, mut rows: I) -> Result<uint>
where I: Iterator<J>, J: Iterator<&'b (ToSql + 'b)> {
let mut conn = self.conn.conn.borrow_mut();
try!(conn.write_messages(&[
Bind {
portal: "",
statement: &*self.name,
formats: &[],
values: &[],
result_formats: &[]
},
Execute {
portal: "",
max_rows: 0,
},
Sync]));
match try!(conn.read_message()) {
BindComplete => {},
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
return DbError::new(fields);
}
_ => {
conn.desynchronized = true;
return Err(Error::BadResponse);
}
}
match try!(conn.read_message()) {
CopyInResponse { .. } => {}
_ => {
conn.desynchronized = true;
return Err(Error::BadResponse);
}
}
let mut buf = vec![];
let _ = buf.write(b"PGCOPY\n\xff\r\n\x00");
let _ = buf.write_be_i32(0);
let _ = buf.write_be_i32(0);
'l: for mut row in rows {
let _ = buf.write_be_i16(self.column_types.len() as i16);
let mut types = self.column_types.iter();
loop {
match (row.next(), types.next()) {
(Some(val), Some(ty)) => {
match val.to_sql(ty) {
Ok(None) => {
let _ = buf.write_be_i32(-1);
}
Ok(Some(val)) => {
let _ = buf.write_be_i32(val.len() as i32);
let _ = buf.write(&*val);
}
Err(err) => {
try_desync!(conn, conn.stream.write_message(
&CopyFail {
message: &*err.to_string(),
}));
break 'l;
}
}
}
(Some(_), None) | (None, Some(_)) => {
try_desync!(conn, conn.stream.write_message(
&CopyFail {
message: "Invalid column count",
}));
break 'l;
}
(None, None) => break
}
}
try_desync!(conn, conn.stream.write_message(
&CopyData {
data: &*buf
}));
buf.clear();
}
let _ = buf.write_be_i16(-1);
try!(conn.write_messages(&[
CopyData {
data: &*buf,
},
CopyDone,
Sync]));
let num = match try!(conn.read_message()) {
CommandComplete { tag } => util::parse_update_count(tag),
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
return DbError::new(fields);
}
_ => {
conn.desynchronized = true;
return Err(Error::BadResponse);
}
};
try!(conn.wait_for_ready());
Ok(num)
}
pub fn finish(mut self) -> Result<()> {
self.finished = true;
self.finish_inner()
}
}
pub trait GenericConnection {
fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>>;
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<uint>;
fn prepare_copy_in<'a>(&'a self, table: &str, columns: &[&str])
-> Result<CopyInStatement<'a>>;
fn transaction<'a>(&'a self) -> Result<Transaction<'a>>;
fn batch_execute(&self, query: &str) -> Result<()>;
}
impl GenericConnection for Connection {
fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
self.prepare(query)
}
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<uint> {
self.execute(query, params)
}
fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
self.transaction()
}
fn prepare_copy_in<'a>(&'a self, table: &str, columns: &[&str])
-> Result<CopyInStatement<'a>> {
self.prepare_copy_in(table, columns)
}
fn batch_execute(&self, query: &str) -> Result<()> {
self.batch_execute(query)
}
}
impl<'a> GenericConnection for Transaction<'a> {
fn prepare<'a>(&'a self, query: &str) -> Result<Statement<'a>> {
self.prepare(query)
}
fn execute(&self, query: &str, params: &[&ToSql]) -> Result<uint> {
self.execute(query, params)
}
fn transaction<'a>(&'a self) -> Result<Transaction<'a>> {
self.transaction()
}
fn prepare_copy_in<'a>(&'a self, table: &str, columns: &[&str])
-> Result<CopyInStatement<'a>> {
self.prepare_copy_in(table, columns)
}
fn batch_execute(&self, query: &str) -> Result<()> {
self.batch_execute(query)
}
}