use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use bson::RawDocument;
use futures_core::{future::BoxFuture, Stream};
use futures_util::StreamExt;
use serde::{de::DeserializeOwned, Deserialize};
#[cfg(test)]
use tokio::sync::oneshot;
use super::{
common::{
kill_cursor,
CursorBuffer,
CursorInformation,
CursorState,
GenericCursor,
GetMoreProvider,
GetMoreProviderResult,
PinnedConnection,
},
BatchValue,
CursorStream,
};
use crate::{
bson::Document,
change_stream::event::ResumeToken,
client::options::ServerAddress,
cmap::conn::PinnedConnectionHandle,
cursor::CursorSpecification,
error::{Error, Result},
operation::GetMore,
results::GetMoreResult,
Client,
ClientSession,
};
#[derive(Debug)]
pub struct SessionCursor<T> {
client: Client,
info: CursorInformation,
state: Option<CursorState>,
drop_address: Option<ServerAddress>,
_phantom: PhantomData<T>,
#[cfg(test)]
kill_watcher: Option<oneshot::Sender<()>>,
}
impl<T> SessionCursor<T> {
pub(crate) fn new(
client: Client,
spec: CursorSpecification,
pinned: Option<PinnedConnectionHandle>,
) -> Self {
let exhausted = spec.info.id == 0;
Self {
client,
info: spec.info,
drop_address: None,
_phantom: Default::default(),
#[cfg(test)]
kill_watcher: None,
state: CursorState {
buffer: CursorBuffer::new(spec.initial_buffer),
exhausted,
post_batch_resume_token: None,
pinned_connection: PinnedConnection::new(pinned),
}
.into(),
}
}
}
impl<T> SessionCursor<T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
pub fn stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionCursorStream<'_, 'session, T> {
self.make_stream(session)
}
pub async fn next(&mut self, session: &mut ClientSession) -> Option<Result<T>> {
self.stream(session).next().await
}
}
impl<T> SessionCursor<T> {
fn make_stream<'session>(
&mut self,
session: &'session mut ClientSession,
) -> SessionCursorStream<'_, 'session, T> {
let get_more_provider = ExplicitSessionGetMoreProvider::new(session);
SessionCursorStream {
generic_cursor: ExplicitSessionCursor::from_state(
self.take_state(),
self.client.clone(),
self.info.clone(),
get_more_provider,
),
session_cursor: self,
}
}
fn take_state(&mut self) -> CursorState {
self.state.take().unwrap()
}
pub async fn advance(&mut self, session: &mut ClientSession) -> Result<bool> {
self.make_stream(session).generic_cursor.advance().await
}
pub fn current(&self) -> &RawDocument {
self.state.as_ref().unwrap().buffer.current().unwrap()
}
pub fn deserialize_current<'a>(&'a self) -> Result<T>
where
T: Deserialize<'a>,
{
bson::from_slice(self.current().as_bytes()).map_err(Error::from)
}
pub fn with_type<'a, D>(mut self) -> SessionCursor<D>
where
D: Deserialize<'a>,
{
let out = SessionCursor {
client: self.client.clone(),
info: self.info.clone(),
state: Some(self.take_state()),
drop_address: self.drop_address.take(),
_phantom: Default::default(),
#[cfg(test)]
kill_watcher: self.kill_watcher.take(),
};
self.mark_exhausted(); out
}
pub(crate) fn address(&self) -> &ServerAddress {
&self.info.address
}
pub(crate) fn set_drop_address(&mut self, address: ServerAddress) {
self.drop_address = Some(address);
}
#[cfg(test)]
pub(crate) fn set_kill_watcher(&mut self, tx: oneshot::Sender<()>) {
assert!(
self.kill_watcher.is_none(),
"cursor already has a kill_watcher"
);
self.kill_watcher = Some(tx);
}
}
impl<T> SessionCursor<T> {
fn mark_exhausted(&mut self) {
self.state.as_mut().unwrap().exhausted = true;
}
pub(crate) fn is_exhausted(&self) -> bool {
self.state.as_ref().unwrap().exhausted
}
}
impl<T> Drop for SessionCursor<T> {
fn drop(&mut self) {
if self.is_exhausted() {
return;
}
kill_cursor(
self.client.clone(),
&self.info.ns,
self.info.id,
self.state.as_ref().unwrap().pinned_connection.replicate(),
self.drop_address.take(),
#[cfg(test)]
self.kill_watcher.take(),
);
}
}
type ExplicitSessionCursor<'session, T> =
GenericCursor<ExplicitSessionGetMoreProvider<'session>, T>;
pub struct SessionCursorStream<'cursor, 'session, T = Document> {
session_cursor: &'cursor mut SessionCursor<T>,
generic_cursor: ExplicitSessionCursor<'session, T>,
}
impl<'cursor, 'session, T> SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
pub(crate) fn post_batch_resume_token(&self) -> Option<&ResumeToken> {
self.generic_cursor.post_batch_resume_token()
}
pub(crate) fn client(&self) -> &Client {
&self.session_cursor.client
}
}
impl<'cursor, 'session, T> Stream for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
type Item = Result<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.generic_cursor).poll_next(cx)
}
}
impl<'cursor, 'session, T> CursorStream for SessionCursorStream<'cursor, 'session, T>
where
T: DeserializeOwned + Unpin + Send + Sync,
{
fn poll_next_in_batch(&mut self, cx: &mut Context<'_>) -> Poll<Result<BatchValue>> {
self.generic_cursor.poll_next_in_batch(cx)
}
}
impl<'cursor, 'session, T> Drop for SessionCursorStream<'cursor, 'session, T> {
fn drop(&mut self) {
self.session_cursor.state = Some(self.generic_cursor.take_state());
}
}
enum ExplicitSessionGetMoreProvider<'session> {
Executing(BoxFuture<'session, ExecutionResult<'session>>),
Idle(MutableSessionReference<'session>),
}
impl<'session> ExplicitSessionGetMoreProvider<'session> {
fn new(session: &'session mut ClientSession) -> Self {
Self::Idle(MutableSessionReference { reference: session })
}
}
impl<'session> GetMoreProvider for ExplicitSessionGetMoreProvider<'session> {
type ResultType = ExecutionResult<'session>;
type GetMoreFuture = BoxFuture<'session, ExecutionResult<'session>>;
fn executing_future(&mut self) -> Option<&mut Self::GetMoreFuture> {
match self {
Self::Executing(future) => Some(future),
Self::Idle { .. } => None,
}
}
fn clear_execution(&mut self, session: &'session mut ClientSession, _exhausted: bool) {
*self = Self::Idle(MutableSessionReference { reference: session })
}
fn start_execution(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: Option<&PinnedConnectionHandle>,
) {
take_mut::take(self, |self_| {
if let ExplicitSessionGetMoreProvider::Idle(session) = self_ {
let pinned_connection = pinned_connection.map(|c| c.replicate());
let future = Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.as_ref());
let get_more_result = client
.execute_operation(get_more, Some(&mut *session.reference))
.await;
ExecutionResult {
get_more_result,
session: session.reference,
}
});
return ExplicitSessionGetMoreProvider::Executing(future);
}
self_
});
}
fn execute(
&mut self,
info: CursorInformation,
client: Client,
pinned_connection: PinnedConnection,
) -> BoxFuture<'_, Result<GetMoreResult>> {
match self {
Self::Idle(ref mut session) => Box::pin(async move {
let get_more = GetMore::new(info, pinned_connection.handle());
client
.execute_operation(get_more, Some(&mut *session.reference))
.await
}),
Self::Executing(_fut) => Box::pin(async {
Err(Error::internal(
"streaming the cursor was cancelled while a request was in progress and must \
be continued before iterating manually",
))
}),
}
}
}
struct ExecutionResult<'session> {
get_more_result: Result<GetMoreResult>,
session: &'session mut ClientSession,
}
impl<'session> GetMoreProviderResult for ExecutionResult<'session> {
type Session = &'session mut ClientSession;
fn as_ref(&self) -> std::result::Result<&GetMoreResult, &Error> {
self.get_more_result.as_ref()
}
fn into_parts(self) -> (Result<GetMoreResult>, Self::Session) {
(self.get_more_result, self.session)
}
}
struct MutableSessionReference<'a> {
reference: &'a mut ClientSession,
}