use crate::{
util::align_to, Buffer, BufferAddress, BufferDescriptor, BufferSize, BufferUsages,
BufferViewMut, CommandEncoder, Device, MapMode,
};
use std::fmt;
use std::sync::{mpsc, Arc};
struct Chunk {
buffer: Arc<Buffer>,
size: BufferAddress,
offset: BufferAddress,
}
struct Exclusive<T>(T);
unsafe impl<T> Sync for Exclusive<T> {}
impl<T> Exclusive<T> {
fn new(value: T) -> Self {
Self(value)
}
fn get_mut(&mut self) -> &mut T {
&mut self.0
}
}
pub struct StagingBelt {
chunk_size: BufferAddress,
active_chunks: Vec<Chunk>,
closed_chunks: Vec<Chunk>,
free_chunks: Vec<Chunk>,
sender: Exclusive<mpsc::Sender<Chunk>>,
receiver: Exclusive<mpsc::Receiver<Chunk>>,
}
impl StagingBelt {
pub fn new(chunk_size: BufferAddress) -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
StagingBelt {
chunk_size,
active_chunks: Vec::new(),
closed_chunks: Vec::new(),
free_chunks: Vec::new(),
sender: Exclusive::new(sender),
receiver: Exclusive::new(receiver),
}
}
pub fn write_buffer(
&mut self,
encoder: &mut CommandEncoder,
target: &Buffer,
offset: BufferAddress,
size: BufferSize,
device: &Device,
) -> BufferViewMut<'_> {
let mut chunk = if let Some(index) = self
.active_chunks
.iter()
.position(|chunk| chunk.offset + size.get() <= chunk.size)
{
self.active_chunks.swap_remove(index)
} else {
self.receive_chunks();
if let Some(index) = self
.free_chunks
.iter()
.position(|chunk| size.get() <= chunk.size)
{
self.free_chunks.swap_remove(index)
} else {
let size = self.chunk_size.max(size.get());
Chunk {
buffer: Arc::new(device.create_buffer(&BufferDescriptor {
label: Some("(wgpu internal) StagingBelt staging buffer"),
size,
usage: BufferUsages::MAP_WRITE | BufferUsages::COPY_SRC,
mapped_at_creation: true,
})),
size,
offset: 0,
}
}
};
encoder.copy_buffer_to_buffer(&chunk.buffer, chunk.offset, target, offset, size.get());
let old_offset = chunk.offset;
chunk.offset = align_to(chunk.offset + size.get(), crate::MAP_ALIGNMENT);
self.active_chunks.push(chunk);
self.active_chunks
.last()
.unwrap()
.buffer
.slice(old_offset..old_offset + size.get())
.get_mapped_range_mut()
}
pub fn finish(&mut self) {
for chunk in self.active_chunks.drain(..) {
chunk.buffer.unmap();
self.closed_chunks.push(chunk);
}
}
pub fn recall(&mut self) {
self.receive_chunks();
for chunk in self.closed_chunks.drain(..) {
let sender = self.sender.get_mut().clone();
chunk
.buffer
.clone()
.slice(..)
.map_async(MapMode::Write, move |_| {
let _ = sender.send(chunk);
});
}
}
fn receive_chunks(&mut self) {
while let Ok(mut chunk) = self.receiver.get_mut().try_recv() {
chunk.offset = 0;
self.free_chunks.push(chunk);
}
}
}
impl fmt::Debug for StagingBelt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StagingBelt")
.field("chunk_size", &self.chunk_size)
.field("active_chunks", &self.active_chunks.len())
.field("closed_chunks", &self.closed_chunks.len())
.field("free_chunks", &self.free_chunks.len())
.finish_non_exhaustive()
}
}