From 0be3af72b706b3ed45f984eea161e55f39cf7326 Mon Sep 17 00:00:00 2001 From: "nicolas.roussouly" Date: Thu, 26 Sep 2024 17:03:59 +0200 Subject: [PATCH] refactor: to simplify BCD integration --- src/gemseo/formulations/bilevel.py | 111 +++++++++++++++++------------ tests/formulations/test_bilevel.py | 2 +- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/gemseo/formulations/bilevel.py b/src/gemseo/formulations/bilevel.py index be9b27c7cb..08fce02b2f 100644 --- a/src/gemseo/formulations/bilevel.py +++ b/src/gemseo/formulations/bilevel.py @@ -33,6 +33,7 @@ from gemseo.core.discipline import MDODiscipline from gemseo.core.mdo_functions.mdo_function import MDOFunction from gemseo.disciplines.scenario_adapters.mdo_scenario_adapter import MDOScenarioAdapter from gemseo.formulations.base_mdo_formulation import BaseMDOFormulation +from gemseo.mda.base_mda import BaseMDA from gemseo.mda.factory import MDAFactory from gemseo.scenarios.scenario_results.bilevel_scenario_result import ( BiLevelScenarioResult, @@ -45,7 +46,6 @@ if TYPE_CHECKING: from gemseo.algos.design_space import DesignSpace from gemseo.core.execution_sequence import ExecutionSequence from gemseo.core.grammars.json_grammar import JSONGrammar - from gemseo.mda.base_mda import BaseMDA from gemseo.scenarios.scenario import Scenario from gemseo.typing import StrKeyMapping @@ -66,13 +66,32 @@ class BiLevel(BaseMDOFormulation): 1. a first MDA to compute the coupling variables, 2. several disciplinary optimizations on the local design variables in parallel, 3. a second MDA to update the coupling variables. + + The residual norm of MDA1 and MDA2 can be captured into scenario + observables thanks to different namespaces :attr:`.BiLevel.MDA1_RESIDUAL_NAMESPACE` + and :attr:`.BiLevel.MDA2_RESIDUAL_NAMESPACE`. """ DEFAULT_SCENARIO_RESULT_CLASS_NAME: ClassVar[str] = BiLevelScenarioResult.__name__ + """Default name of scenario results.""" + + SYSTEM_LEVEL: ClassVar[str] = "system" + """Name of system level.""" + + SUBSCENARIOS_LEVEL: ClassVar[str] = "sub-scenarios" + """Name of sub-scenarios level.""" - SYSTEM_LEVEL = "system" - SUBSCENARIOS_LEVEL = "sub-scenarios" LEVELS = (SYSTEM_LEVEL, SUBSCENARIOS_LEVEL) + """Collection of levels.""" + + CHAIN_NAME: ClassVar[str] = "bilevel_chain" + """Name of the internal chain.""" + + MDA1_RESIDUAL_NAMESPACE: ClassVar[str] = "MDA1" + """A namespace for the MDA1 residuals.""" + + MDA2_RESIDUAL_NAMESPACE: ClassVar[str] = "MDA2" + """A namespace for the MDA2 residuals.""" __sub_scenarios_log_level: int | None """The level of the root logger during the sub-scenarios executions. @@ -134,7 +153,7 @@ class BiLevel(BaseMDOFormulation): self._mda1 = None self._mda2 = None self.reset_x0_before_opt = reset_x0_before_opt - self.scenario_adapters = [] + self._scenario_adapters = [] self.chain = None self._mda_factory = MDAFactory() self._apply_cstr_to_system = apply_cstr_to_system @@ -142,11 +161,16 @@ class BiLevel(BaseMDOFormulation): self.__parallel_scenarios = parallel_scenarios self._multithread_scenarios = multithread_scenarios self.couplstr = CouplingStructure(self.get_sub_disciplines()) + self.__sub_scenarios_log_level = sub_scenarios_log_level # Create MDA - self.__sub_scenarios_log_level = sub_scenarios_log_level self._build_mdas(main_mda_name, inner_mda_name, **main_mda_options) + # Build the scenario adapters + self._build_scenario_adapters( + reset_x0_before_opt=self.reset_x0_before_opt, keep_opt_history=True + ) + # Create MDOChain : MDA1 -> sub scenarios -> MDA2 self._build_chain() @@ -172,7 +196,7 @@ class BiLevel(BaseMDOFormulation): use_non_shared_vars: bool = False, adapter_class: type[MDOScenarioAdapter] = MDOScenarioAdapter, **adapter_options, - ) -> list[MDOScenarioAdapter]: + ): """Build the MDOScenarioAdapter required for each sub scenario. This is used to build the self.chain. @@ -184,11 +208,7 @@ class BiLevel(BaseMDOFormulation): of the scenarios adapters. adapter_class: The class of the adapters. **adapter_options: The options for the adapters' initialization. - - Returns: - The adapters for the sub-scenarios. """ - adapters = [] scenario_log_level = adapter_options.pop( "scenario_log_level", self.__sub_scenarios_log_level ) @@ -203,8 +223,7 @@ class BiLevel(BaseMDOFormulation): scenario_log_level=scenario_log_level, **adapter_options, ) - adapters.append(adapter) - return adapters + self._scenario_adapters.append(adapter) def _compute_adapter_outputs( self, @@ -355,6 +374,9 @@ class BiLevel(BaseMDOFormulation): **main_mda_options, ) self._mda1.warm_start = True + self._mda1.add_namespace_to_output( + BaseMDA.NORMALIZED_RESIDUAL_NORM, self.MDA1_RESIDUAL_NAMESPACE + ) else: LOGGER.warning( "No strongly coupled disciplines detected, " @@ -368,59 +390,58 @@ class BiLevel(BaseMDOFormulation): grammar_type=self._grammar_type, **main_mda_options, ) - self._mda2.warm_start = False + self._mda2.add_namespace_to_output( + BaseMDA.NORMALIZED_RESIDUAL_NORM, self.MDA2_RESIDUAL_NAMESPACE + ) - def _build_chain_dis_sub_opts( - self, - ) -> tuple[list | BaseMDA, list[MDOScenarioAdapter]]: - """Initialize the chain of disciplines and the sub-scenarios. - - Returns: - The first MDA (if exists) and the sub-scenarios. - """ - chain_dis = [] - if self._mda1 is not None: - chain_dis = [self._mda1] - sub_opts = self.scenario_adapters - return chain_dis, sub_opts + @property + def sub_scenarios(self) -> list[MDODiscipline]: + """Return all sub-scenarios.""" + return self._scenario_adapters def _build_chain(self) -> None: """Build the chain on top of which all functions are built. This chain is: MDA -> MDOScenarios -> MDA. """ - # Build the scenario adapters to be chained with MDAs - self.scenario_adapters = self._build_scenario_adapters( - reset_x0_before_opt=self.reset_x0_before_opt, keep_opt_history=True - ) - chain_dis, sub_opts = self._build_chain_dis_sub_opts() + # Init the internal chain with MDA1 + chain_dis = [] if self._mda1 is None else [self._mda1] - if self.__parallel_scenarios: - use_threading = self._multithread_scenarios - par_chain = MDOParallelChain( - sub_opts, use_threading=use_threading, grammar_type=self._grammar_type - ) - chain_dis += [par_chain] - else: - # Chain MDA -> scenarios exec -> MDA - chain_dis += sub_opts + # Add sub-scenarios to the chain + chain_dis += [self._build_sub_scenarios_chain()] - # Add MDA2 if needed + # Add MDA2 to the chain if self._mda2: chain_dis += [self._mda2] + # Build chain depending on warm start if not self.reset_x0_before_opt: self.chain = MDOWarmStartedChain( chain_dis, - name="bilevel_chain", + name=self.CHAIN_NAME, grammar_type=self._grammar_type, variable_names_to_warm_start=self._get_variable_names_to_warm_start(), ) else: self.chain = MDOChain( - chain_dis, name="bilevel_chain", grammar_type=self._grammar_type + chain_dis, name=self.CHAIN_NAME, grammar_type=self._grammar_type + ) + + def _build_sub_scenarios_chain(self) -> MDOChain | MDOParallelChain: + """Build the chain of sub-scenarios. + + Returns: + The chain of sub-scenarios, + either parallel or sequential. + """ + if self.__parallel_scenarios: + return MDOParallelChain( + self.sub_scenarios, + use_threading=self._multithread_scenarios, + grammar_type=self._grammar_type, ) + return MDOChain(self.sub_scenarios) def _get_variable_names_to_warm_start(self) -> list[str]: """Retrieve the names of the variables to warm start. @@ -432,8 +453,8 @@ class BiLevel(BaseMDOFormulation): """ return [ name - for adapter in self.scenario_adapters - for name in adapter.get_output_data_names() + for sub_sce in self.sub_scenarios + for name in sub_sce.get_output_data_names() ] def _update_design_space(self) -> None: diff --git a/tests/formulations/test_bilevel.py b/tests/formulations/test_bilevel.py index 6edccb4941..156b607c83 100644 --- a/tests/formulations/test_bilevel.py +++ b/tests/formulations/test_bilevel.py @@ -238,7 +238,7 @@ def test_grammar_type() -> None: for discipline in formulation.chain.disciplines: assert discipline.grammar_type == grammar_type - for scenario_adapter in formulation.scenario_adapters: + for scenario_adapter in formulation._scenario_adapters: assert scenario_adapter.grammar_type == grammar_type assert formulation.mda1.grammar_type == grammar_type -- GitLab