use crate::ffi_ptr_ext::FfiPtrExt;
use crate::impl_::callback::IntoPyCallbackOutput;
use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef};
use crate::impl_::pyclass_init::{PyNativeTypeInitializer, PyObjectInit};
use crate::types::PyAnyMethods;
use crate::{ffi, Bound, Py, PyClass, PyResult, Python};
use crate::{
ffi::PyTypeObject,
pycell::impl_::{PyClassBorrowChecker, PyClassMutability, PyClassObjectContents},
};
use std::{
cell::UnsafeCell,
marker::PhantomData,
mem::{ManuallyDrop, MaybeUninit},
};
pub struct PyClassInitializer<T: PyClass>(PyClassInitializerImpl<T>);
enum PyClassInitializerImpl<T: PyClass> {
Existing(Py<T>),
New {
init: T,
super_init: <T::BaseType as PyClassBaseType>::Initializer,
},
}
impl<T: PyClass> PyClassInitializer<T> {
#[track_caller]
#[inline]
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
assert!(
super_init.can_be_subclassed(),
"you cannot add a subclass to an existing value",
);
Self(PyClassInitializerImpl::New { init, super_init })
}
#[track_caller]
#[inline]
pub fn add_subclass<S>(self, subclass_value: S) -> PyClassInitializer<S>
where
S: PyClass<BaseType = T>,
S::BaseType: PyClassBaseType<Initializer = Self>,
{
PyClassInitializer::new(subclass_value, self)
}
pub(crate) fn create_class_object(self, py: Python<'_>) -> PyResult<Bound<'_, T>>
where
T: PyClass,
{
unsafe { self.create_class_object_of_type(py, T::type_object_raw(py)) }
}
pub(crate) unsafe fn create_class_object_of_type(
self,
py: Python<'_>,
target_type: *mut crate::ffi::PyTypeObject,
) -> PyResult<Bound<'_, T>>
where
T: PyClass,
{
#[repr(C)]
struct PartiallyInitializedClassObject<T: PyClass> {
_ob_base: <T::BaseType as PyClassBaseType>::LayoutAsBase,
contents: MaybeUninit<PyClassObjectContents<T>>,
}
let (init, super_init) = match self.0 {
PyClassInitializerImpl::Existing(value) => return Ok(value.into_bound(py)),
PyClassInitializerImpl::New { init, super_init } => (init, super_init),
};
let obj = super_init.into_new_object(py, target_type)?;
let part_init: *mut PartiallyInitializedClassObject<T> = obj.cast();
std::ptr::write(
(*part_init).contents.as_mut_ptr(),
PyClassObjectContents {
value: ManuallyDrop::new(UnsafeCell::new(init)),
borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
thread_checker: T::ThreadChecker::new(),
dict: T::Dict::INIT,
weakref: T::WeakRef::INIT,
},
);
Ok(obj.assume_owned(py).downcast_into_unchecked())
}
}
impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
unsafe fn into_new_object(
self,
py: Python<'_>,
subtype: *mut PyTypeObject,
) -> PyResult<*mut ffi::PyObject> {
self.create_class_object_of_type(py, subtype)
.map(Bound::into_ptr)
}
#[inline]
fn can_be_subclassed(&self) -> bool {
!matches!(self.0, PyClassInitializerImpl::Existing(..))
}
}
impl<T> From<T> for PyClassInitializer<T>
where
T: PyClass,
T::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<T::BaseType>>,
{
#[inline]
fn from(value: T) -> PyClassInitializer<T> {
Self::new(value, PyNativeTypeInitializer(PhantomData))
}
}
impl<S, B> From<(S, B)> for PyClassInitializer<S>
where
S: PyClass<BaseType = B>,
B: PyClass + PyClassBaseType<Initializer = PyClassInitializer<B>>,
B::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<B::BaseType>>,
{
#[track_caller]
#[inline]
fn from(sub_and_base: (S, B)) -> PyClassInitializer<S> {
let (sub, base) = sub_and_base;
PyClassInitializer::from(base).add_subclass(sub)
}
}
impl<T: PyClass> From<Py<T>> for PyClassInitializer<T> {
#[inline]
fn from(value: Py<T>) -> PyClassInitializer<T> {
PyClassInitializer(PyClassInitializerImpl::Existing(value))
}
}
impl<'py, T: PyClass> From<Bound<'py, T>> for PyClassInitializer<T> {
#[inline]
fn from(value: Bound<'py, T>) -> PyClassInitializer<T> {
PyClassInitializer::from(value.unbind())
}
}
impl<T, U> IntoPyCallbackOutput<'_, PyClassInitializer<T>> for U
where
T: PyClass,
U: Into<PyClassInitializer<T>>,
{
#[inline]
fn convert(self, _py: Python<'_>) -> PyResult<PyClassInitializer<T>> {
Ok(self.into())
}
}
#[cfg(all(test, feature = "macros"))]
mod tests {
use crate::prelude::*;
#[pyclass(crate = "crate", subclass)]
struct BaseClass {}
#[pyclass(crate = "crate", extends=BaseClass)]
struct SubClass {
_data: i32,
}
#[test]
#[should_panic]
fn add_subclass_to_py_is_unsound() {
Python::with_gil(|py| {
let base = Py::new(py, BaseClass {}).unwrap();
let _subclass = PyClassInitializer::from(base).add_subclass(SubClass { _data: 42 });
});
}
}