From 346c4a4b409a1dc206a68e94ef82b6ca79d3999f Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Thu, 16 Oct 2025 12:54:59 +0200 Subject: [PATCH 1/9] Working version of conditionally_skip script --- .pre-commit-config.yaml | 11 + scripts/wrap_pid_conditions.py | 390 +++++++++++++++++++++++++++++++++ 2 files changed, 401 insertions(+) create mode 100755 scripts/wrap_pid_conditions.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bb97b58dfc..0ad32462bf4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,14 @@ repos: hooks: - id: lb-add-copyright exclude: "^lhcbproject.yml$" +- repo: local + hooks: + - id: wrap-pid-conditions + name: Wrap PID cuts with conditionally_skip + ensure import + entry: python3 scripts/wrap_pid_conditions.py + language: system + types: [python] + pass_filenames: true + additional_dependencies: [] + # To fail CI without auto-fixing, switch to: + # entry: python3 scripts/wrap_pid_conditions.py --check diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py new file mode 100755 index 00000000000..e5ab801e7eb --- /dev/null +++ b/scripts/wrap_pid_conditions.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +############################################################################### +# (c) Copyright 2025 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### +from __future__ import annotations + +import argparse +import re +import sys +from pathlib import Path + +import libcst as cst + +# ================================ +# Config +# ================================ + +PID_ATTR_RE = re.compile(r"^(PID|PROBNN)_", re.IGNORECASE) +PID_BOOL_NAMES = {"ISMUON"} + +PID_WRAPPER_FUNCS = { + "CHILD", + "MIN", + "MAX", + "SUM", + "ANY", + "ALL", + "MINTREE", + "MAXTREE", + "in_range", +} + +BOOL_CONSUMERS = {"SUM", "ANY", "ALL", "COUNT"} + +# ================================ +# Helpers +# ================================ + + +def is_F_attr(node: cst.CSTNode) -> bool: + return ( + isinstance(node, cst.Attribute) + and isinstance(node.value, cst.Name) + and node.value.value == "F" + and isinstance(node.attr, cst.Name) + ) + + +def attr_value(node: cst.Attribute) -> str: + return node.attr.value # type: ignore[return-value] + + +def is_pid_attribute(node: cst.CSTNode) -> bool: + if not is_F_attr(node): + return False + name = attr_value(node) + return bool(PID_ATTR_RE.match(name)) or (name in PID_BOOL_NAMES) + + +def is_pid_wrapper_call(node: cst.CSTNode) -> bool: + if not isinstance(node, cst.Call): + return False + f = node.func + if is_F_attr(f): + return attr_value(f) in PID_WRAPPER_FUNCS + return False + + +def func_name_like(node: cst.CSTNode) -> str | None: + if isinstance(node, cst.Name): + return node.value + if is_F_attr(node): + return attr_value(node) + if isinstance(node, cst.Attribute) and isinstance(node.attr, cst.Name): + return node.attr.value + return None + + +def contains_pid_marker(node: cst.CSTNode) -> bool: + found = False + + class Finder(cst.CSTVisitor): + def visit_Attribute(self, n: cst.Attribute) -> None: + nonlocal found + if found: + return + if is_pid_attribute(n): + found = True + + def visit_Call(self, n: cst.Call) -> None: + nonlocal found + if found: + return + if is_pid_wrapper_call(n): + for a in n.args: + if a.value is not None and contains_pid_marker(a.value): + found = True + + node.visit(Finder()) + return found + + +def is_conditionally_skip_call(node: cst.CSTNode) -> bool: + return ( + isinstance(node, cst.Call) + and isinstance(node.func, cst.Name) + and node.func.value == "conditionally_skip" + ) + + +def is_boolean_like(node: cst.CSTNode) -> bool: + if is_pid_attribute(node): + return attr_value(node) in PID_BOOL_NAMES + if isinstance(node, cst.UnaryOperation) and isinstance( + node.operator, cst.BitInvert + ): + return is_boolean_like(node.expression) + return False + + +# ================================ +# Transformer: wrap PID cuts +# ================================ + + +class WrapPidCuts(cst.CSTTransformer): + def __init__(self): + self._in_cond_skip_stack: list[bool] = [] + self.introduced_count: int = 0 + + def _wrap(self, expr: cst.BaseExpression) -> cst.Call: + self.introduced_count += 1 + return cst.Call(func=cst.Name("conditionally_skip"), args=[cst.Arg(expr)]) + + def visit_Call(self, node: cst.Call) -> None: + self._in_cond_skip_stack.append(is_conditionally_skip_call(node)) + + def leave_Call( + self, original_node: cst.Call, updated_node: cst.Call + ) -> cst.BaseExpression: + _ = self._in_cond_skip_stack.pop() if self._in_cond_skip_stack else False + fn = func_name_like(updated_node.func) or "" + if is_F_attr(updated_node.func) and (fn in BOOL_CONSUMERS): + new_args = [] + for a in updated_node.args: + v = a.value + if ( + isinstance(v, cst.Comparison) + and contains_pid_marker(v) + and not is_conditionally_skip_call(v) + ): + v = self._wrap(v) + new_args.append(a.with_changes(value=v)) + return updated_node.with_changes(args=new_args) + return updated_node + + def leave_Comparison( + self, original_node: cst.Comparison, updated_node: cst.Comparison + ) -> cst.BaseExpression: + if any(self._in_cond_skip_stack): + return updated_node + if contains_pid_marker(updated_node): + return self._wrap(updated_node) + return updated_node + + def leave_UnaryOperation( + self, original_node: cst.UnaryOperation, updated_node: cst.UnaryOperation + ) -> cst.BaseExpression: + if any(self._in_cond_skip_stack): + return updated_node + if isinstance(updated_node.operator, cst.BitInvert): + opnd = updated_node.expression + if (isinstance(opnd, cst.Comparison) and contains_pid_marker(opnd)) or ( + contains_pid_marker(opnd) and is_boolean_like(opnd) + ): + return self._wrap(updated_node) + return updated_node + + def leave_BinaryOperation( + self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation + ) -> cst.BaseExpression: + if isinstance(updated_node.operator, (cst.BitAnd, cst.BitOr)): + + def maybe_wrap(e: cst.BaseExpression) -> cst.BaseExpression: + if is_conditionally_skip_call(e): + return e + if isinstance(e, cst.Comparison) and contains_pid_marker(e): + return self._wrap(e) + if isinstance(e, cst.UnaryOperation) and isinstance( + e.operator, cst.BitInvert + ): + inner = e.expression + if ( + isinstance(inner, cst.Comparison) and contains_pid_marker(inner) + ) or (contains_pid_marker(inner) and is_boolean_like(inner)): + return self._wrap(e) + return e + + return updated_node.with_changes( + left=maybe_wrap(updated_node.left), + right=maybe_wrap(updated_node.right), + ) + return updated_node + + +# ================================ +# Transformer: ensure import +# ================================ + + +class EnsureConditionallySkipImport(cst.CSTTransformer): + """ + Ensure: from RecoConf.global_tools import conditionally_skip + Only acts if `should_apply` is True. + Fixes: + - no brittle string guard + - handle names as Sequence[ImportAlias] (tuple) instead of list + """ + + def __init__(self, should_apply: bool): + self.should_apply = should_apply + self.injected_or_found = False + self.added_to_existing = False + self.insert_after_index = 0 + + def _is_target_module(self, module: cst.CSTNode | None) -> bool: + return ( + isinstance(module, cst.Attribute) + and isinstance(module.value, cst.Name) + and module.value.value == "RecoConf" + and isinstance(module.attr, cst.Name) + and module.attr.value == "global_tools" + ) + + def visit_Module(self, node: cst.Module) -> None: + idx = 0 + body = node.body + # after module docstring + if body and isinstance(body[0], cst.SimpleStatementLine): + stmt = body[0] + if ( + len(stmt.body) == 1 + and isinstance(stmt.body[0], cst.Expr) + and isinstance(stmt.body[0].value, cst.SimpleString) + ): + idx = 1 + # after __future__ imports + while idx < len(body): + stmt = body[idx] + if ( + isinstance(stmt, cst.SimpleStatementLine) + and stmt.body + and isinstance(stmt.body[0], cst.ImportFrom) + ): + imp: cst.ImportFrom = stmt.body[0] + if ( + isinstance(imp.module, cst.Name) + and imp.module.value == "__future__" + ): + idx += 1 + continue + break + self.insert_after_index = idx + + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.ImportFrom: + # If ANY import-from brings conditionally_skip into scope, mark as found (don’t inject again) + names = updated_node.names + if isinstance(names, cst.ImportStar): + # Star import from *anywhere* could already provide the name; be conservative: + if self._is_target_module(updated_node.module): + self.injected_or_found = True + return updated_node + + if isinstance(names, tuple): + for n in names: + if isinstance(n, cst.ImportAlias): + if ( + isinstance(n.name, cst.Name) + and n.name.value == "conditionally_skip" + ): + self.injected_or_found = True + + # If it's specifically from RecoConf.global_tools, append the alias if missing + if self.should_apply and self._is_target_module(updated_node.module): + if isinstance(names, tuple): + have = any( + isinstance(n, cst.ImportAlias) + and isinstance(n.name, cst.Name) + and n.name.value == "conditionally_skip" + for n in names + ) + if not have: + new_names = list(names) + [ + cst.ImportAlias(name=cst.Name("conditionally_skip")) + ] + self.added_to_existing = True + self.injected_or_found = True + return updated_node.with_changes(names=tuple(new_names)) + + return updated_node + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + if not self.should_apply: + return updated_node + if self.injected_or_found: + return updated_node + # Insert a fresh import line + new_import = cst.SimpleStatementLine( + body=[ + cst.ImportFrom( + module=cst.Attribute( + value=cst.Name("RecoConf"), attr=cst.Name("global_tools") + ), + names=(cst.ImportAlias(name=cst.Name("conditionally_skip")),), + ) + ] + ) + body = list(updated_node.body) + i = min(self.insert_after_index, len(body)) + body.insert(i, new_import) + return updated_node.with_changes(body=body) + + +# ================================ +# Driver +# ================================ + + +def process_file(path: Path, write: bool) -> bool: + try: + code = path.read_text(encoding="utf-8") + except Exception: + return False + try: + mod = cst.parse_module(code) + except Exception: + return False + + # Pass 1: wrap PID cuts + wrapper = WrapPidCuts() + mod1 = mod.visit(wrapper) + + # Pass 2: ensure import only if we introduced any wrappers + mod2 = mod1.visit( + EnsureConditionallySkipImport(should_apply=(wrapper.introduced_count > 0)) + ) + + changed = mod2.code != code + if changed and write: + path.write_text(mod2.code, encoding="utf-8") + return changed + + +def main(): + ap = argparse.ArgumentParser( + description="Wrap PID cuts with conditionally_skip(...) and ensure proper import." + ) + ap.add_argument("files", nargs="+", help=".py files to process") + ap.add_argument( + "--check", + action="store_true", + help="Only check; exit 1 if changes would be made", + ) + args = ap.parse_args() + + any_changed = False + for f in args.files: + p = Path(f) + if p.suffix != ".py" or not p.is_file(): + continue + if process_file(p, write=not args.check): + any_changed = True + + if args.check and any_changed: + sys.exit(1) + + +if __name__ == "__main__": + main() -- GitLab From 3303a617869bdff11a726785d7bbd36f3c673a88 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Thu, 16 Oct 2025 14:35:15 +0200 Subject: [PATCH 2/9] Add GHOSTPROB functor --- scripts/wrap_pid_conditions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index e5ab801e7eb..db7035bbfcb 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -22,7 +22,7 @@ import libcst as cst # Config # ================================ -PID_ATTR_RE = re.compile(r"^(PID|PROBNN)_", re.IGNORECASE) +PID_ATTR_RE = re.compile(r"^((PID|PROBNN)_|GHOSTPROB)", re.IGNORECASE) PID_BOOL_NAMES = {"ISMUON"} PID_WRAPPER_FUNCS = { -- GitLab From 1101a1212ba71fc82c7664c429b93e6d233994c7 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Thu, 16 Oct 2025 14:42:06 +0200 Subject: [PATCH 3/9] Add support for derived pid variables --- scripts/wrap_pid_conditions.py | 113 ++++++++++++++++++++++++++++----- 1 file changed, 96 insertions(+), 17 deletions(-) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index db7035bbfcb..9b1a26d2ce5 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -107,6 +107,28 @@ def contains_pid_marker(node: cst.CSTNode) -> bool: return found +def contains_pid_or_derived(node: cst.CSTNode, derived: set[str]) -> bool: + """ + True if the subtree contains a PID marker (F.PID_*, F.PROBNN_*, F.GHOSTPROB, etc.) + OR references any name already marked as PID-derived. + """ + if contains_pid_marker(node): + return True + + found = False + + class FindDerived(cst.CSTVisitor): + def visit_Name(self, n: cst.Name) -> None: + nonlocal found + if found: + return + if n.value in derived: + found = True + + node.visit(FindDerived()) + return found + + def is_conditionally_skip_call(node: cst.CSTNode) -> bool: return ( isinstance(node, cst.Call) @@ -125,15 +147,62 @@ def is_boolean_like(node: cst.CSTNode) -> bool: return False +# ================================ +# Pre-pass: collect PID-derived names +# ================================ + + +class CollectPidDerivedNames(cst.CSTVisitor): + """ + Builds a set of simple names assigned from PID-bearing expressions, + and propagates through reassignments/aug-assignments. + + Tracks only simple names (e.g. x = ...), not attributes (self.x) or destructuring. + """ + + def __init__(self): + self.derived: set[str] = set() + + def _expr_is_pid_related(self, expr: cst.BaseExpression) -> bool: + return contains_pid_or_derived(expr, self.derived) + + def visit_Assign(self, node: cst.Assign) -> None: + # a = or a = b = + if node.value is None: + return + if not self._expr_is_pid_related(node.value): + return + for t in node.targets: + if isinstance(t.target, cst.Name): + self.derived.add(t.target.value) + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + # a: T = + if node.value is None: + return + if not self._expr_is_pid_related(node.value): + return + if isinstance(node.target, cst.Name): + self.derived.add(node.target.value) + + def visit_AugAssign(self, node: cst.AugAssign) -> None: + # a *= , a += , etc. (mark 'a' if RHS is PID-related) + if not self._expr_is_pid_related(node.value): + return + if isinstance(node.target, cst.Name): + self.derived.add(node.target.value) + + # ================================ # Transformer: wrap PID cuts # ================================ class WrapPidCuts(cst.CSTTransformer): - def __init__(self): + def __init__(self, derived_names: set[str]): self._in_cond_skip_stack: list[bool] = [] self.introduced_count: int = 0 + self.derived = set(derived_names) def _wrap(self, expr: cst.BaseExpression) -> cst.Call: self.introduced_count += 1 @@ -153,7 +222,7 @@ class WrapPidCuts(cst.CSTTransformer): v = a.value if ( isinstance(v, cst.Comparison) - and contains_pid_marker(v) + and contains_pid_or_derived(v, self.derived) and not is_conditionally_skip_call(v) ): v = self._wrap(v) @@ -166,7 +235,7 @@ class WrapPidCuts(cst.CSTTransformer): ) -> cst.BaseExpression: if any(self._in_cond_skip_stack): return updated_node - if contains_pid_marker(updated_node): + if contains_pid_or_derived(updated_node, self.derived): return self._wrap(updated_node) return updated_node @@ -177,8 +246,11 @@ class WrapPidCuts(cst.CSTTransformer): return updated_node if isinstance(updated_node.operator, cst.BitInvert): opnd = updated_node.expression - if (isinstance(opnd, cst.Comparison) and contains_pid_marker(opnd)) or ( - contains_pid_marker(opnd) and is_boolean_like(opnd) + if ( + isinstance(opnd, cst.Comparison) + and contains_pid_or_derived(opnd, self.derived) + ) or ( + contains_pid_or_derived(opnd, self.derived) and is_boolean_like(opnd) ): return self._wrap(updated_node) return updated_node @@ -191,15 +263,21 @@ class WrapPidCuts(cst.CSTTransformer): def maybe_wrap(e: cst.BaseExpression) -> cst.BaseExpression: if is_conditionally_skip_call(e): return e - if isinstance(e, cst.Comparison) and contains_pid_marker(e): + if isinstance(e, cst.Comparison) and contains_pid_or_derived( + e, self.derived + ): return self._wrap(e) if isinstance(e, cst.UnaryOperation) and isinstance( e.operator, cst.BitInvert ): inner = e.expression if ( - isinstance(inner, cst.Comparison) and contains_pid_marker(inner) - ) or (contains_pid_marker(inner) and is_boolean_like(inner)): + isinstance(inner, cst.Comparison) + and contains_pid_or_derived(inner, self.derived) + ) or ( + contains_pid_or_derived(inner, self.derived) + and is_boolean_like(inner) + ): return self._wrap(e) return e @@ -219,9 +297,7 @@ class EnsureConditionallySkipImport(cst.CSTTransformer): """ Ensure: from RecoConf.global_tools import conditionally_skip Only acts if `should_apply` is True. - Fixes: - - no brittle string guard - - handle names as Sequence[ImportAlias] (tuple) instead of list + Handles both single-line and parenthesized multi-line import-from statements. """ def __init__(self, should_apply: bool): @@ -272,10 +348,9 @@ class EnsureConditionallySkipImport(cst.CSTTransformer): def leave_ImportFrom( self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom ) -> cst.ImportFrom: - # If ANY import-from brings conditionally_skip into scope, mark as found (don’t inject again) + # If ANY import-from brings conditionally_skip into scope, mark as found names = updated_node.names if isinstance(names, cst.ImportStar): - # Star import from *anywhere* could already provide the name; be conservative: if self._is_target_module(updated_node.module): self.injected_or_found = True return updated_node @@ -289,7 +364,7 @@ class EnsureConditionallySkipImport(cst.CSTTransformer): ): self.injected_or_found = True - # If it's specifically from RecoConf.global_tools, append the alias if missing + # Append to `from RecoConf.global_tools import ...` if missing if self.should_apply and self._is_target_module(updated_node.module): if isinstance(names, tuple): have = any( @@ -347,8 +422,12 @@ def process_file(path: Path, write: bool) -> bool: except Exception: return False - # Pass 1: wrap PID cuts - wrapper = WrapPidCuts() + # Pre-pass: collect PID-derived names + collector = CollectPidDerivedNames() + mod.visit(collector) + + # Pass 1: wrap PID cuts (including expressions using derived names) + wrapper = WrapPidCuts(derived_names=collector.derived) mod1 = mod.visit(wrapper) # Pass 2: ensure import only if we introduced any wrappers @@ -364,7 +443,7 @@ def process_file(path: Path, write: bool) -> bool: def main(): ap = argparse.ArgumentParser( - description="Wrap PID cuts with conditionally_skip(...) and ensure proper import." + description="Wrap PID cuts with conditionally_skip(...), including expressions using PID-derived variables, and ensure proper import." ) ap.add_argument("files", nargs="+", help=".py files to process") ap.add_argument( -- GitLab From e99a3cd748511aea9d6990b39155334918f9109c Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Tue, 21 Oct 2025 08:05:14 +0200 Subject: [PATCH 4/9] Replace GHOSTPROB with IS_PHOTON and IS_NOT_H --- scripts/wrap_pid_conditions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index 9b1a26d2ce5..2b5c8604f6c 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -22,7 +22,7 @@ import libcst as cst # Config # ================================ -PID_ATTR_RE = re.compile(r"^((PID|PROBNN)_|GHOSTPROB)", re.IGNORECASE) +PID_ATTR_RE = re.compile(r"^((PID|PROBNN)_|IS_PHOTON|IS_NOT_H)", re.IGNORECASE) PID_BOOL_NAMES = {"ISMUON"} PID_WRAPPER_FUNCS = { -- GitLab From afb633a10a7b4397974805abdc3f596afc56e995 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Mon, 1 Dec 2025 11:11:58 +0100 Subject: [PATCH 5/9] Add no_pid_tool and modify wrap pid to include labels --- Hlt/RecoConf/python/RecoConf/no_pid_tool.py | 175 ++++++++++++++++++++ scripts/wrap_pid_conditions.py | 119 ++++++++++++- 2 files changed, 292 insertions(+), 2 deletions(-) create mode 100644 Hlt/RecoConf/python/RecoConf/no_pid_tool.py diff --git a/Hlt/RecoConf/python/RecoConf/no_pid_tool.py b/Hlt/RecoConf/python/RecoConf/no_pid_tool.py new file mode 100644 index 00000000000..58e076798e4 --- /dev/null +++ b/Hlt/RecoConf/python/RecoConf/no_pid_tool.py @@ -0,0 +1,175 @@ +############################################################################### +# (c) Copyright 2025 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### +""" +Utilities to *switch off PID* in a controlled way. + +Intended use case: produce “no PID” (noPID) samples, e.g. HLT2 trigger on +simulation where PID is known to be poorly modelled and is corrected +later with data-driven tools. + +This module provides: + +- no_pid_tool: configurable that defines which PID “labels” are switched off. +- conditionally_skip: wrapper used around PID-dependent cuts. +- poison_on_demand: wrapper used around PID functors themselves, as a safety net. + +If a label is “switched off” via no_pid_tool.bind(...), then: + +- PID cuts wrapped by conditionally_skip(..., label=...) will be short-circuited. +- PID functors wrapped by poison_on_demand(..., label=...) will throw if evaluated. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import FrozenSet, Iterable, Optional + +import Functors as F +from PyConf import configurable + +# --------------------------------------------------------------------------- +# 1) Central configuration: which labels are "no PID" +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _NoPidState: + """ + Internal state object returned by no_pid_tool. + + labels: + A set of labels for which PID is switched off. A “label” is just a + string attached consistently to: + - the PID functor (via poison_on_demand), + - the PID cut (via conditionally_skip). + """ + + labels: FrozenSet[str] = frozenset() + + +@configurable +def no_pid_tool( + labels: Iterable[str] = (), +) -> _NoPidState: + """ + Central knob to switch off PID for selected groups of functors / cuts. + + Typical usage in job options (for simulation): + + from RecoConf.no_pid_tool import no_pid_tool + + # Switch off PID for selected labels: + no_pid_tool.bind(labels={"PID_K", "PID_P", "PID_MU"}) + + Any PID functor / cut annotated with one of these labels will: + * be skipped via conditionally_skip(..., label="..."), + * be poisoned via poison_on_demand(..., label="..."). + + You are free to define labels however you like (per particle species, + per detector, or a single coarse label like "ALL_PID"). + """ + return _NoPidState(labels=frozenset(labels)) + + +# --------------------------------------------------------------------------- +# 2) conditionally_skip: used by the rewrite script on PID cuts +# --------------------------------------------------------------------------- + + +@configurable +def conditionally_skip( + predicate, + label: Optional[str] = None, +): + """ + Wrap a PID-dependent cut so it can be switched off. + + The libcst script will rewrite PID cuts like: + + F.PID_K > 3 + + into something like: + + conditionally_skip(F.PID_K > 3, label="PID_K") + + Behaviour: + + - If `label` is None -> no-op, just returns `predicate`. + - If `label` is not listed in no_pid_tool().labels -> no-op. + - If `label` *is* listed in no_pid_tool().labels: + return F.ACCEPT(True) | predicate + + So the functor algebra will short-circuit and `predicate` will + not be evaluated, effectively removing that PID cut. + """ + cfg = no_pid_tool() + + # If no label or label not configured for no-PID, leave the cut as-is. + if label is None or label not in cfg.labels: + return predicate + + # When this label is in "no PID" mode: force acceptance so that + # the underlying predicate is never evaluated. + # + # In C++ this corresponds to Functors::Accept{ true }. + return F.ACCEPT(True) | predicate + + +# --------------------------------------------------------------------------- +# 3) poison_on_demand: used on PID functors themselves +# --------------------------------------------------------------------------- + + +@configurable +def poison_on_demand( + functor, + label: Optional[str] = None, +): + """ + Wrap PID-related functors so they *throw* if evaluated while the + corresponding label is in "no PID" mode. + + Intended usage at functor definition time, e.g. where F.PID_K is created: + + from RecoConf.no_pid_tool import poison_on_demand + + F.PID_K = poison_on_demand(F.PID_K, label="PID_K") + F.PROBNN_K = poison_on_demand(F.PROBNN_K, label="PID_K") + + Then, when you configure a no-PID sample: + + from RecoConf.no_pid_tool import no_pid_tool + no_pid_tool.bind(labels={"PID_K"}) + + you get: + + - PID cuts rewritten as conditionally_skip(..., label="PID_K") + being skipped cleanly; and + + - any remaining use of F.PID_K / F.PROBNN_K (e.g. missed by rewriting, + or used directly elsewhere) triggering an exception, so you don't + silently get PID dependence in a “no PID” configuration. + """ + cfg = no_pid_tool() + + # If this label isn't in "no PID" mode, leave the functor unchanged. + if label is None or label not in cfg.labels: + return functor + + # When the label is in "no PID" mode: replace with a POISON functor. + # + # This matches the C++ pattern: + # Functors::Functional::Poison( Functors::PID::IsMuon, "do not use IsMuon" ); + msg = ( + f"PID functor with label {label!r} is disabled by no_pid_tool and " + f"must not be evaluated (no-PID configuration)." + ) + return F.POISON(functor, msg) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index 2b5c8604f6c..1034d7153c8 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -109,7 +109,7 @@ def contains_pid_marker(node: cst.CSTNode) -> bool: def contains_pid_or_derived(node: cst.CSTNode, derived: set[str]) -> bool: """ - True if the subtree contains a PID marker (F.PID_*, F.PROBNN_*, F.GHOSTPROB, etc.) + True if the subtree contains a PID marker (F.PID_*, F.PROBNN_*, IS_PHOTON, IS_NOT_H, ISMUON) OR references any name already marked as PID-derived. """ if contains_pid_marker(node): @@ -147,6 +147,68 @@ def is_boolean_like(node: cst.CSTNode) -> bool: return False +# NEW: map attribute name -> category label +def pid_category_from_attr_name(name: str) -> str | None: + """ + Map a PID-like attribute name to a coarse category label. + + Examples: + PID_K -> "PID" + PROBNN_MU -> "PROBNN" + IS_PHOTON -> "IS_PHOTON" + IS_NOT_H -> "IS_NOT_H" + ISMUON -> "ISMUON" + """ + upper = name.upper() + if upper.startswith("PID_"): + return "PID" + if upper.startswith("PROBNN_"): + return "PROBNN" + if upper in {"IS_PHOTON", "IS_NOT_H"}: + return upper + if upper in PID_BOOL_NAMES: + return upper + return None + + +# NEW: derive category label from an expression +def pid_label_from_expr(node: cst.CSTNode) -> str | None: + """ + Try to derive a *category* label from a PID-bearing expression. + + Strategy: + - Find the first F.PID_*/F.PROBNN_* attribute, or IS_PHOTON/IS_NOT_H/ISMUON. + - Map it to a category: + PID_* -> "PID" + PROBNN_* -> "PROBNN" + IS_PHOTON -> "IS_PHOTON" + IS_NOT_H -> "IS_NOT_H" + ISMUON -> "ISMUON" + - If none is found, return None. + """ + label: str | None = None + + class Finder(cst.CSTVisitor): + def visit_Attribute(self, n: cst.Attribute) -> None: + nonlocal label + if label is not None: + return + if is_pid_attribute(n): + cat = pid_category_from_attr_name(attr_value(n)) + if cat is not None: + label = cat + + def visit_Call(self, n: cst.Call) -> None: + # Recurse into PID wrapper calls like F.in_range(...) + if is_pid_wrapper_call(n): + for a in n.args: + if a.value is not None: + a.value.visit(self) + + node.visit(Finder()) + return label + + # ================================ # Pre-pass: collect PID-derived names # ================================ @@ -205,8 +267,27 @@ class WrapPidCuts(cst.CSTTransformer): self.derived = set(derived_names) def _wrap(self, expr: cst.BaseExpression) -> cst.Call: + """ + Wrap an expression in conditionally_skip(..., label=). + + - If we can find a PID attribute, we pass the category label + ("PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"). + - Otherwise we emit conditionally_skip(expr) without label + (behaves like your original implementation). + """ self.introduced_count += 1 - return cst.Call(func=cst.Name("conditionally_skip"), args=[cst.Arg(expr)]) + + label = pid_label_from_expr(expr) + + args: list[cst.Arg] = [cst.Arg(expr)] + if label is not None: + args.append( + cst.Arg( + keyword=cst.Name("label"), + value=cst.SimpleString(f'"{label}"'), + ) + ) + return cst.Call(func=cst.Name("conditionally_skip"), args=args) def visit_Call(self, node: cst.Call) -> None: self._in_cond_skip_stack.append(is_conditionally_skip_call(node)) @@ -215,6 +296,39 @@ class WrapPidCuts(cst.CSTTransformer): self, original_node: cst.Call, updated_node: cst.Call ) -> cst.BaseExpression: _ = self._in_cond_skip_stack.pop() if self._in_cond_skip_stack else False + + # 1) Upgrade existing conditionally_skip(...) calls by adding a label + # if missing. + if is_conditionally_skip_call(updated_node): + # Do nothing if there's already a label=... arg + has_label_kw = any( + isinstance(a.keyword, cst.Name) and a.keyword.value == "label" + for a in updated_node.args + if a.keyword is not None + ) + if has_label_kw: + return updated_node + + # Take the first positional argument as the predicate expression + if not updated_node.args: + return updated_node + pred_arg = updated_node.args[0].value + if pred_arg is None: + return updated_node + + label = pid_label_from_expr(pred_arg) + if label is None: + return updated_node + + new_args = list(updated_node.args) + [ + cst.Arg( + keyword=cst.Name("label"), + value=cst.SimpleString(f'"{label}"'), + ) + ] + return updated_node.with_changes(args=tuple(new_args)) + + # 2) Existing BOOL_CONSUMERS logic (unchanged) fn = func_name_like(updated_node.func) or "" if is_F_attr(updated_node.func) and (fn in BOOL_CONSUMERS): new_args = [] @@ -228,6 +342,7 @@ class WrapPidCuts(cst.CSTTransformer): v = self._wrap(v) new_args.append(a.with_changes(value=v)) return updated_node.with_changes(args=new_args) + return updated_node def leave_Comparison( -- GitLab From 4c109acc0fe4e8e826f14ac528a129a44a809cdf Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Mon, 1 Dec 2025 11:31:38 +0100 Subject: [PATCH 6/9] Add no_pid_poisoning script and add two options of modified hlt2_pp_2024_no_pid --- .../hlt2_pp_2024_no_pid_functor-based copy.py | 43 +++++++++++++ .../hlt2_pp_2024_no_pid_recoconf-based.py | 60 +++++++++++++++++++ .../python/RecoConf/no_pid_poisoning.py | 23 +++++++ 3 files changed, 126 insertions(+) create mode 100644 Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py create mode 100644 Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py create mode 100644 Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py diff --git a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py new file mode 100644 index 00000000000..8361cd7419c --- /dev/null +++ b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py @@ -0,0 +1,43 @@ +############################################################################### +# (c) Copyright 2024 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### +from DDDB.CheckDD4Hep import UseDD4Hep +from Hlt2Conf.settings.hlt2_pp_2024 import make_streams +from Moore import options, run_moore +from RecoConf.global_tools import ( + stateProvider_with_simplified_geom, + trackMasterExtrapolator_with_simplified_geom, +) +from RecoConf.no_pid_tool import no_pid_tool +from RecoConf.reconstruction_objects import reconstruction + +options.lines_maker = make_streams + +if UseDD4Hep: + # This needs to happen before the public tools are instantiated, + # which means we cannot put it inside make_streams(). + from PyConf.Tools import TrackMasterExtrapolator, TrackMasterFitter + + TrackMasterExtrapolator.global_bind( + ApplyMultScattCorr=False, + ApplyEnergyLossCorr=False, + ApplyElectronEnergyLossCorr=False, + ) + TrackMasterFitter.global_bind(ApplyMaterialCorrections=False) + +public_tools = [ + trackMasterExtrapolator_with_simplified_geom(), + stateProvider_with_simplified_geom(), +] + +nopid_labels = {"PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"} + +with reconstruction.bind(from_file=False), no_pid_tool.bind(labels=nopid_labels): + config = run_moore(options, public_tools=public_tools) diff --git a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py new file mode 100644 index 00000000000..c8580e78e3f --- /dev/null +++ b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py @@ -0,0 +1,60 @@ +############################################################################### +# (c) Copyright 2024 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### + +from DDDB.CheckDD4Hep import UseDD4Hep +from Moore import options, run_moore +from RecoConf.global_tools import ( + stateProvider_with_simplified_geom, + trackMasterExtrapolator_with_simplified_geom, +) +from RecoConf.no_pid_poisoning import enable_pid_poisoning + +# NEW: no-PID machinery +from RecoConf.no_pid_tool import no_pid_tool +from RecoConf.reconstruction_objects import reconstruction + +# Choose which PID categories to switch off for this configuration. +# These must match the labels produced by your wrap_pid_conditions script +# and the labels used in enable_pid_poisoning(). +nopid_labels = {"PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"} + +# Activate the no-PID configuration *and* poison the functors +# BEFORE importing the Hlt2 lines/settings. +with no_pid_tool.bind(labels=nopid_labels): + # Poison the Functors.* PID objects under this configuration + enable_pid_poisoning() + + # Now import the Hlt2 configuration; any use of F.PID_*, F.PROBNN_*, etc. + # in cuts will see the poisoned functors and the conditionally_skip + # wrappers will also consult this bound configuration. + from Hlt2Conf.settings.hlt2_pp_2024 import make_streams + + options.lines_maker = make_streams + + if UseDD4Hep: + # This needs to happen before the public tools are instantiated, + # which means we cannot put it inside make_streams(). + from PyConf.Tools import TrackMasterExtrapolator, TrackMasterFitter + + TrackMasterExtrapolator.global_bind( + ApplyMultScattCorr=False, + ApplyEnergyLossCorr=False, + ApplyElectronEnergyLossCorr=False, + ) + TrackMasterFitter.global_bind(ApplyMaterialCorrections=False) + + public_tools = [ + trackMasterExtrapolator_with_simplified_geom(), + stateProvider_with_simplified_geom(), + ] + + with reconstruction.bind(from_file=False): + config = run_moore(options, public_tools=public_tools) diff --git a/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py b/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py new file mode 100644 index 00000000000..bfc39547e88 --- /dev/null +++ b/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py @@ -0,0 +1,23 @@ +import Functors as F + +from RecoConf.no_pid_tool import poison_on_demand + + +def enable_pid_poisoning(): + F.PID_E = poison_on_demand(F.PID_E, label="PID") + F.PID_K = poison_on_demand(F.PID_K, label="PID") + F.PID_MU = poison_on_demand(F.PID_MU, label="PID") + F.PID_P = poison_on_demand(F.PID_P, label="PID") + F.PID_PI = poison_on_demand(F.PID_PI, label="PID") + + F.PROBNN_D = poison_on_demand(F.PROBNN_D, label="PROBNN") + F.PROBNN_E = poison_on_demand(F.PROBNN_E, label="PROBNN") + F.PROBNN_GHOST = poison_on_demand(F.PROBNN_GHOST, label="PROBNN") + F.PROBNN_K = poison_on_demand(F.PROBNN_K, label="PROBNN") + F.PROBNN_MU = poison_on_demand(F.PROBNN_MU, label="PROBNN") + F.PROBNN_P = poison_on_demand(F.PROBNN_P, label="PROBNN") + F.PROBNN_PI = poison_on_demand(F.PROBNN_PI, label="PROBNN") + + F.IS_PHOTON = poison_on_demand(F.IS_PHOTON, label="IS_PHOTON") + F.IS_NOT_H = poison_on_demand(F.IS_NOT_H, label="IS_NOT_H") + F.ISMUON = poison_on_demand(F.ISMUON, label="ISMUON") -- GitLab From 8d461ecf8b58e25e9c0e439e2723cdeccef2b946 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Tue, 2 Dec 2025 09:41:17 +0100 Subject: [PATCH 7/9] Add example production.py for lbexec usage of poisoning --- Hlt/Moore/python/Moore/production.py | 80 +++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/Hlt/Moore/python/Moore/production.py b/Hlt/Moore/python/Moore/production.py index cb87f440ae1..c5effbdd6c9 100644 --- a/Hlt/Moore/python/Moore/production.py +++ b/Hlt/Moore/python/Moore/production.py @@ -28,6 +28,33 @@ _trkeff_probe_matching_options = { } +# --------------------------------------------------------------------------- +# Helper: turn a regex string into a set of disabled Functor labels +# --------------------------------------------------------------------------- + + +def _labels_from_regex(pattern_str, supported_labels): + """ + Given a regex string and a set of supported labels, return the subset + of labels that match the regex. Raises ConfigurationError on bad regex. + """ + if not pattern_str: + return set() + + try: + pattern = re.compile(pattern_str) + except re.error as exc: + raise ConfigurationError( + f"Invalid regex for --disable-functor-groups-regex: {pattern_str!r}" + ) from exc + + return { + label + for label in supported_labels + if pattern.fullmatch(label) or pattern.search(label) + } + + def hlt1(options: Options, *raw_args): """Setup Hlt1 with given options @@ -93,6 +120,7 @@ def hlt2(options: Options, *raw_args): args.velo_source, args.without_ut, args.trkeff_probe_matching, + args.disable_functor_groups_regex, ) @@ -156,6 +184,16 @@ def _parse_args(raw_args): help="Binds MuonProbeToLongMatcher.addNeighbouringMuonHits if not null. (default: %(default)s)", ) + # NEW: regex over Functor labels (e.g. PID, PROBNN, PID_BOOL) + parser.add_argument( + "--disable-functor-groups-regex", + default=None, + help=( + "Regular expression over Functor labels to disable. " + 'Example: "(PID|PROBNN)" disables PID and PROBNN functors/cuts.' + ), + ) + return parser.parse_args(raw_args) @@ -171,7 +209,9 @@ def _hlt2( velo_source=None, without_ut=False, trkeff_probe_matching="2025-like", + disable_functor_groups_regex=None, ): + from Functors import enable_pid_poisoning, functor_label_switches from GaudiConf.LbExec import DSTFormatTypes from PyConf.Algorithms import MuonProbeToLongMatcher from PyConf.application import metainfo_repos, retrieve_encoding_dictionary @@ -225,6 +265,12 @@ def _hlt2( ] ) + # Decide which Functor labels to disable + SUPPORTED_LABELS = {"PID", "PROBNN", "PID_BOOL"} + disabled_labels = _labels_from_regex(disable_functor_groups_regex, SUPPORTED_LABELS) + if disabled_labels: + print(f"[Moore/Hlt2] Disabling functor groups: {sorted(disabled_labels)}") + def _my_line_maker(): return _line_maker( options, @@ -237,8 +283,11 @@ def _hlt2( process="Hlt2", ) - with reconstruction.bind(from_file=False), _apply_track_binds(): - config = run_moore(options, _my_line_maker, public_tools) + # Bind labels, apply poisoning, then run Moore + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + with reconstruction.bind(from_file=False), _apply_track_binds(): + config = run_moore(options, _my_line_maker, public_tools) if options.simulation: from Configurables import DeterministicPrescaler @@ -436,6 +485,7 @@ def spruce(options: Options, *raw_args): args.persistreco, args.rawbanks, args.reco_only, + args.disable_functor_groups_regex, ) @@ -449,8 +499,10 @@ def _spruce( persistreco=False, rawbanks=None, reco_only=False, + disable_functor_groups_regex=None, ): import Functors as F + from Functors import enable_pid_poisoning, functor_label_switches from PyConf.Algorithms import VoidFilter from PyConf.application import ( configure, @@ -476,6 +528,12 @@ def _spruce( trackMasterExtrapolator_with_simplified_geom, ) + # Decide which Functor labels to disable (same as for Hlt2) + SUPPORTED_LABELS = {"PID", "PROBNN", "PID_BOOL"} + disabled_labels = _labels_from_regex(disable_functor_groups_regex, SUPPORTED_LABELS) + if disabled_labels: + print(f"[Moore/Spruce] Disabling functor groups: {sorted(disabled_labels)}") + def _my_line_maker(): return _line_maker( options, @@ -515,7 +573,7 @@ def _spruce( options, _my_line_maker(), process="spruce" ) - ## Add this filter before `moore_control_node` using `LAZY_AND` logic IF flagging + # Add this filter before `moore_control_node` using `LAZY_AND` logic IF flagging moore_withfilter_node = CompositeNode( "MC_spruce_control_node", combine_logic=NodeLogic.LAZY_AND, @@ -525,11 +583,15 @@ def _spruce( if not flagging: moore_control_node = moore_withfilter_node - config.update( - configure(options, moore_control_node, public_tools=public_tools) - ) - ## Remove any line prescales + # Bind labels & poison before configuring + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + config.update( + configure(options, moore_control_node, public_tools=public_tools) + ) + + # Remove any line prescales from Configurables import DeterministicPrescaler from Gaudi.Configuration import allConfigurables @@ -548,7 +610,9 @@ def _spruce( trackMasterExtrapolator_with_simplified_geom(), stateProvider_with_simplified_geom(), ] - config = run_moore(options, _my_line_maker, public_tools) + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + config = run_moore(options, _my_line_maker, public_tools) return config -- GitLab From 379c71a857c706edaafc8f7c9c1eb9183a9f9795 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Fri, 5 Dec 2025 11:57:12 +0100 Subject: [PATCH 8/9] Adapt wrap_pid with new Rec setup --- .../hlt2_pp_2024_no_pid_functor-based copy.py | 43 ------ .../hlt2_pp_2024_no_pid_recoconf-based.py | 60 -------- scripts/wrap_pid_conditions.py | 135 ++++++++---------- 3 files changed, 58 insertions(+), 180 deletions(-) delete mode 100644 Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py delete mode 100644 Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py diff --git a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py deleted file mode 100644 index 8361cd7419c..00000000000 --- a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_functor-based copy.py +++ /dev/null @@ -1,43 +0,0 @@ -############################################################################### -# (c) Copyright 2024 CERN for the benefit of the LHCb Collaboration # -# # -# This software is distributed under the terms of the GNU General Public # -# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # -# # -# In applying this licence, CERN does not waive the privileges and immunities # -# granted to it by virtue of its status as an Intergovernmental Organization # -# or submit itself to any jurisdiction. # -############################################################################### -from DDDB.CheckDD4Hep import UseDD4Hep -from Hlt2Conf.settings.hlt2_pp_2024 import make_streams -from Moore import options, run_moore -from RecoConf.global_tools import ( - stateProvider_with_simplified_geom, - trackMasterExtrapolator_with_simplified_geom, -) -from RecoConf.no_pid_tool import no_pid_tool -from RecoConf.reconstruction_objects import reconstruction - -options.lines_maker = make_streams - -if UseDD4Hep: - # This needs to happen before the public tools are instantiated, - # which means we cannot put it inside make_streams(). - from PyConf.Tools import TrackMasterExtrapolator, TrackMasterFitter - - TrackMasterExtrapolator.global_bind( - ApplyMultScattCorr=False, - ApplyEnergyLossCorr=False, - ApplyElectronEnergyLossCorr=False, - ) - TrackMasterFitter.global_bind(ApplyMaterialCorrections=False) - -public_tools = [ - trackMasterExtrapolator_with_simplified_geom(), - stateProvider_with_simplified_geom(), -] - -nopid_labels = {"PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"} - -with reconstruction.bind(from_file=False), no_pid_tool.bind(labels=nopid_labels): - config = run_moore(options, public_tools=public_tools) diff --git a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py b/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py deleted file mode 100644 index c8580e78e3f..00000000000 --- a/Hlt/Hlt2Conf/options/hlt2_pp_2024_no_pid_recoconf-based.py +++ /dev/null @@ -1,60 +0,0 @@ -############################################################################### -# (c) Copyright 2024 CERN for the benefit of the LHCb Collaboration # -# # -# This software is distributed under the terms of the GNU General Public # -# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # -# # -# In applying this licence, CERN does not waive the privileges and immunities # -# granted to it by virtue of its status as an Intergovernmental Organization # -# or submit itself to any jurisdiction. # -############################################################################### - -from DDDB.CheckDD4Hep import UseDD4Hep -from Moore import options, run_moore -from RecoConf.global_tools import ( - stateProvider_with_simplified_geom, - trackMasterExtrapolator_with_simplified_geom, -) -from RecoConf.no_pid_poisoning import enable_pid_poisoning - -# NEW: no-PID machinery -from RecoConf.no_pid_tool import no_pid_tool -from RecoConf.reconstruction_objects import reconstruction - -# Choose which PID categories to switch off for this configuration. -# These must match the labels produced by your wrap_pid_conditions script -# and the labels used in enable_pid_poisoning(). -nopid_labels = {"PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"} - -# Activate the no-PID configuration *and* poison the functors -# BEFORE importing the Hlt2 lines/settings. -with no_pid_tool.bind(labels=nopid_labels): - # Poison the Functors.* PID objects under this configuration - enable_pid_poisoning() - - # Now import the Hlt2 configuration; any use of F.PID_*, F.PROBNN_*, etc. - # in cuts will see the poisoned functors and the conditionally_skip - # wrappers will also consult this bound configuration. - from Hlt2Conf.settings.hlt2_pp_2024 import make_streams - - options.lines_maker = make_streams - - if UseDD4Hep: - # This needs to happen before the public tools are instantiated, - # which means we cannot put it inside make_streams(). - from PyConf.Tools import TrackMasterExtrapolator, TrackMasterFitter - - TrackMasterExtrapolator.global_bind( - ApplyMultScattCorr=False, - ApplyEnergyLossCorr=False, - ApplyElectronEnergyLossCorr=False, - ) - TrackMasterFitter.global_bind(ApplyMaterialCorrections=False) - - public_tools = [ - trackMasterExtrapolator_with_simplified_geom(), - stateProvider_with_simplified_geom(), - ] - - with reconstruction.bind(from_file=False): - config = run_moore(options, public_tools=public_tools) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index 1034d7153c8..cc5d4e16100 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -130,10 +130,16 @@ def contains_pid_or_derived(node: cst.CSTNode, derived: set[str]) -> bool: def is_conditionally_skip_call(node: cst.CSTNode) -> bool: + """ + Detect F.conditionally_skip(...) + """ return ( isinstance(node, cst.Call) - and isinstance(node.func, cst.Name) - and node.func.value == "conditionally_skip" + and isinstance(node.func, cst.Attribute) + and isinstance(node.func.value, cst.Name) + and node.func.value.value == "F" + and isinstance(node.func.attr, cst.Name) + and node.func.attr.value == "conditionally_skip" ) @@ -268,12 +274,11 @@ class WrapPidCuts(cst.CSTTransformer): def _wrap(self, expr: cst.BaseExpression) -> cst.Call: """ - Wrap an expression in conditionally_skip(..., label=). + Wrap an expression in F.conditionally_skip(..., label=). - If we can find a PID attribute, we pass the category label ("PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"). - - Otherwise we emit conditionally_skip(expr) without label - (behaves like your original implementation). + - Otherwise we emit F.conditionally_skip(expr) without label. """ self.introduced_count += 1 @@ -287,7 +292,12 @@ class WrapPidCuts(cst.CSTTransformer): value=cst.SimpleString(f'"{label}"'), ) ) - return cst.Call(func=cst.Name("conditionally_skip"), args=args) + return cst.Call( + func=cst.Attribute( + value=cst.Name("F"), attr=cst.Name("conditionally_skip") + ), + args=args, + ) def visit_Call(self, node: cst.Call) -> None: self._in_cond_skip_stack.append(is_conditionally_skip_call(node)) @@ -297,7 +307,7 @@ class WrapPidCuts(cst.CSTTransformer): ) -> cst.BaseExpression: _ = self._in_cond_skip_stack.pop() if self._in_cond_skip_stack else False - # 1) Upgrade existing conditionally_skip(...) calls by adding a label + # 1) Upgrade existing F.conditionally_skip(...) calls by adding a label # if missing. if is_conditionally_skip_call(updated_node): # Do nothing if there's already a label=... arg @@ -404,36 +414,26 @@ class WrapPidCuts(cst.CSTTransformer): # ================================ -# Transformer: ensure import +# Transformer: ensure `import Functors as F` # ================================ -class EnsureConditionallySkipImport(cst.CSTTransformer): +class EnsureFunctorsImport(cst.CSTTransformer): """ - Ensure: from RecoConf.global_tools import conditionally_skip - Only acts if `should_apply` is True. - Handles both single-line and parenthesized multi-line import-from statements. + Ensure there is an `import Functors as F` at the top of the file + if `should_apply` is True and such an import is not already present. """ def __init__(self, should_apply: bool): self.should_apply = should_apply - self.injected_or_found = False - self.added_to_existing = False + self.has_import = False self.insert_after_index = 0 - def _is_target_module(self, module: cst.CSTNode | None) -> bool: - return ( - isinstance(module, cst.Attribute) - and isinstance(module.value, cst.Name) - and module.value.value == "RecoConf" - and isinstance(module.attr, cst.Name) - and module.attr.value == "global_tools" - ) - def visit_Module(self, node: cst.Module) -> None: - idx = 0 body = node.body - # after module docstring + idx = 0 + + # Skip over module docstring if body and isinstance(body[0], cst.SimpleStatementLine): stmt = body[0] if ( @@ -442,7 +442,8 @@ class EnsureConditionallySkipImport(cst.CSTTransformer): and isinstance(stmt.body[0].value, cst.SimpleString) ): idx = 1 - # after __future__ imports + + # Skip over __future__ imports while idx < len(body): stmt = body[idx] if ( @@ -458,64 +459,43 @@ class EnsureConditionallySkipImport(cst.CSTTransformer): idx += 1 continue break - self.insert_after_index = idx - def leave_ImportFrom( - self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom - ) -> cst.ImportFrom: - # If ANY import-from brings conditionally_skip into scope, mark as found - names = updated_node.names - if isinstance(names, cst.ImportStar): - if self._is_target_module(updated_node.module): - self.injected_or_found = True - return updated_node - - if isinstance(names, tuple): - for n in names: - if isinstance(n, cst.ImportAlias): - if ( - isinstance(n.name, cst.Name) - and n.name.value == "conditionally_skip" - ): - self.injected_or_found = True - - # Append to `from RecoConf.global_tools import ...` if missing - if self.should_apply and self._is_target_module(updated_node.module): - if isinstance(names, tuple): - have = any( - isinstance(n, cst.ImportAlias) - and isinstance(n.name, cst.Name) - and n.name.value == "conditionally_skip" - for n in names - ) - if not have: - new_names = list(names) + [ - cst.ImportAlias(name=cst.Name("conditionally_skip")) - ] - self.added_to_existing = True - self.injected_or_found = True - return updated_node.with_changes(names=tuple(new_names)) + self.insert_after_index = idx - return updated_node + def visit_Import(self, node: cst.Import) -> None: + # Detect existing "import Functors as F" + for alias in node.names: + if not isinstance(alias, cst.ImportAlias): + continue + if ( + isinstance(alias.name, cst.Name) + and alias.name.value == "Functors" + and alias.asname is not None + and isinstance(alias.asname.name, cst.Name) + and alias.asname.name.value == "F" + ): + self.has_import = True def leave_Module( self, original_node: cst.Module, updated_node: cst.Module ) -> cst.Module: - if not self.should_apply: + if not self.should_apply or self.has_import: return updated_node - if self.injected_or_found: - return updated_node - # Insert a fresh import line + + # Insert: import Functors as F new_import = cst.SimpleStatementLine( body=[ - cst.ImportFrom( - module=cst.Attribute( - value=cst.Name("RecoConf"), attr=cst.Name("global_tools") - ), - names=(cst.ImportAlias(name=cst.Name("conditionally_skip")),), + cst.Import( + names=[ + cst.ImportAlias( + name=cst.Name("Functors"), + asname=cst.AsName(name=cst.Name("F")), + ) + ] ) ] ) + body = list(updated_node.body) i = min(self.insert_after_index, len(body)) body.insert(i, new_import) @@ -545,10 +525,8 @@ def process_file(path: Path, write: bool) -> bool: wrapper = WrapPidCuts(derived_names=collector.derived) mod1 = mod.visit(wrapper) - # Pass 2: ensure import only if we introduced any wrappers - mod2 = mod1.visit( - EnsureConditionallySkipImport(should_apply=(wrapper.introduced_count > 0)) - ) + # Pass 2: ensure `import Functors as F` only if we introduced any wrappers + mod2 = mod1.visit(EnsureFunctorsImport(should_apply=(wrapper.introduced_count > 0))) changed = mod2.code != code if changed and write: @@ -558,7 +536,10 @@ def process_file(path: Path, write: bool) -> bool: def main(): ap = argparse.ArgumentParser( - description="Wrap PID cuts with conditionally_skip(...), including expressions using PID-derived variables, and ensure proper import." + description=( + "Wrap PID cuts with F.conditionally_skip(...), including expressions " + "using PID-derived variables, and ensure `import Functors as F`." + ) ) ap.add_argument("files", nargs="+", help=".py files to process") ap.add_argument( -- GitLab From 4da46b1226af8d838ee80b0ab354349608684f41 Mon Sep 17 00:00:00 2001 From: Jan wagner Date: Sat, 6 Dec 2025 01:22:13 +0100 Subject: [PATCH 9/9] Make pid is None safe --- scripts/wrap_pid_conditions.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py index cc5d4e16100..b3c34125dbb 100755 --- a/scripts/wrap_pid_conditions.py +++ b/scripts/wrap_pid_conditions.py @@ -109,12 +109,25 @@ def contains_pid_marker(node: cst.CSTNode) -> bool: def contains_pid_or_derived(node: cst.CSTNode, derived: set[str]) -> bool: """ - True if the subtree contains a PID marker (F.PID_*, F.PROBNN_*, IS_PHOTON, IS_NOT_H, ISMUON) - OR references any name already marked as PID-derived. + True if the subtree contains a PID marker (F.PID_*, ...) OR references a + name already marked as PID-derived, *excluding* simple None-guards like + `pid is not None`. """ if contains_pid_marker(node): return True + # Special case: `derived_name is (not) None` is just a guard, not the cut. + if isinstance(node, cst.Comparison): + if isinstance(node.left, cst.Name) and node.left.value in derived: + if all( + isinstance(comp.operator, (cst.Is, cst.IsNot)) + and isinstance(comp.comparator, cst.Name) + and comp.comparator.value == "None" + for comp in node.comparisons + ): + # treat as NOT PID-related + return False + found = False class FindDerived(cst.CSTVisitor): -- GitLab