use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use encoding_rs::{Encoding, UTF_8};
use futures_util::stream::{Stream, TryStreamExt};
use http::header::HeaderMap;
#[cfg(feature = "json")]
use serde::de::DeserializeOwned;
use spin::mutex::spin::SpinMutex as Mutex;
use crate::content_disposition::ContentDisposition;
use crate::multipart::{MultipartState, StreamingStage};
use crate::{helpers, Error};
#[derive(Debug)]
pub struct Field<'r> {
state: Arc<Mutex<MultipartState<'r>>>,
done: bool,
headers: HeaderMap,
content_disposition: ContentDisposition,
content_type: Option<mime::Mime>,
idx: usize,
}
impl<'r> Field<'r> {
pub(crate) fn new(
state: Arc<Mutex<MultipartState<'r>>>,
headers: HeaderMap,
idx: usize,
content_disposition: ContentDisposition,
) -> Self {
let content_type = helpers::parse_content_type(&headers);
Field {
state,
headers,
content_disposition,
content_type,
idx,
done: false,
}
}
pub fn name(&self) -> Option<&str> {
self.content_disposition.field_name.as_deref()
}
pub fn file_name(&self) -> Option<&str> {
self.content_disposition.file_name.as_deref()
}
pub fn content_type(&self) -> Option<&mime::Mime> {
self.content_type.as_ref()
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub async fn bytes(self) -> crate::Result<Bytes> {
let mut buf = BytesMut::new();
let mut this = self;
while let Some(bytes) = this.chunk().await? {
buf.extend_from_slice(&bytes);
}
Ok(buf.freeze())
}
pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
self.try_next().await
}
#[cfg(feature = "json")]
#[cfg_attr(nightly, doc(cfg(feature = "json")))]
pub async fn json<T: DeserializeOwned>(self) -> crate::Result<T> {
serde_json::from_slice(&self.bytes().await?).map_err(Error::DecodeJson)
}
pub async fn text(self) -> crate::Result<String> {
self.text_with_charset("utf-8").await
}
pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
let encoding_name = self
.content_type()
.and_then(|mime| mime.get_param(mime::CHARSET))
.map(|charset| charset.as_str())
.unwrap_or(default_encoding);
let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
let bytes = self.bytes().await?;
Ok(encoding.decode(&bytes).0.into_owned())
}
pub fn index(&self) -> usize {
self.idx
}
}
impl Stream for Field<'_> {
type Item = Result<Bytes, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
debug_assert!(self.state.try_lock().is_some(), "expected exlusive lock");
let state = self.state.clone();
let mut lock = match state.try_lock() {
Some(lock) => lock,
None => return Poll::Ready(Some(Err(Error::LockFailure))),
};
let state = &mut *lock;
if let Err(err) = state.buffer.poll_stream(cx) {
return Poll::Ready(Some(Err(err)));
}
match state
.buffer
.read_field_data(&state.boundary, state.curr_field_name.as_deref())
{
Ok(Some((done, bytes))) => {
state.curr_field_size_counter += bytes.len() as u64;
if state.curr_field_size_counter > state.curr_field_size_limit {
return Poll::Ready(Some(Err(Error::FieldSizeExceeded {
limit: state.curr_field_size_limit,
field_name: state.curr_field_name.clone(),
})));
}
if done {
state.stage = StreamingStage::ReadingBoundary;
self.done = true;
}
Poll::Ready(Some(Ok(bytes)))
}
Ok(None) => Poll::Pending,
Err(err) => Poll::Ready(Some(Err(err))),
}
}
}