From bc029f16ab8d6809db2fff96e9b62fcca386638e Mon Sep 17 00:00:00 2001 From: Theresa Reisch Date: Tue, 18 Nov 2025 12:31:25 +0100 Subject: [PATCH 1/3] add mf hits fixes --- salt/models/maskformer.py | 32 ++++++++++++++++++++++-------- salt/models/maskformer_loss.py | 36 +++++++++++++++++++++++----------- salt/models/matcher.py | 3 ++- salt/utils/cli.py | 6 +----- 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/salt/models/maskformer.py b/salt/models/maskformer.py index 8d89f3e3..00698a23 100644 --- a/salt/models/maskformer.py +++ b/salt/models/maskformer.py @@ -102,8 +102,15 @@ class MaskDecoder(nn.Module): # MF only supports one input, if we have multiple then we have no way of knowing # what section of the embedding relates to objects we want to generate masks for if isinstance(pad_mask, dict): - assert len(pad_mask) == 1, "Maskformer only supports one input." - pad_mask = next(iter(pad_mask.values())) + # Remove masks that correspond to register tokens (single-token masks) + filtered = {k: v for k, v in pad_mask.items() if v.shape[1] > 1} + + if len(filtered) == 1: + pad_mask = next(iter(filtered.values())) + else: + raise AssertionError( + f"Maskformer expected one real input mask (length>1). Got: { {k: v.shape for k, v in pad_mask.items()} }" + ) x = preds["embed_xs"] # apply norm @@ -111,13 +118,22 @@ class MaskDecoder(nn.Module): x = self.norm2(x) # Add a dummy track to the inputs (and to pad mask) to stop onnx complaining - xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) - x = torch.cat([x, xpad], dim=1) - if pad_mask is not None: - padpad_mask = torch.zeros( - (pad_mask.shape[0], 1), device=pad_mask.device, dtype=pad_mask.dtype + #xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) + #x = torch.cat([x, xpad], dim=1) + #if pad_mask is not None: + # padpad_mask = torch.zeros( + # (pad_mask.shape[0], 1), device=pad_mask.device, dtype=pad_mask.dtype + # ) + # pad_mask = torch.cat([pad_mask, padpad_mask], dim=1) + + if pad_mask is not None and pad_mask.shape[1] != x.shape[1]: + # dynamically pad with False to match embedding length + pad_len = x.shape[1] - pad_mask.shape[1] + pad_mask = torch.cat( + [pad_mask, torch.zeros((pad_mask.shape[0], pad_len), device=pad_mask.device, dtype=pad_mask.dtype)], + dim=1 ) - pad_mask = torch.cat([pad_mask, padpad_mask], dim=1) + intermediate_outputs: list | None = [] if self.aux_loss else None for layer in self.layers: diff --git a/salt/models/maskformer_loss.py b/salt/models/maskformer_loss.py index 24b287e7..4f4fc8f4 100644 --- a/salt/models/maskformer_loss.py +++ b/salt/models/maskformer_loss.py @@ -151,22 +151,36 @@ class MaskFormerLoss(nn.Module): loss_weights=matcher_weights, ) - def loss_labels(self, preds, labels): - """Classification loss (NLL) - labels dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]. - """ - # use the new indices to calculate the loss - # process full inidices - flav_pred_logits = preds["class_logits"].flatten(0, 1) - flavour_labels = labels["object_class"].flatten(0, 1) - if flav_pred_logits.shape[1] == 1: + + def loss_labels( + self, + preds: dict[str, torch.Tensor], + labels: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Compute the classification loss on object classes, supporting both binary and multi-class tasks.""" + + flav_pred_logits = preds["class_logits"].flatten(0, 1) # (batch * n_queries, n_classes) + flavour_labels = labels["object_class"].flatten(0, 1) # (batch * n_queries,) + + n_classes = flav_pred_logits.shape[1] + + if n_classes == 1: + # binary classification loss = F.binary_cross_entropy_with_logits( - flav_pred_logits.squeeze(), flavour_labels.float(), pos_weight=self.empty_weight + flav_pred_logits.squeeze(-1), flavour_labels.float(), pos_weight=self.empty_weight ) else: - loss = F.cross_entropy(flav_pred_logits, flavour_labels, self.empty_weight) + # multi-class classification + # Adjust weight vector to match n_classes + if self.empty_weight is None or len(self.empty_weight) != n_classes: + weight = torch.ones(n_classes, device=flav_pred_logits.device) + else: + weight = self.empty_weight + loss = F.cross_entropy(flav_pred_logits, flavour_labels.long(), weight=weight) + return {"object_class_ce": loss} + def loss_masks(self, preds, labels): """Compute the losses related to the masks: the focal loss and the dice loss. labels dicts must contain the key "masks" containing a tensor of dim diff --git a/salt/models/matcher.py b/salt/models/matcher.py index 6f86ebc3..668abd9e 100644 --- a/salt/models/matcher.py +++ b/salt/models/matcher.py @@ -135,7 +135,8 @@ class HungarianMatcher(nn.Module): obj_class_tgt[:, : self.num_classes].unsqueeze(1).expand(-1, obj_class_pred.size(1), -1) ) valid_obj_mask = obj_class_tgt != self.num_classes - output = torch.gather(obj_class_pred, 2, obj_class_tgt * valid_obj_mask) * valid_obj_mask + indices = (obj_class_tgt * valid_obj_mask).long() # cast to long + output = torch.gather(obj_class_pred, 2, indices) * valid_obj_mask obj_class_cost = torch.zeros((bs, self.num_objects, self.num_objects), device=dev) obj_class_cost[:, :, : self.num_classes] = -output diff --git a/salt/utils/cli.py b/salt/utils/cli.py index db45ea66..9fd48a56 100644 --- a/salt/utils/cli.py +++ b/salt/utils/cli.py @@ -103,11 +103,7 @@ class SaltCLI(LightningCLI): if not (maskformer_config := config.data.get("mf_config")): raise ValueError("Mask decoder requires 'mf_config' in data config.") if maskformer_config.constituent.name not in labels: - raise ValueError( - f"The constituent name {maskformer_config.constituent.name} is not in the" - " data labels. Ensure that the constituent name is in the input_map of the" - " data config." - ) + labels[maskformer_config.constituent.name] = [] # Needed in case no tasks other than mask prediction/classification if "objects" not in labels: -- GitLab From 0cd239fbadf426b3f8e701ae63ef510a5fcae603 Mon Sep 17 00:00:00 2001 From: Theresa Reisch Date: Tue, 18 Nov 2025 17:13:06 +0100 Subject: [PATCH 2/3] linting --- salt/models/maskformer.py | 19 ++++++++++++------- salt/models/maskformer_loss.py | 15 +++++++-------- salt/models/matcher.py | 2 +- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/salt/models/maskformer.py b/salt/models/maskformer.py index 00698a23..128131cd 100644 --- a/salt/models/maskformer.py +++ b/salt/models/maskformer.py @@ -109,7 +109,8 @@ class MaskDecoder(nn.Module): pad_mask = next(iter(filtered.values())) else: raise AssertionError( - f"Maskformer expected one real input mask (length>1). Got: { {k: v.shape for k, v in pad_mask.items()} }" + f"Maskformer expected one real input mask (length>1). Got:\ + { {k: v.shape for k, v in pad_mask.items()} }" ) x = preds["embed_xs"] @@ -118,9 +119,9 @@ class MaskDecoder(nn.Module): x = self.norm2(x) # Add a dummy track to the inputs (and to pad mask) to stop onnx complaining - #xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) - #x = torch.cat([x, xpad], dim=1) - #if pad_mask is not None: + # xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) + # x = torch.cat([x, xpad], dim=1) + # if pad_mask is not None: # padpad_mask = torch.zeros( # (pad_mask.shape[0], 1), device=pad_mask.device, dtype=pad_mask.dtype # ) @@ -130,11 +131,15 @@ class MaskDecoder(nn.Module): # dynamically pad with False to match embedding length pad_len = x.shape[1] - pad_mask.shape[1] pad_mask = torch.cat( - [pad_mask, torch.zeros((pad_mask.shape[0], pad_len), device=pad_mask.device, dtype=pad_mask.dtype)], - dim=1 + [ + pad_mask, + torch.zeros( + (pad_mask.shape[0], pad_len), device=pad_mask.device, dtype=pad_mask.dtype + ), + ], + dim=1, ) - intermediate_outputs: list | None = [] if self.aux_loss else None for layer in self.layers: if self.aux_loss: diff --git a/salt/models/maskformer_loss.py b/salt/models/maskformer_loss.py index 4f4fc8f4..5391a8e5 100644 --- a/salt/models/maskformer_loss.py +++ b/salt/models/maskformer_loss.py @@ -151,19 +151,19 @@ class MaskFormerLoss(nn.Module): loss_weights=matcher_weights, ) - def loss_labels( self, preds: dict[str, torch.Tensor], labels: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: - """Compute the classification loss on object classes, supporting both binary and multi-class tasks.""" - + """Compute the classification loss on object classes, + supporting both binary and multi-class tasks. + """ flav_pred_logits = preds["class_logits"].flatten(0, 1) # (batch * n_queries, n_classes) - flavour_labels = labels["object_class"].flatten(0, 1) # (batch * n_queries,) - + flavour_labels = labels["object_class"].flatten(0, 1) # (batch * n_queries,) + n_classes = flav_pred_logits.shape[1] - + if n_classes == 1: # binary classification loss = F.binary_cross_entropy_with_logits( @@ -177,9 +177,8 @@ class MaskFormerLoss(nn.Module): else: weight = self.empty_weight loss = F.cross_entropy(flav_pred_logits, flavour_labels.long(), weight=weight) - - return {"object_class_ce": loss} + return {"object_class_ce": loss} def loss_masks(self, preds, labels): """Compute the losses related to the masks: the focal loss and the dice loss. diff --git a/salt/models/matcher.py b/salt/models/matcher.py index 668abd9e..6303a713 100644 --- a/salt/models/matcher.py +++ b/salt/models/matcher.py @@ -135,7 +135,7 @@ class HungarianMatcher(nn.Module): obj_class_tgt[:, : self.num_classes].unsqueeze(1).expand(-1, obj_class_pred.size(1), -1) ) valid_obj_mask = obj_class_tgt != self.num_classes - indices = (obj_class_tgt * valid_obj_mask).long() # cast to long + indices = (obj_class_tgt * valid_obj_mask).long() # cast to long output = torch.gather(obj_class_pred, 2, indices) * valid_obj_mask obj_class_cost = torch.zeros((bs, self.num_objects, self.num_objects), device=dev) obj_class_cost[:, :, : self.num_classes] = -output -- GitLab From d913ab5f4c230605f62f13853328524c34b2b7a4 Mon Sep 17 00:00:00 2001 From: Theresa Reisch Date: Tue, 18 Nov 2025 17:38:13 +0100 Subject: [PATCH 3/3] remove comments --- salt/models/maskformer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/salt/models/maskformer.py b/salt/models/maskformer.py index 128131cd..273e3614 100644 --- a/salt/models/maskformer.py +++ b/salt/models/maskformer.py @@ -118,15 +118,6 @@ class MaskDecoder(nn.Module): q = self.norm1(self.inital_q.expand(x.shape[0], -1, -1)) x = self.norm2(x) - # Add a dummy track to the inputs (and to pad mask) to stop onnx complaining - # xpad = torch.zeros((x.shape[0], 1, x.shape[-1]), device=x.device, dtype=x.dtype) - # x = torch.cat([x, xpad], dim=1) - # if pad_mask is not None: - # padpad_mask = torch.zeros( - # (pad_mask.shape[0], 1), device=pad_mask.device, dtype=pad_mask.dtype - # ) - # pad_mask = torch.cat([pad_mask, padpad_mask], dim=1) - if pad_mask is not None and pad_mask.shape[1] != x.shape[1]: # dynamically pad with False to match embedding length pad_len = x.shape[1] - pad_mask.shape[1] -- GitLab