diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bb97b58dfc4ccbdb7625757e82742c6a19cc7c5..0ad32462bf425ec8ceafa7fb140488844a7a7ef4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,14 @@ repos: hooks: - id: lb-add-copyright exclude: "^lhcbproject.yml$" +- repo: local + hooks: + - id: wrap-pid-conditions + name: Wrap PID cuts with conditionally_skip + ensure import + entry: python3 scripts/wrap_pid_conditions.py + language: system + types: [python] + pass_filenames: true + additional_dependencies: [] + # To fail CI without auto-fixing, switch to: + # entry: python3 scripts/wrap_pid_conditions.py --check diff --git a/Hlt/Moore/python/Moore/production.py b/Hlt/Moore/python/Moore/production.py index cb87f440ae1f522ceb3ad2ac4a21c9fb90ea0120..c5effbdd6c9cfd4aa621d93392874c9dca9771e4 100644 --- a/Hlt/Moore/python/Moore/production.py +++ b/Hlt/Moore/python/Moore/production.py @@ -28,6 +28,33 @@ _trkeff_probe_matching_options = { } +# --------------------------------------------------------------------------- +# Helper: turn a regex string into a set of disabled Functor labels +# --------------------------------------------------------------------------- + + +def _labels_from_regex(pattern_str, supported_labels): + """ + Given a regex string and a set of supported labels, return the subset + of labels that match the regex. Raises ConfigurationError on bad regex. + """ + if not pattern_str: + return set() + + try: + pattern = re.compile(pattern_str) + except re.error as exc: + raise ConfigurationError( + f"Invalid regex for --disable-functor-groups-regex: {pattern_str!r}" + ) from exc + + return { + label + for label in supported_labels + if pattern.fullmatch(label) or pattern.search(label) + } + + def hlt1(options: Options, *raw_args): """Setup Hlt1 with given options @@ -93,6 +120,7 @@ def hlt2(options: Options, *raw_args): args.velo_source, args.without_ut, args.trkeff_probe_matching, + args.disable_functor_groups_regex, ) @@ -156,6 +184,16 @@ def _parse_args(raw_args): help="Binds MuonProbeToLongMatcher.addNeighbouringMuonHits if not null. (default: %(default)s)", ) + # NEW: regex over Functor labels (e.g. PID, PROBNN, PID_BOOL) + parser.add_argument( + "--disable-functor-groups-regex", + default=None, + help=( + "Regular expression over Functor labels to disable. " + 'Example: "(PID|PROBNN)" disables PID and PROBNN functors/cuts.' + ), + ) + return parser.parse_args(raw_args) @@ -171,7 +209,9 @@ def _hlt2( velo_source=None, without_ut=False, trkeff_probe_matching="2025-like", + disable_functor_groups_regex=None, ): + from Functors import enable_pid_poisoning, functor_label_switches from GaudiConf.LbExec import DSTFormatTypes from PyConf.Algorithms import MuonProbeToLongMatcher from PyConf.application import metainfo_repos, retrieve_encoding_dictionary @@ -225,6 +265,12 @@ def _hlt2( ] ) + # Decide which Functor labels to disable + SUPPORTED_LABELS = {"PID", "PROBNN", "PID_BOOL"} + disabled_labels = _labels_from_regex(disable_functor_groups_regex, SUPPORTED_LABELS) + if disabled_labels: + print(f"[Moore/Hlt2] Disabling functor groups: {sorted(disabled_labels)}") + def _my_line_maker(): return _line_maker( options, @@ -237,8 +283,11 @@ def _hlt2( process="Hlt2", ) - with reconstruction.bind(from_file=False), _apply_track_binds(): - config = run_moore(options, _my_line_maker, public_tools) + # Bind labels, apply poisoning, then run Moore + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + with reconstruction.bind(from_file=False), _apply_track_binds(): + config = run_moore(options, _my_line_maker, public_tools) if options.simulation: from Configurables import DeterministicPrescaler @@ -436,6 +485,7 @@ def spruce(options: Options, *raw_args): args.persistreco, args.rawbanks, args.reco_only, + args.disable_functor_groups_regex, ) @@ -449,8 +499,10 @@ def _spruce( persistreco=False, rawbanks=None, reco_only=False, + disable_functor_groups_regex=None, ): import Functors as F + from Functors import enable_pid_poisoning, functor_label_switches from PyConf.Algorithms import VoidFilter from PyConf.application import ( configure, @@ -476,6 +528,12 @@ def _spruce( trackMasterExtrapolator_with_simplified_geom, ) + # Decide which Functor labels to disable (same as for Hlt2) + SUPPORTED_LABELS = {"PID", "PROBNN", "PID_BOOL"} + disabled_labels = _labels_from_regex(disable_functor_groups_regex, SUPPORTED_LABELS) + if disabled_labels: + print(f"[Moore/Spruce] Disabling functor groups: {sorted(disabled_labels)}") + def _my_line_maker(): return _line_maker( options, @@ -515,7 +573,7 @@ def _spruce( options, _my_line_maker(), process="spruce" ) - ## Add this filter before `moore_control_node` using `LAZY_AND` logic IF flagging + # Add this filter before `moore_control_node` using `LAZY_AND` logic IF flagging moore_withfilter_node = CompositeNode( "MC_spruce_control_node", combine_logic=NodeLogic.LAZY_AND, @@ -525,11 +583,15 @@ def _spruce( if not flagging: moore_control_node = moore_withfilter_node - config.update( - configure(options, moore_control_node, public_tools=public_tools) - ) - ## Remove any line prescales + # Bind labels & poison before configuring + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + config.update( + configure(options, moore_control_node, public_tools=public_tools) + ) + + # Remove any line prescales from Configurables import DeterministicPrescaler from Gaudi.Configuration import allConfigurables @@ -548,7 +610,9 @@ def _spruce( trackMasterExtrapolator_with_simplified_geom(), stateProvider_with_simplified_geom(), ] - config = run_moore(options, _my_line_maker, public_tools) + with functor_label_switches.bind(labels=disabled_labels): + enable_pid_poisoning() + config = run_moore(options, _my_line_maker, public_tools) return config diff --git a/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py b/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc39547e88ec8185cec4443c9b1021a1709f738 --- /dev/null +++ b/Hlt/RecoConf/python/RecoConf/no_pid_poisoning.py @@ -0,0 +1,23 @@ +import Functors as F + +from RecoConf.no_pid_tool import poison_on_demand + + +def enable_pid_poisoning(): + F.PID_E = poison_on_demand(F.PID_E, label="PID") + F.PID_K = poison_on_demand(F.PID_K, label="PID") + F.PID_MU = poison_on_demand(F.PID_MU, label="PID") + F.PID_P = poison_on_demand(F.PID_P, label="PID") + F.PID_PI = poison_on_demand(F.PID_PI, label="PID") + + F.PROBNN_D = poison_on_demand(F.PROBNN_D, label="PROBNN") + F.PROBNN_E = poison_on_demand(F.PROBNN_E, label="PROBNN") + F.PROBNN_GHOST = poison_on_demand(F.PROBNN_GHOST, label="PROBNN") + F.PROBNN_K = poison_on_demand(F.PROBNN_K, label="PROBNN") + F.PROBNN_MU = poison_on_demand(F.PROBNN_MU, label="PROBNN") + F.PROBNN_P = poison_on_demand(F.PROBNN_P, label="PROBNN") + F.PROBNN_PI = poison_on_demand(F.PROBNN_PI, label="PROBNN") + + F.IS_PHOTON = poison_on_demand(F.IS_PHOTON, label="IS_PHOTON") + F.IS_NOT_H = poison_on_demand(F.IS_NOT_H, label="IS_NOT_H") + F.ISMUON = poison_on_demand(F.ISMUON, label="ISMUON") diff --git a/Hlt/RecoConf/python/RecoConf/no_pid_tool.py b/Hlt/RecoConf/python/RecoConf/no_pid_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..58e076798e402a34642254d1f421ecd2e964dc46 --- /dev/null +++ b/Hlt/RecoConf/python/RecoConf/no_pid_tool.py @@ -0,0 +1,175 @@ +############################################################################### +# (c) Copyright 2025 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### +""" +Utilities to *switch off PID* in a controlled way. + +Intended use case: produce “no PID” (noPID) samples, e.g. HLT2 trigger on +simulation where PID is known to be poorly modelled and is corrected +later with data-driven tools. + +This module provides: + +- no_pid_tool: configurable that defines which PID “labels” are switched off. +- conditionally_skip: wrapper used around PID-dependent cuts. +- poison_on_demand: wrapper used around PID functors themselves, as a safety net. + +If a label is “switched off” via no_pid_tool.bind(...), then: + +- PID cuts wrapped by conditionally_skip(..., label=...) will be short-circuited. +- PID functors wrapped by poison_on_demand(..., label=...) will throw if evaluated. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import FrozenSet, Iterable, Optional + +import Functors as F +from PyConf import configurable + +# --------------------------------------------------------------------------- +# 1) Central configuration: which labels are "no PID" +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _NoPidState: + """ + Internal state object returned by no_pid_tool. + + labels: + A set of labels for which PID is switched off. A “label” is just a + string attached consistently to: + - the PID functor (via poison_on_demand), + - the PID cut (via conditionally_skip). + """ + + labels: FrozenSet[str] = frozenset() + + +@configurable +def no_pid_tool( + labels: Iterable[str] = (), +) -> _NoPidState: + """ + Central knob to switch off PID for selected groups of functors / cuts. + + Typical usage in job options (for simulation): + + from RecoConf.no_pid_tool import no_pid_tool + + # Switch off PID for selected labels: + no_pid_tool.bind(labels={"PID_K", "PID_P", "PID_MU"}) + + Any PID functor / cut annotated with one of these labels will: + * be skipped via conditionally_skip(..., label="..."), + * be poisoned via poison_on_demand(..., label="..."). + + You are free to define labels however you like (per particle species, + per detector, or a single coarse label like "ALL_PID"). + """ + return _NoPidState(labels=frozenset(labels)) + + +# --------------------------------------------------------------------------- +# 2) conditionally_skip: used by the rewrite script on PID cuts +# --------------------------------------------------------------------------- + + +@configurable +def conditionally_skip( + predicate, + label: Optional[str] = None, +): + """ + Wrap a PID-dependent cut so it can be switched off. + + The libcst script will rewrite PID cuts like: + + F.PID_K > 3 + + into something like: + + conditionally_skip(F.PID_K > 3, label="PID_K") + + Behaviour: + + - If `label` is None -> no-op, just returns `predicate`. + - If `label` is not listed in no_pid_tool().labels -> no-op. + - If `label` *is* listed in no_pid_tool().labels: + return F.ACCEPT(True) | predicate + + So the functor algebra will short-circuit and `predicate` will + not be evaluated, effectively removing that PID cut. + """ + cfg = no_pid_tool() + + # If no label or label not configured for no-PID, leave the cut as-is. + if label is None or label not in cfg.labels: + return predicate + + # When this label is in "no PID" mode: force acceptance so that + # the underlying predicate is never evaluated. + # + # In C++ this corresponds to Functors::Accept{ true }. + return F.ACCEPT(True) | predicate + + +# --------------------------------------------------------------------------- +# 3) poison_on_demand: used on PID functors themselves +# --------------------------------------------------------------------------- + + +@configurable +def poison_on_demand( + functor, + label: Optional[str] = None, +): + """ + Wrap PID-related functors so they *throw* if evaluated while the + corresponding label is in "no PID" mode. + + Intended usage at functor definition time, e.g. where F.PID_K is created: + + from RecoConf.no_pid_tool import poison_on_demand + + F.PID_K = poison_on_demand(F.PID_K, label="PID_K") + F.PROBNN_K = poison_on_demand(F.PROBNN_K, label="PID_K") + + Then, when you configure a no-PID sample: + + from RecoConf.no_pid_tool import no_pid_tool + no_pid_tool.bind(labels={"PID_K"}) + + you get: + + - PID cuts rewritten as conditionally_skip(..., label="PID_K") + being skipped cleanly; and + + - any remaining use of F.PID_K / F.PROBNN_K (e.g. missed by rewriting, + or used directly elsewhere) triggering an exception, so you don't + silently get PID dependence in a “no PID” configuration. + """ + cfg = no_pid_tool() + + # If this label isn't in "no PID" mode, leave the functor unchanged. + if label is None or label not in cfg.labels: + return functor + + # When the label is in "no PID" mode: replace with a POISON functor. + # + # This matches the C++ pattern: + # Functors::Functional::Poison( Functors::PID::IsMuon, "do not use IsMuon" ); + msg = ( + f"PID functor with label {label!r} is disabled by no_pid_tool and " + f"must not be evaluated (no-PID configuration)." + ) + return F.POISON(functor, msg) diff --git a/scripts/wrap_pid_conditions.py b/scripts/wrap_pid_conditions.py new file mode 100755 index 0000000000000000000000000000000000000000..b3c34125dbbd86c354d0980cf159d53fd34c3809 --- /dev/null +++ b/scripts/wrap_pid_conditions.py @@ -0,0 +1,578 @@ +#!/usr/bin/env python3 +############################################################################### +# (c) Copyright 2025 CERN for the benefit of the LHCb Collaboration # +# # +# This software is distributed under the terms of the GNU General Public # +# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING". # +# # +# In applying this licence, CERN does not waive the privileges and immunities # +# granted to it by virtue of its status as an Intergovernmental Organization # +# or submit itself to any jurisdiction. # +############################################################################### +from __future__ import annotations + +import argparse +import re +import sys +from pathlib import Path + +import libcst as cst + +# ================================ +# Config +# ================================ + +PID_ATTR_RE = re.compile(r"^((PID|PROBNN)_|IS_PHOTON|IS_NOT_H)", re.IGNORECASE) +PID_BOOL_NAMES = {"ISMUON"} + +PID_WRAPPER_FUNCS = { + "CHILD", + "MIN", + "MAX", + "SUM", + "ANY", + "ALL", + "MINTREE", + "MAXTREE", + "in_range", +} + +BOOL_CONSUMERS = {"SUM", "ANY", "ALL", "COUNT"} + +# ================================ +# Helpers +# ================================ + + +def is_F_attr(node: cst.CSTNode) -> bool: + return ( + isinstance(node, cst.Attribute) + and isinstance(node.value, cst.Name) + and node.value.value == "F" + and isinstance(node.attr, cst.Name) + ) + + +def attr_value(node: cst.Attribute) -> str: + return node.attr.value # type: ignore[return-value] + + +def is_pid_attribute(node: cst.CSTNode) -> bool: + if not is_F_attr(node): + return False + name = attr_value(node) + return bool(PID_ATTR_RE.match(name)) or (name in PID_BOOL_NAMES) + + +def is_pid_wrapper_call(node: cst.CSTNode) -> bool: + if not isinstance(node, cst.Call): + return False + f = node.func + if is_F_attr(f): + return attr_value(f) in PID_WRAPPER_FUNCS + return False + + +def func_name_like(node: cst.CSTNode) -> str | None: + if isinstance(node, cst.Name): + return node.value + if is_F_attr(node): + return attr_value(node) + if isinstance(node, cst.Attribute) and isinstance(node.attr, cst.Name): + return node.attr.value + return None + + +def contains_pid_marker(node: cst.CSTNode) -> bool: + found = False + + class Finder(cst.CSTVisitor): + def visit_Attribute(self, n: cst.Attribute) -> None: + nonlocal found + if found: + return + if is_pid_attribute(n): + found = True + + def visit_Call(self, n: cst.Call) -> None: + nonlocal found + if found: + return + if is_pid_wrapper_call(n): + for a in n.args: + if a.value is not None and contains_pid_marker(a.value): + found = True + + node.visit(Finder()) + return found + + +def contains_pid_or_derived(node: cst.CSTNode, derived: set[str]) -> bool: + """ + True if the subtree contains a PID marker (F.PID_*, ...) OR references a + name already marked as PID-derived, *excluding* simple None-guards like + `pid is not None`. + """ + if contains_pid_marker(node): + return True + + # Special case: `derived_name is (not) None` is just a guard, not the cut. + if isinstance(node, cst.Comparison): + if isinstance(node.left, cst.Name) and node.left.value in derived: + if all( + isinstance(comp.operator, (cst.Is, cst.IsNot)) + and isinstance(comp.comparator, cst.Name) + and comp.comparator.value == "None" + for comp in node.comparisons + ): + # treat as NOT PID-related + return False + + found = False + + class FindDerived(cst.CSTVisitor): + def visit_Name(self, n: cst.Name) -> None: + nonlocal found + if found: + return + if n.value in derived: + found = True + + node.visit(FindDerived()) + return found + + +def is_conditionally_skip_call(node: cst.CSTNode) -> bool: + """ + Detect F.conditionally_skip(...) + """ + return ( + isinstance(node, cst.Call) + and isinstance(node.func, cst.Attribute) + and isinstance(node.func.value, cst.Name) + and node.func.value.value == "F" + and isinstance(node.func.attr, cst.Name) + and node.func.attr.value == "conditionally_skip" + ) + + +def is_boolean_like(node: cst.CSTNode) -> bool: + if is_pid_attribute(node): + return attr_value(node) in PID_BOOL_NAMES + if isinstance(node, cst.UnaryOperation) and isinstance( + node.operator, cst.BitInvert + ): + return is_boolean_like(node.expression) + return False + + +# NEW: map attribute name -> category label +def pid_category_from_attr_name(name: str) -> str | None: + """ + Map a PID-like attribute name to a coarse category label. + + Examples: + PID_K -> "PID" + PROBNN_MU -> "PROBNN" + IS_PHOTON -> "IS_PHOTON" + IS_NOT_H -> "IS_NOT_H" + ISMUON -> "ISMUON" + """ + upper = name.upper() + if upper.startswith("PID_"): + return "PID" + if upper.startswith("PROBNN_"): + return "PROBNN" + if upper in {"IS_PHOTON", "IS_NOT_H"}: + return upper + if upper in PID_BOOL_NAMES: + return upper + return None + + +# NEW: derive category label from an expression +def pid_label_from_expr(node: cst.CSTNode) -> str | None: + """ + Try to derive a *category* label from a PID-bearing expression. + + Strategy: + - Find the first F.PID_*/F.PROBNN_* attribute, or IS_PHOTON/IS_NOT_H/ISMUON. + - Map it to a category: + PID_* -> "PID" + PROBNN_* -> "PROBNN" + IS_PHOTON -> "IS_PHOTON" + IS_NOT_H -> "IS_NOT_H" + ISMUON -> "ISMUON" + - If none is found, return None. + """ + label: str | None = None + + class Finder(cst.CSTVisitor): + def visit_Attribute(self, n: cst.Attribute) -> None: + nonlocal label + if label is not None: + return + if is_pid_attribute(n): + cat = pid_category_from_attr_name(attr_value(n)) + if cat is not None: + label = cat + + def visit_Call(self, n: cst.Call) -> None: + # Recurse into PID wrapper calls like F.in_range(...) + if is_pid_wrapper_call(n): + for a in n.args: + if a.value is not None: + a.value.visit(self) + + node.visit(Finder()) + return label + + +# ================================ +# Pre-pass: collect PID-derived names +# ================================ + + +class CollectPidDerivedNames(cst.CSTVisitor): + """ + Builds a set of simple names assigned from PID-bearing expressions, + and propagates through reassignments/aug-assignments. + + Tracks only simple names (e.g. x = ...), not attributes (self.x) or destructuring. + """ + + def __init__(self): + self.derived: set[str] = set() + + def _expr_is_pid_related(self, expr: cst.BaseExpression) -> bool: + return contains_pid_or_derived(expr, self.derived) + + def visit_Assign(self, node: cst.Assign) -> None: + # a = or a = b = + if node.value is None: + return + if not self._expr_is_pid_related(node.value): + return + for t in node.targets: + if isinstance(t.target, cst.Name): + self.derived.add(t.target.value) + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + # a: T = + if node.value is None: + return + if not self._expr_is_pid_related(node.value): + return + if isinstance(node.target, cst.Name): + self.derived.add(node.target.value) + + def visit_AugAssign(self, node: cst.AugAssign) -> None: + # a *= , a += , etc. (mark 'a' if RHS is PID-related) + if not self._expr_is_pid_related(node.value): + return + if isinstance(node.target, cst.Name): + self.derived.add(node.target.value) + + +# ================================ +# Transformer: wrap PID cuts +# ================================ + + +class WrapPidCuts(cst.CSTTransformer): + def __init__(self, derived_names: set[str]): + self._in_cond_skip_stack: list[bool] = [] + self.introduced_count: int = 0 + self.derived = set(derived_names) + + def _wrap(self, expr: cst.BaseExpression) -> cst.Call: + """ + Wrap an expression in F.conditionally_skip(..., label=). + + - If we can find a PID attribute, we pass the category label + ("PID", "PROBNN", "IS_PHOTON", "IS_NOT_H", "ISMUON"). + - Otherwise we emit F.conditionally_skip(expr) without label. + """ + self.introduced_count += 1 + + label = pid_label_from_expr(expr) + + args: list[cst.Arg] = [cst.Arg(expr)] + if label is not None: + args.append( + cst.Arg( + keyword=cst.Name("label"), + value=cst.SimpleString(f'"{label}"'), + ) + ) + return cst.Call( + func=cst.Attribute( + value=cst.Name("F"), attr=cst.Name("conditionally_skip") + ), + args=args, + ) + + def visit_Call(self, node: cst.Call) -> None: + self._in_cond_skip_stack.append(is_conditionally_skip_call(node)) + + def leave_Call( + self, original_node: cst.Call, updated_node: cst.Call + ) -> cst.BaseExpression: + _ = self._in_cond_skip_stack.pop() if self._in_cond_skip_stack else False + + # 1) Upgrade existing F.conditionally_skip(...) calls by adding a label + # if missing. + if is_conditionally_skip_call(updated_node): + # Do nothing if there's already a label=... arg + has_label_kw = any( + isinstance(a.keyword, cst.Name) and a.keyword.value == "label" + for a in updated_node.args + if a.keyword is not None + ) + if has_label_kw: + return updated_node + + # Take the first positional argument as the predicate expression + if not updated_node.args: + return updated_node + pred_arg = updated_node.args[0].value + if pred_arg is None: + return updated_node + + label = pid_label_from_expr(pred_arg) + if label is None: + return updated_node + + new_args = list(updated_node.args) + [ + cst.Arg( + keyword=cst.Name("label"), + value=cst.SimpleString(f'"{label}"'), + ) + ] + return updated_node.with_changes(args=tuple(new_args)) + + # 2) Existing BOOL_CONSUMERS logic (unchanged) + fn = func_name_like(updated_node.func) or "" + if is_F_attr(updated_node.func) and (fn in BOOL_CONSUMERS): + new_args = [] + for a in updated_node.args: + v = a.value + if ( + isinstance(v, cst.Comparison) + and contains_pid_or_derived(v, self.derived) + and not is_conditionally_skip_call(v) + ): + v = self._wrap(v) + new_args.append(a.with_changes(value=v)) + return updated_node.with_changes(args=new_args) + + return updated_node + + def leave_Comparison( + self, original_node: cst.Comparison, updated_node: cst.Comparison + ) -> cst.BaseExpression: + if any(self._in_cond_skip_stack): + return updated_node + if contains_pid_or_derived(updated_node, self.derived): + return self._wrap(updated_node) + return updated_node + + def leave_UnaryOperation( + self, original_node: cst.UnaryOperation, updated_node: cst.UnaryOperation + ) -> cst.BaseExpression: + if any(self._in_cond_skip_stack): + return updated_node + if isinstance(updated_node.operator, cst.BitInvert): + opnd = updated_node.expression + if ( + isinstance(opnd, cst.Comparison) + and contains_pid_or_derived(opnd, self.derived) + ) or ( + contains_pid_or_derived(opnd, self.derived) and is_boolean_like(opnd) + ): + return self._wrap(updated_node) + return updated_node + + def leave_BinaryOperation( + self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation + ) -> cst.BaseExpression: + if isinstance(updated_node.operator, (cst.BitAnd, cst.BitOr)): + + def maybe_wrap(e: cst.BaseExpression) -> cst.BaseExpression: + if is_conditionally_skip_call(e): + return e + if isinstance(e, cst.Comparison) and contains_pid_or_derived( + e, self.derived + ): + return self._wrap(e) + if isinstance(e, cst.UnaryOperation) and isinstance( + e.operator, cst.BitInvert + ): + inner = e.expression + if ( + isinstance(inner, cst.Comparison) + and contains_pid_or_derived(inner, self.derived) + ) or ( + contains_pid_or_derived(inner, self.derived) + and is_boolean_like(inner) + ): + return self._wrap(e) + return e + + return updated_node.with_changes( + left=maybe_wrap(updated_node.left), + right=maybe_wrap(updated_node.right), + ) + return updated_node + + +# ================================ +# Transformer: ensure `import Functors as F` +# ================================ + + +class EnsureFunctorsImport(cst.CSTTransformer): + """ + Ensure there is an `import Functors as F` at the top of the file + if `should_apply` is True and such an import is not already present. + """ + + def __init__(self, should_apply: bool): + self.should_apply = should_apply + self.has_import = False + self.insert_after_index = 0 + + def visit_Module(self, node: cst.Module) -> None: + body = node.body + idx = 0 + + # Skip over module docstring + if body and isinstance(body[0], cst.SimpleStatementLine): + stmt = body[0] + if ( + len(stmt.body) == 1 + and isinstance(stmt.body[0], cst.Expr) + and isinstance(stmt.body[0].value, cst.SimpleString) + ): + idx = 1 + + # Skip over __future__ imports + while idx < len(body): + stmt = body[idx] + if ( + isinstance(stmt, cst.SimpleStatementLine) + and stmt.body + and isinstance(stmt.body[0], cst.ImportFrom) + ): + imp: cst.ImportFrom = stmt.body[0] + if ( + isinstance(imp.module, cst.Name) + and imp.module.value == "__future__" + ): + idx += 1 + continue + break + + self.insert_after_index = idx + + def visit_Import(self, node: cst.Import) -> None: + # Detect existing "import Functors as F" + for alias in node.names: + if not isinstance(alias, cst.ImportAlias): + continue + if ( + isinstance(alias.name, cst.Name) + and alias.name.value == "Functors" + and alias.asname is not None + and isinstance(alias.asname.name, cst.Name) + and alias.asname.name.value == "F" + ): + self.has_import = True + + def leave_Module( + self, original_node: cst.Module, updated_node: cst.Module + ) -> cst.Module: + if not self.should_apply or self.has_import: + return updated_node + + # Insert: import Functors as F + new_import = cst.SimpleStatementLine( + body=[ + cst.Import( + names=[ + cst.ImportAlias( + name=cst.Name("Functors"), + asname=cst.AsName(name=cst.Name("F")), + ) + ] + ) + ] + ) + + body = list(updated_node.body) + i = min(self.insert_after_index, len(body)) + body.insert(i, new_import) + return updated_node.with_changes(body=body) + + +# ================================ +# Driver +# ================================ + + +def process_file(path: Path, write: bool) -> bool: + try: + code = path.read_text(encoding="utf-8") + except Exception: + return False + try: + mod = cst.parse_module(code) + except Exception: + return False + + # Pre-pass: collect PID-derived names + collector = CollectPidDerivedNames() + mod.visit(collector) + + # Pass 1: wrap PID cuts (including expressions using derived names) + wrapper = WrapPidCuts(derived_names=collector.derived) + mod1 = mod.visit(wrapper) + + # Pass 2: ensure `import Functors as F` only if we introduced any wrappers + mod2 = mod1.visit(EnsureFunctorsImport(should_apply=(wrapper.introduced_count > 0))) + + changed = mod2.code != code + if changed and write: + path.write_text(mod2.code, encoding="utf-8") + return changed + + +def main(): + ap = argparse.ArgumentParser( + description=( + "Wrap PID cuts with F.conditionally_skip(...), including expressions " + "using PID-derived variables, and ensure `import Functors as F`." + ) + ) + ap.add_argument("files", nargs="+", help=".py files to process") + ap.add_argument( + "--check", + action="store_true", + help="Only check; exit 1 if changes would be made", + ) + args = ap.parse_args() + + any_changed = False + for f in args.files: + p = Path(f) + if p.suffix != ".py" or not p.is_file(): + continue + if process_file(p, write=not args.check): + any_changed = True + + if args.check and any_changed: + sys.exit(1) + + +if __name__ == "__main__": + main()