[go: up one dir, main page]

File: strategy.cpp

package info (click to toggle)
iminuit 2.30.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 8,660 kB
  • sloc: cpp: 14,591; python: 11,177; makefile: 11; sh: 5
file content (82 lines) | stat: -rw-r--r-- 3,415 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include "equal.hpp"
#include "pybind11.hpp"
#include <Minuit2/MnStrategy.h>

namespace ROOT {
namespace Minuit2 {
bool operator==(const MnStrategy& a, const MnStrategy& b) {
  return a.Strategy() == b.Strategy() && a.GradientNCycles() == b.GradientNCycles() &&
         a.GradientStepTolerance() == b.GradientStepTolerance() &&
         a.GradientTolerance() == b.GradientTolerance() &&
         a.HessianNCycles() == b.HessianNCycles() &&
         a.HessianStepTolerance() == b.HessianStepTolerance() &&
         a.HessianG2Tolerance() == b.HessianG2Tolerance() &&
         a.HessianGradientNCycles() == b.HessianGradientNCycles() &&
         a.StorageLevel() == b.StorageLevel();
}
} // namespace Minuit2
} // namespace ROOT

namespace py = pybind11;
using namespace ROOT::Minuit2;

void set_strategy(MnStrategy& self, unsigned s) {
  switch (s) {
    case 0: self.SetLowStrategy(); break;
    case 1: self.SetMediumStrategy(); break;
    case 2: self.SetHighStrategy(); break;
    default: throw std::invalid_argument("invalid strategy");
  }
}

void bind_strategy(py::module m) {
  py::class_<MnStrategy>(m, "MnStrategy")

      .def(py::init<>())
      .def(py::init<unsigned>())
      .def_property("strategy", &MnStrategy::Strategy, set_strategy)

      .def_property("gradient_ncycles", &MnStrategy::GradientNCycles,
                    &MnStrategy::SetGradientNCycles)
      .def_property("gradient_step_tolerance", &MnStrategy::GradientStepTolerance,
                    &MnStrategy::SetGradientStepTolerance)
      .def_property("gradient_tolerance", &MnStrategy::GradientTolerance,
                    &MnStrategy::SetGradientTolerance)
      .def_property("hessian_ncycles", &MnStrategy::HessianNCycles,
                    &MnStrategy::SetHessianNCycles)
      .def_property("hessian_step_tolerance", &MnStrategy::HessianStepTolerance,
                    &MnStrategy::SetHessianStepTolerance)
      .def_property("hessian_g2_tolerance", &MnStrategy::HessianG2Tolerance,
                    &MnStrategy::SetHessianG2Tolerance)
      .def_property("hessian_gradient_ncycles", &MnStrategy::HessianGradientNCycles,
                    &MnStrategy::SetHessianGradientNCycles)
      .def_property("storage_level", &MnStrategy::StorageLevel,
                    &MnStrategy::SetStorageLevel)

      .def(py::self == py::self)

      .def(py::pickle(
          [](const MnStrategy& self) {
            return py::make_tuple(
                self.Strategy(), self.GradientNCycles(), self.GradientStepTolerance(),
                self.GradientTolerance(), self.HessianNCycles(),
                self.HessianStepTolerance(), self.HessianG2Tolerance(),
                self.HessianGradientNCycles(), self.StorageLevel());
          },
          [](py::tuple tp) {
            MnStrategy str(tp[0].cast<unsigned>());
            str.SetGradientNCycles(tp[1].cast<unsigned>());
            str.SetGradientStepTolerance(tp[2].cast<double>());
            str.SetGradientTolerance(tp[3].cast<double>());
            str.SetHessianNCycles(tp[4].cast<unsigned>());
            str.SetHessianStepTolerance(tp[5].cast<double>());
            str.SetHessianG2Tolerance(tp[6].cast<double>());
            str.SetHessianGradientNCycles(tp[7].cast<unsigned>());
            str.SetStorageLevel(tp[8].cast<unsigned>());
            return str;
          }))

      ;

  py::implicitly_convertible<unsigned, MnStrategy>();
}