1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
|
import logging
import os
from collections.abc import Mapping
import h5py
import numpy
from tomoscan.esrf.volume import HDF5Volume
from tomoscan.utils.hdf5 import DatasetReader
from tomoscan.volumebase import VolumeBase
_logger = logging.getLogger(__name__)
def concatenate(output_volume: VolumeBase, volumes: tuple, axis: int) -> None:
"""
Function to do 'raw' concatenation on volumes.
This is agnostic of any metadata. So if you want to ensure about coherence of metadata (and data) you must do it yourself
data will be concatenate in the order volumes are provided. Volumes data must be 3D. Concatenate data will be 3D and concatenation will be done
over the axis `axis`
concatenation will be done with a virtual dataset if input volumes and output_volume are HDF5Volume instances.
warning: concatenation enforce writing data and metadata to disk
:param output_volume VolumeBase: volume to create
:param tuple volumes: tuple of VolumeBase instances
:param axis: axis to use for doing the concatenation. must be in 0, 1, 2
"""
# 0. do some check
if not isinstance(output_volume, VolumeBase):
raise TypeError(
f"output_volume is expected to be an instance of {VolumeBase}. {type(output_volume)} provided"
)
if not isinstance(axis, int):
raise TypeError(f"axis must be an int. {type(axis)} provided")
elif axis not in (0, 1, 2):
raise ValueError(f"axis must be in (0, 1, 2). {axis} provided")
if not isinstance(volumes, tuple):
raise TypeError(f"volumes must be a tuple. {type(volumes)} provided")
else:
is_invalid = lambda y: not isinstance(y, VolumeBase)
invalids = tuple(filter(is_invalid, volumes))
if len(invalids) > 0:
raise ValueError(f"Several non-volumes found. ({invalids})")
from tomoscan.esrf.volume.jp2kvolume import JP2KVolume # avoid cyclic import
if isinstance(output_volume, JP2KVolume) and output_volume.rescale_data is True:
_logger.warning(
"concatenation will rescale data frame. If you want to avoid this please set output volume 'rescale_data' to False"
)
# 1. compute final shape
def get_volume_shape():
if axis == 0:
new_shape = [0, None, None]
elif axis == 1:
new_shape = [None, 0, None]
else:
new_shape = [None, None, 0]
for vol in volumes:
vol_shape = vol.get_volume_shape()
if vol_shape is None:
raise ValueError(
f"Unable to find shape for volume {vol.get_identifier().to_str()}"
)
new_shape[axis] += vol_shape[axis]
if axis == 0:
if new_shape[1] is None:
new_shape[1], new_shape[2] = vol_shape[1], vol_shape[2]
elif new_shape[1] != vol_shape[1] or new_shape[2] != vol_shape[2]:
raise ValueError("Found incoherent shapes. Unable to concatenate")
elif axis == 1:
if new_shape[0] is None:
new_shape[0], new_shape[2] = vol_shape[0], vol_shape[2]
elif new_shape[0] != vol_shape[0] or new_shape[2] != vol_shape[2]:
raise ValueError("Found incoherent shapes. Unable to concatenate")
else:
if new_shape[0] is None:
new_shape[0], new_shape[1] = vol_shape[0], vol_shape[1]
elif new_shape[0] != vol_shape[0] or new_shape[1] != vol_shape[1]:
raise ValueError("Found incoherent shapes. Unable to concatenate")
return tuple(new_shape)
final_shape = get_volume_shape()
if final_shape is None:
# should never be raised. Other error type is expected to be raised first
raise RuntimeError("Unable to get final volume shape")
# 2. Handle volume data (concatenation)
if isinstance(output_volume, HDF5Volume) and numpy.all(
[isinstance(vol, HDF5Volume)] for vol in volumes
):
# 2.1 in the case of HDF5 we can short cut this by creating a virtual dataset. Would highly speed up processing avoid copy
# note: in theory this could be done for any input_volume type using external dataset but we don't want to spend ages on
# this use case for now. Some work around this (using EDf) has been done in nxtomomill for information. See https://gitlab.esrf.fr/tomotools/nxtomomill/-/merge_requests/115
_logger.info("start creation of external dataset")
with DatasetReader(volumes[0].data_url) as dataset:
data_type = dataset.dtype
# FIXME: avoid keeping some file open. not clear why this is needed
dataset = None
with h5py.File(output_volume.data_url.file_path(), mode="a") as h5s:
# 2.1.1 check data path
if output_volume.data_url.data_path() in h5s:
if output_volume.overwrite:
del h5s[output_volume.data_url.data_path()]
else:
raise OSError(
f"Unable to save data to {output_volume.data_url.data_path()}. This path already exists in {output_volume.data_url.file_path()}. If you want you can ask to overwrite it (from the output volume)."
)
# 2.1.2 create virtual layout
v_layout = h5py.VirtualLayout(
shape=final_shape,
dtype=data_type,
)
# 2.1.3 create virtual source
start_index = 0
for volume in volumes:
# provide relative path
rel_file_path = os.path.relpath(
volume.data_url.file_path(),
os.path.dirname(output_volume.data_url.file_path()),
)
rel_file_path = "./" + rel_file_path
data_path = volume.data_url.data_path()
vol_shape = volume.get_volume_shape()
vs = h5py.VirtualSource(
rel_file_path,
name=data_path,
shape=vol_shape,
)
stop_index = start_index + vol_shape[axis]
if axis == 0:
v_layout[start_index:stop_index] = vs
elif axis == 1:
v_layout[:, start_index:stop_index, :] = vs
elif axis == 2:
v_layout[:, :, start_index:stop_index] = vs
start_index = stop_index
# 2.1.4 create virtual dataset
h5s.create_virtual_dataset(
name=output_volume.data_url.data_path(), layout=v_layout
)
else:
# 2.1 default case (duplicate all input data slice by slice)
# 2.1.1 special case of the concatenation other axis 0
if axis == 0:
def iter_input():
for vol in volumes:
for data_slice in vol.browse_slices():
yield data_slice
for frame_dumper, input_slice in zip(
output_volume.data_file_saver_generator(
n_frames=final_shape[0],
data_url=output_volume.data_url,
overwrite=output_volume.overwrite,
),
iter_input(),
):
frame_dumper[:] = input_slice
else:
# 2.1.2 concatenation with data duplication over axis 1 or 2
for i_z, frame_dumper in enumerate(
output_volume.data_file_saver_generator(
n_frames=final_shape[0],
data_url=output_volume.data_url,
overwrite=output_volume.overwrite,
)
):
if axis == 1:
frame_dumper[:] = numpy.concatenate(
[vol.get_slice(axis=0, index=i_z) for vol in volumes],
axis=0,
)
elif axis == 2:
frame_dumper[:] = numpy.concatenate(
[vol.get_slice(axis=0, index=i_z) for vol in volumes],
axis=1,
)
else:
raise RuntimeError
# 3. handle metadata
for vol in volumes:
if vol.metadata is None:
try:
vol.load_metadata(store=True)
except Exception as e:
_logger.error(f"fail to load metadata for {vol}. Error is {e}")
output_volume.metadata = {}
[update_metadata(output_volume.metadata, vol.metadata) for vol in volumes]
output_volume.save_metadata()
def update_metadata(ddict_1: dict, ddict_2: dict) -> dict:
"""
update metadata ddict_1 from ddict_2
metadata are dict. And those dicts
warning: will modify ddict_1
"""
if not isinstance(ddict_1, dict) or not isinstance(ddict_2, dict):
raise TypeError(f"ddict_1 and ddict_2 are expected to be instances of {dict}")
for key, value in ddict_2.items():
if isinstance(value, Mapping):
ddict_1[key] = update_metadata(ddict_1.get(key, {}), value)
else:
ddict_1[key] = value
return ddict_1
def rescale_data(data, new_min, new_max, data_min=None, data_max=None):
if data_min is None:
data_min = numpy.min(data)
if data_max is None:
data_max = numpy.max(data)
return (new_max - new_min) / (data_max - data_min) * (data - data_min) + new_min
|