From 83e11298f90488b579c938befa1ca4b7ad5f3061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophie=20G=C3=A9linas?= Date: Tue, 22 Jul 2025 08:07:09 -0700 Subject: [PATCH 1/2] add functionality for BCEWithLogitsLoss for jet classification --- salt/models/task.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/salt/models/task.py b/salt/models/task.py index 17258ed6..fa5dea86 100644 --- a/salt/models/task.py +++ b/salt/models/task.py @@ -173,6 +173,8 @@ class ClassificationTask(TaskBase): if labels is not None: if preds.ndim == 3: loss = self.loss(preds.permute(0, 2, 1), labels) + elif isinstance(self.loss, torch.nn.BCEWithLogitsLoss): + loss = self.loss(preds.squeeze(-1), labels.float()) else: loss = self.loss(preds, labels) loss = self.apply_sample_weight(loss, labels_dict) @@ -181,7 +183,12 @@ class ClassificationTask(TaskBase): return preds, loss def run_inference(self, preds: Tensor, pad_mask: Tensor | None = None) -> Tensor: - if pad_mask is None: + if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): + if pad_mask is None: + probs = torch.sigmoid(preds) + else: + probs = masked_sigmoid(preds, pad_mask.unsqueeze(-1)) + elif pad_mask is None: assert preds.ndim == 2 probs = torch.softmax(preds, dim=-1) else: -- GitLab From f7f8990ca847bc487961bc3b9fca1c6efc86ddf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophie=20G=C3=A9linas?= Date: Wed, 23 Jul 2025 04:59:06 -0700 Subject: [PATCH 2/2] update for BCEWithLogitsLoss for jet classification --- salt/models/task.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/salt/models/task.py b/salt/models/task.py index fa5dea86..bd395dc8 100644 --- a/salt/models/task.py +++ b/salt/models/task.py @@ -184,10 +184,7 @@ class ClassificationTask(TaskBase): def run_inference(self, preds: Tensor, pad_mask: Tensor | None = None) -> Tensor: if isinstance(self.loss, torch.nn.BCEWithLogitsLoss): - if pad_mask is None: - probs = torch.sigmoid(preds) - else: - probs = masked_sigmoid(preds, pad_mask.unsqueeze(-1)) + probs = torch.sigmoid(preds) elif pad_mask is None: assert preds.ndim == 2 probs = torch.softmax(preds, dim=-1) -- GitLab