diff --git a/changelog/fragments/927.added.rst b/changelog/fragments/927.added.rst new file mode 100644 index 0000000000000000000000000000000000000000..4253f4d63097af975e7b92757c07d70b571d4d97 --- /dev/null +++ b/changelog/fragments/927.added.rst @@ -0,0 +1 @@ +New discipline RetryDiscipline: it wraps a discipline to retry the execution several times. It tries to execute the discipline, if it raises an exception then it retries up to a maximum number of attempts. It can pass a tuple of exceptions that, if one of them raised, do not retry the execution. diff --git a/src/gemseo/disciplines/wrappers/retry_discipline.py b/src/gemseo/disciplines/wrappers/retry_discipline.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d617713d4411a3b2eb9b24538ade080e443b36 --- /dev/null +++ b/src/gemseo/disciplines/wrappers/retry_discipline.py @@ -0,0 +1,209 @@ +# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +"""The retry discipline.""" + +from __future__ import annotations + +import concurrent.futures as cfutures +import math +import os +import signal +import time +from concurrent.futures import ProcessPoolExecutor +from logging import getLogger +from typing import TYPE_CHECKING +from typing import ClassVar + +from gemseo.core.discipline import Discipline + +if TYPE_CHECKING: + from collections.abc import Iterable + + from gemseo.typing import StrKeyMapping + +LOGGER = getLogger(__name__) + + +class RetryDiscipline(Discipline): + """A discipline to be executed with retry and timeout options. + + This :class:`.Discipline` wraps another discipline. + + It can be executed multiple times (up to a specified number of retries) + if the previous attempts fail to produce any result. + + A timeout in seconds can be specified to prevent executions from becoming stuck. + + Users can also provide a tuple of :class:`.Exception` that, if one of them is + raised, it does not retry the execution. + + Please note that the ``TimeoutError`` exception is also caught if the wrapped + discipline raises such an exception (i.e. aside from ``RetryDiscipline`` itself). + So it could lead to 2 surprising cases, but in fact normal cases: + - a ``TimeoutError`` exception even though the user didn't provide any timeout + value. + - a ``TimeoutError`` raised sooner than the ``timeout`` value set by the user. + """ + + __n_executions: int + """The number of performed executions of the discipline.""" + + __time_out_exceptions: ClassVar[tuple[type[Exception], ...]] = ( + TimeoutError, + cfutures.TimeoutError, + ) + """The possible timeout exceptions that can be raised during execution.""" + + n_retry: int + """The number of retry of the discipline.""" + + wait_time: float + """The time to wait between 2 retries (in seconds).""" + + timeout: float + """The maximum duration, in seconds, that the discipline is allowed to run.""" + + fatal_exceptions: Iterable[type[Exception]] + """The exceptions for which the code raises an exception and exit immediately + without retrying a run.""" + + def __init__( + self, + discipline: Discipline, + n_retry: int = 5, + wait_time: float = 0.0, + timeout: float = math.inf, + fatal_exceptions: Iterable[type[Exception]] = (), + ) -> None: + """ + Args: + discipline: The discipline to wrap in the retry loop. + n_retry: The number of retry of the discipline. + wait_time: The time to wait between 2 retries (in seconds). + timeout: The maximum duration, in seconds, that the discipline is + allowed to run. If this time limit is exceeded, the + execution is terminated. If ``math.inf``, the + discipline is executed without timeout limit. + fatal_exceptions: The exceptions for which the code raises an + exception and exit immediately without retrying a run. + + Raises: + TimeoutError: If the ``timeout`` limit is reached. + Exception: Other exceptions if an issue is encountered during the + execution of ``discipline``. + + """ # noqa:D205 D212 D415 + super().__init__(discipline.name) + self.__discipline = discipline + self.io.input_grammar = discipline.io.input_grammar + self.io.output_grammar = discipline.io.output_grammar + self.n_retry = n_retry + self.wait_time = wait_time + self.timeout = timeout + self.fatal_exceptions = fatal_exceptions + self.__n_executions = 0 + + @property + def n_executions(self) -> int: + """The number of times the discipline has been retried during execution.""" + return self.__n_executions + + def _run(self, input_data: StrKeyMapping) -> StrKeyMapping | None: + self.__n_executions = 0 + + for n_try in range(1, self.n_retry + 1): + self.__n_executions += 1 + + LOGGER.debug( + "Trying to execute the discipline: attempt %d/%d", n_try, self.n_retry + ) + + try: + if math.isinf(self.timeout): + return self.__discipline.execute(input_data) + return self._execute_discipline(input_data) + + except self.__time_out_exceptions: + msg = ( + "Timeout reached during the execution of " + f"discipline {self.__discipline.name}" + ) + LOGGER.debug(msg) + current_error = TimeoutError(msg) + + except Exception as error: # noqa: BLE001 + if isinstance(error, self.fatal_exceptions): + LOGGER.info( + "Failed to execute discipline %s, " + "aborting retry because of the exception type %s.", + self.__discipline.name, + type(error), + ) + raise + current_error = error + + time.sleep(self.wait_time) + + plural_suffix = "s" if self.n_retry > 1 else "" + LOGGER.error( + "Failed to execute discipline %s after %d attempt%s.", + self.__discipline.name, + self.n_retry, + plural_suffix, + ) + raise current_error + + def _execute_discipline(self, input_data: StrKeyMapping) -> StrKeyMapping: + """Execute the discipline with a timeout. + + Args: + input_data: The input data passed to the discipline. + + Returns: + The output returned by the discipline. + """ + LOGGER.debug( + "Executing discipline %s with a timeout of %s s", + self.__discipline.name, + self.timeout, + ) + + with ProcessPoolExecutor() as executor: + run_discipline = executor.submit( + self.__discipline.execute, + input_data, + ) + + try: + return run_discipline.result(timeout=self.timeout) + + except self.__time_out_exceptions: + # Killing the children is mandatory to abort the discipline execution + # immediately: shutdown + kill children. + pid_child = [p.pid for p in executor._processes.values()] + executor.shutdown(wait=False, cancel_futures=True) + + LOGGER.debug("killing subprocesses: %s", pid_child) + for pid in pid_child: + os.kill(pid, signal.SIGTERM) + + LOGGER.exception( + "Process stopped as it exceeds timeout (%s s)", self.timeout + ) + raise + + except Exception as error: # noqa: BLE001 + LOGGER.debug(type(error)) + raise diff --git a/tests/disciplines/wrappers/test_retry_discipline.py b/tests/disciplines/wrappers/test_retry_discipline.py new file mode 100644 index 0000000000000000000000000000000000000000..5262ce0c1d2b840102cf81bf93dd150cfdfa37a7 --- /dev/null +++ b/tests/disciplines/wrappers/test_retry_discipline.py @@ -0,0 +1,234 @@ +# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program; if not, write to the Free Software Foundation, +# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +"""Tests for retry discipline.""" + +from __future__ import annotations + +import math +import re +import time +from typing import TYPE_CHECKING + +import pytest +from numpy import array + +from gemseo import create_discipline +from gemseo.core.discipline import Discipline +from gemseo.disciplines.wrappers.retry_discipline import RetryDiscipline +from gemseo.utils.timer import Timer + +if TYPE_CHECKING: + from collections.abc import Iterable + + from gemseo.typing import StrKeyMapping + + +@pytest.fixture +def an_analytic_discipline() -> Discipline: + """Analytic discipline.""" + return create_discipline("AnalyticDiscipline", expressions={"y": "x"}) + + +@pytest.fixture +def a_crashing_analytic_discipline() -> Discipline: + """Analytic discipline crashing when x=0.""" + return create_discipline("AnalyticDiscipline", expressions={"y": "1.0/x"}) + + +@pytest.fixture +def a_crashing_discipline_in_run() -> Discipline: + return CrashingDisciplineInRun(name="Crash_run") + + +@pytest.fixture +def a_long_time_running_discipline() -> Discipline: + return DisciplineLongTimeRunning() + + +class CrashingDisciplineInRun(Discipline): + """A discipline raising NotImplementedError in ``_run``.""" + + def _run(self, input_data: StrKeyMapping): + msg = "Error: This method is not implemented." + raise NotImplementedError(msg) + + +class DisciplineLongTimeRunning(Discipline): + """A discipline that could run for a while, to test the timeout feature.""" + + def _run(self, input_data: StrKeyMapping) -> None: + time.sleep(5.0) + + +@pytest.mark.parametrize("timeout", [math.inf, 10.0]) +def test_retry_discipline(an_analytic_discipline, timeout, caplog) -> None: + """Test discipline, no timeout set.""" + retry_discipline = RetryDiscipline(an_analytic_discipline, timeout=timeout) + retry_discipline.execute({"x": array([4.0])}) + + assert retry_discipline.n_executions == 1 + assert retry_discipline.local_data == {"x": array([4.0]), "y": array([4.0])} + + assert caplog.text == "" + + +@pytest.mark.parametrize("wait_time", [0.5, 1.0]) +@pytest.mark.parametrize("n_retry", [1, 3]) +def test_failure_retry_discipline_with_timeout( + an_analytic_discipline, n_retry, wait_time, caplog +) -> None: + """Test failure of the discipline with a too much very short timeout.""" + disc_with_timeout = RetryDiscipline( + an_analytic_discipline, timeout=1e-4, n_retry=n_retry, wait_time=wait_time + ) + + with ( + Timer() as timer, + pytest.raises( + TimeoutError, + match="Timeout reached during the execution" + " of discipline AnalyticDiscipline", + ), + ): + disc_with_timeout.execute({"x": array([4.0])}) + + elapsed_time = timer.elapsed_time + assert elapsed_time > 0.05 + (n_retry - 1) * wait_time + + assert disc_with_timeout.n_executions == n_retry + assert disc_with_timeout.local_data == {"x": array([4.0])} + + assert "Process stopped as it exceeds timeout" in caplog.text + + plural_suffix = "s" if n_retry > 1 else "" + log_message = ( + f"Failed to execute discipline AnalyticDiscipline after {n_retry}" + f" attempt{plural_suffix}." + ) + assert log_message in caplog.text + + +def test_failure_zero_division_error(a_crashing_analytic_discipline, caplog) -> None: + """Test failure of the discipline with a bad x entry. + + In order to catch the ZeroDivisionError, set n_retry=1 + """ + disc = RetryDiscipline(a_crashing_analytic_discipline, n_retry=1) + with pytest.raises(ZeroDivisionError, match="float division by zero"): + disc.execute({"x": array([0.0])}) + + assert disc.local_data == {"x": array([0.0])} + assert disc.n_executions == 1 + + log_message = "Failed to execute discipline AnalyticDiscipline after 1 attempt." + assert log_message in caplog.text + + +@pytest.mark.parametrize( + "fatal_exceptions", + [ + (ZeroDivisionError,), + (ZeroDivisionError, FloatingPointError, OverflowError), + (OverflowError, FloatingPointError, ZeroDivisionError), + ], +) +@pytest.mark.parametrize("n_try", [1, 3]) +def test_failure_zero_division_error_with_timeout( + n_try: int, + fatal_exceptions: Iterable[type[Exception]], + a_crashing_analytic_discipline, + caplog, +) -> None: + """Test failure of the discipline with timeout and a bad x entry. + + In order to catch the ZeroDivisionError that arises before timeout (5s), test with + n_retry=1 and 3 to be sure every case is ok. + """ + disc = RetryDiscipline( + a_crashing_analytic_discipline, + n_retry=n_try, + timeout=5.0, + fatal_exceptions=fatal_exceptions, + ) + with pytest.raises(ZeroDivisionError, match="float division by zero"): + disc.execute({"x": array([0.0])}) + + assert disc.n_executions == 1 + assert disc.local_data == {"x": array([0.0])} + + log_message = ( + "Failed to execute discipline AnalyticDiscipline," + " aborting retry because of the exception type ." + ) + assert log_message in caplog.text + + +def test_a_not_implemented_error_analytic_discipline( + a_crashing_discipline_in_run, caplog +) -> None: + """Test discipline with a_crashing_discipline_in_run and a tuple of + + a tuple of fatal_exceptions that abort the retry (ZeroDivisionError). + """ + retry_discipline = RetryDiscipline( + a_crashing_discipline_in_run, + n_retry=5, + timeout=100.0, + fatal_exceptions=( + ZeroDivisionError, + FloatingPointError, + OverflowError, + NotImplementedError, + ), + ) + with pytest.raises( + NotImplementedError, match=re.escape("Error: This method is not implemented.") + ): + retry_discipline.execute({"x": array([1.0])}) + + assert retry_discipline.n_executions == 1 + assert retry_discipline.local_data == {} + + log_message = ( + "Failed to execute discipline Crash_run, aborting retry " + "because of the exception type ." + ) + assert log_message in caplog.text + + +def test_retry_discipline_timeout_feature( + a_long_time_running_discipline, caplog +) -> None: + """Test the timeout feature of discipline with a long computation.""" + n_retry = 1 + + disc_with_timeout = RetryDiscipline( + a_long_time_running_discipline, timeout=2.0, n_retry=n_retry + ) + with pytest.raises( + TimeoutError, + match="Timeout reached during the execution" + " of discipline DisciplineLongTimeRunning", + ): + disc_with_timeout.execute({"x": array([0.0])}) + + assert disc_with_timeout.n_executions == n_retry + assert disc_with_timeout.local_data == {} + + assert "Process stopped as it exceeds timeout" in caplog.text + log_message = ( + "Failed to execute discipline DisciplineLongTimeRunning after 1 attempt." + ) + assert log_message in caplog.text