use std::collections::HashMap;
use re_types::{
archetypes::Tensor,
datatypes::{TensorBuffer, TensorData, TensorDimension},
tensor_data::TensorCastError,
Archetype as _, AsComponents as _,
};
mod util;
#[test]
fn tensor_roundtrip() {
let all_expected = [Tensor {
data: TensorData {
shape: vec![
TensorDimension {
size: 2,
name: None,
},
TensorDimension {
size: 3,
name: None,
},
],
buffer: TensorBuffer::U8(vec![1, 2, 3, 4, 5, 6].into()),
}
.into(),
}];
let all_arch_serialized = [Tensor::try_from(ndarray::array![[1u8, 2, 3], [4, 5, 6]])
.unwrap()
.to_arrow()
.unwrap()];
let expected_extensions: HashMap<_, _> = [("data", vec!["rerun.components.TensorData"])].into();
for (expected, serialized) in all_expected.into_iter().zip(all_arch_serialized) {
for (field, array) in &serialized {
eprintln!("{} = {array:#?}", field.name);
if false {
util::assert_extensions(
&**array,
expected_extensions[field.name.as_str()].as_slice(),
);
}
}
let deserialized = Tensor::from_arrow(serialized).unwrap();
similar_asserts::assert_eq!(expected, deserialized);
}
}
#[test]
fn convert_tensor_to_ndarray_u8() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::U8(vec![0; 60].into()),
);
let n = ndarray::ArrayViewD::<u8>::try_from(&t).unwrap();
assert_eq!(n.shape(), &[3, 4, 5]);
}
#[test]
fn convert_tensor_to_ndarray_u16() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::U16(vec![0_u16; 60].into()),
);
let n = ndarray::ArrayViewD::<u16>::try_from(&t).unwrap();
assert_eq!(n.shape(), &[3, 4, 5]);
}
#[test]
fn convert_tensor_to_ndarray_f32() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::F32(vec![0_f32; 60].into()),
);
let n = ndarray::ArrayViewD::<f32>::try_from(&t).unwrap();
assert_eq!(n.shape(), &[3, 4, 5]);
}
#[test]
fn convert_ndarray_u8_to_tensor() {
let n = ndarray::array![[1., 2., 3.], [4., 5., 6.]];
let t = TensorData::try_from(n).unwrap();
assert_eq!(
t.shape(),
&[TensorDimension::unnamed(2), TensorDimension::unnamed(3)]
);
}
#[test]
fn convert_ndarray_slice_to_tensor() {
let n = ndarray::array![[1., 2., 3.], [4., 5., 6.]];
let n = &n.slice(ndarray::s![.., 1]);
let t = TensorData::try_from(*n).unwrap();
assert_eq!(t.shape(), &[TensorDimension::unnamed(2)]);
}
#[test]
fn check_slices() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::U16((0_u16..60).collect::<Vec<u16>>().into()),
);
let n = ndarray::ArrayViewD::<u16>::try_from(&t).unwrap();
assert_eq!(n[[0, 0, 0]], 0);
assert_eq!(n[[2, 3, 4]], 59);
for z in 0..3 {
for y in 0..4 {
for x in 0..5 {
assert_eq!(n[[z, y, x]] as usize, z * 4 * 5 + y * 5 + x);
}
}
}
let sl: ndarray::ArrayView2<'_, u16> = n.slice(ndarray::s![.., 1, ..]);
assert_eq!(sl.shape(), &[3, 5]);
assert_eq!(sl[[0, 0]], 5);
assert_eq!(sl[[1, 3]], 28);
}
#[test]
fn check_tensor_shape_error() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::U8(vec![0; 59].into()),
);
let n = ndarray::ArrayViewD::<u8>::try_from(&t);
assert_eq!(
n,
Err(TensorCastError::BadTensorShape {
source: ndarray::ShapeError::from_kind(ndarray::ErrorKind::OutOfBounds)
})
);
}
#[test]
fn check_tensor_type_error() {
let t = TensorData::new(
vec![
TensorDimension::unnamed(3),
TensorDimension::unnamed(4),
TensorDimension::unnamed(5),
],
TensorBuffer::U16(vec![0; 60].into()),
);
let n = ndarray::ArrayViewD::<u8>::try_from(&t);
assert_eq!(n, Err(TensorCastError::TypeMismatch));
}