diff --git a/salt/models/task.py b/salt/models/task.py index 17258ed660a1047b4bfc971e4377ac2e18e55b46..bd395dc862eb44bf36d7b434281c8a39d3bf605a 100644 --- a/salt/models/task.py +++ b/salt/models/task.py @@ -173,6 +173,8 @@ class ClassificationTask(TaskBase): if labels is not None: if preds.ndim == 3: loss = self.loss(preds.permute(0, 2, 1), labels) + elif isinstance(self.loss, torch.nn.BCEWithLogitsLoss): + loss = self.loss(preds.squeeze(-1), labels.float()) else: loss = self.loss(preds, labels) loss = self.apply_sample_weight(loss, labels_dict) @@ -181,7 +183,9 @@ class ClassificationTask(TaskBase): return preds, loss def run_inference(self, preds: Tensor, pad_mask: Tensor | None = None) -> Tensor: - if pad_mask is None: + if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): + probs = torch.sigmoid(preds) + elif pad_mask is None: assert preds.ndim == 2 probs = torch.softmax(preds, dim=-1) else: