From 493579fd647ba09ccc9556b44b04a9072f0b10ee Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Sun, 12 Oct 2025 09:04:11 +0200 Subject: [PATCH] add more complex union tests --- tests/test_annotations.py | 43 +++++++++++++++++++++++++++++++++++++++ tests/test_typechecks.py | 23 ++++++++++++++++++++- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index e4b99cb..322abb6 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -649,3 +649,46 @@ def test_mutually_recursive_types(): return actual assert f([[[0]]]) == 0 + + +type Height = Annotated[float, Ge(0.0)] +type InternalPressureFlyingAtm = Annotated[float, Interval(ge=0.0, le=1.0)] +type Depth = Annotated[float, Le(0.0)] +type InternalPressureDivingAtm = Annotated[float, Ge(1.0)] +type FlyOrDivePressureAtm = tuple[Height, InternalPressureFlyingAtm] | tuple[ + Depth, InternalPressureDivingAtm +] + + +def test_advanced_union(): + @annotated + def pressure_is_ok(p: FlyOrDivePressureAtm): + if p[0] == 0.0: + # on the ground: pressure should be constant + return p[1] == 1.0 + elif p[0] > 0.0: + # flying: keep pressure high + return p[1] >= 0.75 + elif p[0] < 0.0: + # diving: keep pressure low + return p[1] <= 1.25 + + # at ground + assert pressure_is_ok((0.0, 1.0)) + assert not pressure_is_ok((0.0, 0.5)) + assert not pressure_is_ok((0.0, 2.0)) + + with pytest.raises(AssertionError): # type: ignore + _ = pressure_is_ok((0.0, -1.0)) + + # in flight + assert pressure_is_ok((1000.0, 0.9)) + assert not pressure_is_ok((1000.0, 0.5)) + with pytest.raises(AssertionError): # type: ignore + _ = pressure_is_ok((1000.0, 1.1)) + + # diving + assert pressure_is_ok((-1000.0, 1.1)) + assert not pressure_is_ok((-1000.0, 1.5)) + with pytest.raises(AssertionError): # type: ignore + _ = pressure_is_ok((-1000.0, 0.9)) diff --git a/tests/test_typechecks.py b/tests/test_typechecks.py index 4f63d21..86c3ead 100644 --- a/tests/test_typechecks.py +++ b/tests/test_typechecks.py @@ -1,5 +1,5 @@ from typing import Annotated -from annotated_types import Ge +from annotated_types import Ge, Lt, Gt from inspect import get_annotations import dataclasses from contractme.typecheck import ( @@ -61,3 +61,24 @@ def test_constraints_error_union(): an = get_annotations(h)["a"] assert not get_constraints_errors(0.0, an) assert get_constraints_errors(-1.0, an) + + +# ((<0.0) x (int)) + ((>0.0) x ((<0.0) x (int)) + (>0.0))) +type PosFloat = Annotated[float, Gt(0.0)] +type NegFloat = Annotated[float, Lt(0.0)] +type NegAndInt = tuple[NegFloat, int] +type NegAndIntOrPos = NegAndInt | PosFloat +type Complex = NegAndInt | tuple[PosFloat, NegAndIntOrPos] + + +def recursive_union(a: Complex): # pragma: no cover + pass + + +def test_complex(): + an = get_annotations(recursive_union)["a"] + assert not get_constraints_errors((-1.0, 1), an) + assert get_constraints_errors((1.0, 1), an) + assert not get_constraints_errors((1.0, (-1.0, 1)), an) + assert get_constraints_errors((1.0, (1.0, 1)), an) + assert not get_constraints_errors((1.0, 1.0), an) -- GitLab