From 1bb907ce5494ab54bb7f7b9d2d8674fe6e1a2896 Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Sat, 11 Oct 2025 17:00:02 +0200 Subject: [PATCH 1/8] VSCode: set config for pytest --- .vscode/settings.json | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9b38853 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file -- GitLab From 620c99de867df3cd3e8c2eb728f4269fa40f7595 Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Wed, 4 Jun 2025 08:58:11 +0200 Subject: [PATCH 2/8] annotation type alias --- tests/test_annotations.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 55aa2c8..2d26a53 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -587,3 +587,18 @@ def test_tuple_ellipsis(): with pytest.raises(AssertionError): f((0, 1.0)) # type: ignore + + +type MyInt = Annotated[int, Gt(0)] + + +def test_type_alias(): + @annotated + def f(a: MyInt): + return a + + assert f(1) == 1 + with pytest.raises(AssertionError): + f(1.0) # type: ignore + with pytest.raises(AssertionError): + f(0) # type: ignore -- GitLab From 00607189bfc7376079e0c57cb6e662b725f4c3f6 Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Wed, 4 Jun 2025 09:13:33 +0200 Subject: [PATCH 3/8] test recursive types --- tests/test_annotations.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 2d26a53..bd29b07 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -602,3 +602,24 @@ def test_type_alias(): f(1.0) # type: ignore with pytest.raises(AssertionError): f(0) # type: ignore + + +type RecList = list[RecList] + + +def test_recursive_type(): + @annotated + def f(a: RecList): + actual = a + while actual: + actual = actual[0] + return actual + + e = [] + + assert f([[[e]]]) is e + assert f([[[e], e], e]) is e + with pytest.raises(AssertionError): + assert f([[[e]], 1]) is e # type: ignore + with pytest.raises(AssertionError): + assert f([[[e, None]]]) is e # type: ignore -- GitLab From 4e767dc20e3f589f8c1b8661908964f43cd46716 Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Wed, 4 Jun 2025 10:19:18 +0200 Subject: [PATCH 4/8] option types, and refacto support for option types split code into annotations / typechecks modules --- README.md | 4 + src/contractme/__init__.py | 4 +- src/contractme/annotations.py | 64 +++++++ src/contractme/contracting.py | 332 +--------------------------------- src/contractme/typecheck.py | 298 ++++++++++++++++++++++++++++++ tests/test_annotations.py | 30 ++- tests/test_iter.py | 36 ---- tests/test_typechecks.py | 48 +++++ 8 files changed, 445 insertions(+), 371 deletions(-) create mode 100644 src/contractme/annotations.py create mode 100644 src/contractme/typecheck.py delete mode 100644 tests/test_iter.py create mode 100644 tests/test_typechecks.py diff --git a/README.md b/README.md index fef526b..b11acc0 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,10 @@ using the [annotated-types](https://pypi.org/project/annotated-types/) library. **Note:** `annodated_types.MultipleOf` follows the Python semantics. **Note 2:** Following an open-world reasoning, any unknown annotation is considered to be correct, so it won't cause a check failure. +**Note 3:** Type checking follows Python's `isinstance` semantics, which means subclass +relationships are respected. Since `bool` is a subclass of `int` in Python, boolean values +will pass `int` type checks. Currently there's no built-in way to specify "exactly int, not bool" +in type annotations. ```python from typing import TypeAlias, Annotated diff --git a/src/contractme/__init__.py b/src/contractme/__init__.py index dc47ebe..0824818 100644 --- a/src/contractme/__init__.py +++ b/src/contractme/__init__.py @@ -1,4 +1,4 @@ -from contractme import contracting +from contractme import contracting, annotations ignore_preconditions = contracting.ignore_preconditions check_preconditions = contracting.check_preconditions @@ -7,4 +7,4 @@ check_postconditions = contracting.check_postconditions precondition = contracting.precondition postcondition = contracting.postcondition -annotated = contracting.annotated +annotated = annotations.annotated diff --git a/src/contractme/annotations.py b/src/contractme/annotations.py new file mode 100644 index 0000000..e7eb334 --- /dev/null +++ b/src/contractme/annotations.py @@ -0,0 +1,64 @@ +import sys +from typing import get_type_hints +from inspect import get_annotations +from contractme.contracting import precondition, postcondition +from contractme.typecheck import ( + get_first_structure_error, + get_constraints_errors, + resolve_type, + AnnotationForm, +) + + +def get_types_and_annotations_errors(value, an): + # we visit all of the possible nodes depth-first + # for each one, we check its type, and we delegate any annotation check + # then we loop over all these annotations, and we check them as well + # this allows for better errors handling: first the structure is checked to be fully OK, + # then we check the values and constraints. + struct_err = get_first_structure_error(value, an) + if struct_err: + return f"structure is malformed: {struct_err}" + else: + constraints_err = get_constraints_errors(value, an) + if constraints_err: + return "structure ok, but value does not match constraints: " + ", ".join( + constraints_err + ) + else: + return None + + +def annotated(f): + def check_annotated_arg(arg_name, arg_an: AnnotationForm): + constraints = getattr(arg_an, "__metadata__", tuple()) + t = resolve_type(arg_an) + if constraints: + msg = f"{arg_name} should be instance of {t} under constraints {constraints}" + else: + msg = f"{arg_name} should be instance of {t}" + + if arg_name == "return": + cond_arg_name = "result" + cond_f = postcondition + else: + cond_arg_name = n + cond_f = precondition + + def chk(**kw): + err = get_types_and_annotations_errors(kw[cond_arg_name], arg_an) + if err: + print(f"arg {cond_arg_name}: {err}", file=sys.stderr) + + return err is None + + cond = cond_f(chk, msg) + return cond + + ths = get_type_hints(f) + ans = get_annotations(f) + + for n, t_or_typevar in ths.items(): + f = check_annotated_arg(n, ans[n])(f) + + return f diff --git a/src/contractme/contracting.py b/src/contractme/contracting.py index a21df6c..c595ae1 100644 --- a/src/contractme/contracting.py +++ b/src/contractme/contracting.py @@ -1,32 +1,8 @@ -import sys import inspect import copy from warnings import warn -import types -from typing import ( - Callable, - get_type_hints, - TypeVar, - get_origin, - get_args, - TypeAlias, - TypeAliasType, - Annotated, - Any, - Iterable, -) -from inspect import get_annotations +from typing import Callable import functools -import operator -import annotated_types -import datetime -import zoneinfo -from dataclasses import dataclass - -# Stolen from typeshed -assert sys.version_info >= (3, 10) -ClassInfo: TypeAlias = type | types.UnionType | tuple["ClassInfo", ...] -AnnotationForm: TypeAlias = Any def show_source(f: Callable) -> str: @@ -350,310 +326,4 @@ def postcondition(conditional, message=None): return condition_decorator(False, conditional, message) -def normalize_constraint(constraint): - if isinstance(constraint, functools.partial): - # https://github.com/annotated-types/annotated-types?tab=readme-ov-file#gt-ge-lt-le - match constraint.func, constraint.args: - case operator.lt, (n,): - n = annotated_types.Lt(n) - case operator.le, (n,): - n = annotated_types.Le(n) - case operator.gt, (n,): - n = annotated_types.Gt(n) - case operator.ge, (n,): - n = annotated_types.Ge(n) - case _: # pragma: no cover - raise NotImplementedError(constraint) - return n - else: - return constraint - - -def get_all_constraints_failures(val, constraints): - if not constraints: - return [] - else: - failing = [] - for constraint_denorm in constraints: - constraint = normalize_constraint(constraint_denorm) - match constraint: - case annotated_types.Interval(gt=None, lt=None, ge=None, le=None): - # null range - check = True - case annotated_types.Gt(gt=gt) | annotated_types.Interval( - gt=gt, lt=None, ge=None, le=None - ): - check = val > gt - case annotated_types.Lt(lt=lt) | annotated_types.Interval( - lt=lt, gt=None, ge=None, le=None - ): - check = val < lt - case annotated_types.Ge(ge=ge) | annotated_types.Interval( - ge=ge, lt=None, gt=None, le=None - ): - check = val >= ge - case annotated_types.Le(le=le) | annotated_types.Interval( - le=le, lt=None, ge=None, gt=None - ): - check = val <= le - case annotated_types.Interval(gt=gt, lt=lt, ge=None, le=None): - check = gt < val < lt - case annotated_types.Interval(ge=ge, lt=lt, gt=None, le=None): - check = ge <= val < lt - case annotated_types.Interval(gt=gt, le=le, ge=None, lt=None): - check = gt < val <= le - case annotated_types.Interval(ge=ge, le=le, gt=None, lt=None): - check = ge <= val <= le - case annotated_types.MultipleOf(multiple_of=multiple_of): - check = val % multiple_of == 0 - case annotated_types.MinLen(min_length=min_length) | annotated_types.Len( - min_length=min_length, max_length=None - ): - check = min_length <= len(val) - case annotated_types.MaxLen(max_length=max_length) | annotated_types.Len( - max_length=max_length, min_length=None - ): - check = len(val) <= max_length - case annotated_types.Len(min_length=min_length, max_length=max_length): - assert min_length is not None and max_length is not None - check = min_length <= len(val) <= max_length - case annotated_types.Timezone(tz=tz): - if tz is None: - check = val.tzinfo is None - elif tz is Ellipsis: - check = val.tzinfo is not None - elif isinstance(tz, datetime.tzinfo): - check = val.tzinfo == tz - elif isinstance(tz, str): - check = val.tzinfo == zoneinfo.ZoneInfo(tz) - else: # pragma: no cover - raise NotImplementedError(f"tz being {tz.__class__.__name__}") - case annotated_types.Predicate(func=func): - check = func(val) - case _: - check = True - - if not check: - failing.append(constraint) - return failing - - -ConstraintSet: TypeAlias = list[Any] - - -def resolve_type(t: AnnotationForm) -> AnnotationForm: - """ - Best-effort type resolution, this will remove as many - type var or alias as possible. - """ - alias = [] - while isinstance(t, (TypeVar, TypeAliasType)): - if t in alias: - # circular definition - break - else: - alias.append(t) - - if isinstance(t, TypeAliasType): - t = t.__value__ - elif isinstance(t, TypeVar): - t = t.__bound__ or Any - return t - - -@dataclass -class CompositeTypeInfo: - origin: ClassInfo - constraints: ConstraintSet - child_repeat: bool # 0 .. n - children: tuple[AnnotationForm, ...] - - def __init__(self, t: AnnotationForm): - origin = get_origin(t) or t - self.constraints = list() - while origin is Annotated: - self.constraints += get_args(t)[1:] - t = resolve_type(get_args(t)[0]) - origin = get_origin(t) or t - self.origin = origin - args = get_args(t) - if self.origin == tuple and len(args) == 2 and args[1] == Ellipsis: - self.child_repeat = True - self.children = (args[0],) - elif self.origin == list: - assert ( - len(args) == 1 - ), "Cannot have list[T1, T2...] must either be list[T1] or list[tuple[T1, T2]]" - self.child_repeat = True - self.children = args - else: - self.child_repeat = False - self.children = args - - -def get_recursive_composite_type_info(an: AnnotationForm) -> Iterable[CompositeTypeInfo]: - to_visit = [an] - - while to_visit: - t = resolve_type(to_visit.pop(0)) - r = CompositeTypeInfo(t) - if not r.child_repeat: - to_visit += r.children - yield r - - -def annotated(f): - def get_types_and_annotations_errors(value, an): - # we visit all of the possible nodes depth-first - # for each one, we check its type, and we delegate any annotation check - # then we loop over all these annotations, and we check them as well - # this allows for better errors handling: first the structure is checked to be fully OK, - # then we check the values and constraints. - @dataclass - class ErrorsCollector: - _context: list - _errors: list[str] - _local_error: bool - - def __init__(self): - self._context = [] - self._errors = [] - self._local_error = False - - def set_context(self, type_info): - self._context.append(type_info.origin) - - def check_cond(self, cond: bool, message: str): - self._local_error = not cond - if self._local_error: - self._errors.append(f"{self._context}: {message}") - - def is_ok(self) -> bool: - return not self._errors - - def locally_ok(self) -> bool: - return not self._local_error - - def check_with(self, check_function, value, an): - err = check_function(value, an) - if err: - self._errors.append(f"{self._context}: {err}") - - def maybe_first_error(self) -> str | None: - return None if not self._errors else self._errors[0] - - def all_errors(self) -> list[str]: - return self._errors - - def get_first_structure_error(value, an) -> str | None: - struct = ErrorsCollector() - - check_values = [value] - for type_info in get_recursive_composite_type_info(an): - # Struct check - val = check_values.pop(0) - struct.set_context(type_info) - is_sequence = type_info.origin in (list, tuple) - struct.check_cond( - isinstance(val, type_info.origin), - "expect {type_info.origin} got {val.__class__.__name__}", - ) - if struct.is_ok(): - if type_info.child_repeat: - assert len(type_info.children) == 1 - for v in val: - struct.check_with(get_first_structure_error, v, type_info.children[0]) - if not struct.is_ok(): - break - elif is_sequence: - # depth-first - struct.check_cond( - len(val) == len(type_info.children), - "not the right number of children: " - "expect {len(type_info.children)} got {len(val)}", - ) - if struct.is_ok(): - check_values = list(val) + check_values - - if not struct.is_ok(): - break - assert ( - not check_values - ), f"bug! values remains to be checked ({check_values}) for {value} {an}" - return struct.maybe_first_error() - - def get_constraints_errors(value, an): - constraints = ErrorsCollector() - - check_values = [value] - for type_info in get_recursive_composite_type_info(an): - # Constraints check - val = check_values.pop(0) - is_sequence = type_info.origin in (list, tuple) - constraint_failures = get_all_constraints_failures(val, type_info.constraints) - constraints.check_cond( - not constraint_failures, - f"{val!r} failed constraint {constraint_failures}", - ) - if constraints.locally_ok(): - # check only if parent constraints are already ok - if type_info.child_repeat: - assert len(type_info.children) == 1 - for v in val: - constraints.check_with(get_constraints_errors, v, type_info.children[0]) - elif is_sequence: - # depth-first - check_values = list(val) + check_values - - assert ( - not check_values - ), f"bug! values remains to be checked ({check_values}) for {value} {an}" - return constraints.all_errors() - - struct_err = get_first_structure_error(value, an) - if struct_err: - return f"structure is malformed: {struct_err}" - else: - constraints_err = get_constraints_errors(value, an) - if constraints_err: - return "structure ok, but value does not match constraints: " + ", ".join( - constraints_err - ) - else: - return None - - def check_annotated_arg(arg_name, arg_an: AnnotationForm): - constraints = getattr(arg_an, "__metadata__", tuple()) - t = resolve_type(arg_an) - if constraints: - msg = f"{arg_name} should be instance of {t} under constraints {constraints}" - else: - msg = f"{arg_name} should be instance of {t}" - - if arg_name == "return": - cond_arg_name = "result" - cond_f = postcondition - else: - cond_arg_name = n - cond_f = precondition - - def chk(**kw): - err = get_types_and_annotations_errors(kw[cond_arg_name], arg_an) - if err: - print(f"arg {cond_arg_name}: {err}", file=sys.stderr) - - return err is None - - cond = cond_f(chk, msg) - return cond - - ths = get_type_hints(f) - ans = get_annotations(f) - - for n, t_or_typevar in ths.items(): - f = check_annotated_arg(n, ans[n])(f) - - return f - - never_returns = postcondition(lambda: False, "should never return") diff --git a/src/contractme/typecheck.py b/src/contractme/typecheck.py new file mode 100644 index 0000000..5f39cad --- /dev/null +++ b/src/contractme/typecheck.py @@ -0,0 +1,298 @@ +import sys +import datetime +import zoneinfo +import functools +import operator +from enum import Enum, auto +from dataclasses import dataclass +import types +import typing +from typing import ( + Annotated, + Any, + Iterable, + TypeAliasType, + TypeVar, + get_origin, + get_args, +) +import annotated_types + +assert sys.version_info >= (3, 12) + +# Stolen from typeshed +type ClassInfo = type | types.UnionType | tuple[ClassInfo, ...] +type AnnotationForm = Any + + +def normalize_constraint(constraint): + if isinstance(constraint, functools.partial): + # https://github.com/annotated-types/annotated-types?tab=readme-ov-file#gt-ge-lt-le + match constraint.func, constraint.args: + case operator.lt, (n,): + n = annotated_types.Lt(n) + case operator.le, (n,): + n = annotated_types.Le(n) + case operator.gt, (n,): + n = annotated_types.Gt(n) + case operator.ge, (n,): + n = annotated_types.Ge(n) + case _: # pragma: no cover + raise NotImplementedError(constraint) + return n + else: + return constraint + + +def get_all_constraints_failures(val, constraints): + if not constraints: + return [] + else: + failing = [] + for constraint_denorm in constraints: + constraint = normalize_constraint(constraint_denorm) + match constraint: + case annotated_types.Interval(gt=None, lt=None, ge=None, le=None): + # null range + check = True + case annotated_types.Gt(gt=gt) | annotated_types.Interval( + gt=gt, lt=None, ge=None, le=None + ): + check = val > gt + case annotated_types.Lt(lt=lt) | annotated_types.Interval( + lt=lt, gt=None, ge=None, le=None + ): + check = val < lt + case annotated_types.Ge(ge=ge) | annotated_types.Interval( + ge=ge, lt=None, gt=None, le=None + ): + check = val >= ge + case annotated_types.Le(le=le) | annotated_types.Interval( + le=le, lt=None, ge=None, gt=None + ): + check = val <= le + case annotated_types.Interval(gt=gt, lt=lt, ge=None, le=None): + check = gt < val < lt + case annotated_types.Interval(ge=ge, lt=lt, gt=None, le=None): + check = ge <= val < lt + case annotated_types.Interval(gt=gt, le=le, ge=None, lt=None): + check = gt < val <= le + case annotated_types.Interval(ge=ge, le=le, gt=None, lt=None): + check = ge <= val <= le + case annotated_types.MultipleOf(multiple_of=multiple_of): + check = val % multiple_of == 0 + case annotated_types.MinLen(min_length=min_length) | annotated_types.Len( + min_length=min_length, max_length=None + ): + check = min_length <= len(val) + case annotated_types.MaxLen(max_length=max_length) | annotated_types.Len( + max_length=max_length, min_length=None + ): + check = len(val) <= max_length + case annotated_types.Len(min_length=min_length, max_length=max_length): + assert min_length is not None and max_length is not None + check = min_length <= len(val) <= max_length + case annotated_types.Timezone(tz=tz): + if tz is None: + check = val.tzinfo is None + elif tz is Ellipsis: + check = val.tzinfo is not None + elif isinstance(tz, datetime.tzinfo): + check = val.tzinfo == tz + elif isinstance(tz, str): + check = val.tzinfo == zoneinfo.ZoneInfo(tz) + else: # pragma: no cover + raise NotImplementedError(f"tz being {tz.__class__.__name__}") + case annotated_types.Predicate(func=func): + check = func(val) + case _: + check = True + + if not check: + failing.append(constraint) + return failing + + +def resolve_type(t: AnnotationForm) -> AnnotationForm: + """ + Best-effort type resolution, this will remove as many + type var or alias as possible. + """ + alias = [] + while isinstance(t, (TypeVar, TypeAliasType)): + if t in alias: + # circular definition + break + else: + alias.append(t) + + if isinstance(t, TypeAliasType): + t = t.__value__ + elif isinstance(t, TypeVar): + t = t.__bound__ or Any + return t + + +class ChildMod(Enum): + NONE = auto() + REPEAT = auto() # 0..n + OPTION = auto() # A | B | C + + +type ConstraintSet = list[Any] + + +@dataclass +class CompositeTypeInfo: + origin: ClassInfo + constraints: ConstraintSet + child_mod: ChildMod + children: tuple[AnnotationForm, ...] + + def __init__(self, t: AnnotationForm): + origin = get_origin(t) or t + self.constraints = list() + while origin is Annotated: + self.constraints += get_args(t)[1:] + t = resolve_type(get_args(t)[0]) + origin = get_origin(t) or t + self.origin = origin + args = get_args(t) + if self.origin == tuple and len(args) == 2 and args[1] == Ellipsis: + self.child_mod = ChildMod.REPEAT + self.children = (args[0],) + elif self.origin == list: + assert ( + len(args) == 1 + ), "Cannot have list[T1, T2...] must either be list[T1] or list[tuple[T1, T2]]" + self.child_mod = ChildMod.REPEAT + self.children = args + elif self.origin in (types.UnionType, typing.Union): + self.child_mod = ChildMod.OPTION + self.children = args + else: + self.child_mod = ChildMod.NONE + self.children = args + + +def get_recursive_composite_type_info(an: AnnotationForm) -> Iterable[CompositeTypeInfo]: + to_visit = [an] + + while to_visit: + t = resolve_type(to_visit.pop(0)) + r = CompositeTypeInfo(t) + if r.child_mod == ChildMod.NONE: + to_visit += r.children + yield r + + +@dataclass +class ErrorsCollector: + _context: list + _errors: list[str] + _local_error: bool + + def __init__(self): + self._context = [] + self._errors = [] + self._local_error = False + + def set_context(self, type_info): + self._context.append(type_info.origin) + + def _append_error(self, msg): + self._errors.append(f"in {self._context}: {msg}") + + def check_cond(self, cond: bool, message: str): + self._local_error = not cond + if self._local_error: + self._append_error(message) + + def is_ok(self) -> bool: + return not self._errors + + def locally_ok(self) -> bool: + return not self._local_error + + def check_with(self, check_function, *args): + err = check_function(*args) + if err: + self._append_error(err) + + def maybe_first_error(self) -> str | None: + return None if not self._errors else self._errors[0] + + def all_errors(self) -> list[str]: + return self._errors + + +def get_first_structure_error(value, an) -> str | None: + struct = ErrorsCollector() + + check_values = [value] + for type_info in get_recursive_composite_type_info(an): + # Struct check + val = check_values.pop(0) + struct.set_context(type_info) + is_sequence = type_info.origin in (list, tuple) + if type_info.child_mod == ChildMod.OPTION: + found = False + for c in type_info.children: + err = get_first_structure_error(val, c) + if not err: + found = True + break + struct.check_cond(found, f"{val} does not match any of {type_info.children}") + else: + struct.check_cond( + isinstance(val, type_info.origin), + f"expect {type_info.origin} got {val.__class__.__name__}", + ) + if struct.is_ok(): + if type_info.child_mod == ChildMod.REPEAT: + assert len(type_info.children) == 1 + for v in val: + struct.check_with(get_first_structure_error, v, type_info.children[0]) + if not struct.is_ok(): + break + elif is_sequence: + # depth-first + struct.check_cond( + len(val) == len(type_info.children), + f"not the right number of children: " + f"expect {len(type_info.children)} got {len(val)}", + ) + if struct.is_ok(): + check_values = list(val) + check_values + + if not struct.is_ok(): + break + assert not check_values, f"bug! values remains to be checked ({check_values}) for {value} {an}" + return struct.maybe_first_error() + + +def get_constraints_errors(value, an): + constraints = ErrorsCollector() + + check_values = [value] + for type_info in get_recursive_composite_type_info(an): + # Constraints check + val = check_values.pop(0) + is_sequence = type_info.origin in (list, tuple) + constraint_failures = get_all_constraints_failures(val, type_info.constraints) + constraints.check_cond( + not constraint_failures, + f"{val!r} failed constraint {constraint_failures}", + ) + if constraints.locally_ok(): + # check only if parent constraints are already ok + if type_info.child_mod == ChildMod.REPEAT: + assert len(type_info.children) == 1 + for v in val: + constraints.check_with(get_constraints_errors, v, type_info.children[0]) + elif is_sequence: + # depth-first + check_values = list(val) + check_values + + assert not check_values, f"bug! values remains to be checked ({check_values}) for {value} {an}" + return constraints.all_errors() diff --git a/tests/test_annotations.py b/tests/test_annotations.py index bd29b07..e4b99cb 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -589,12 +589,12 @@ def test_tuple_ellipsis(): f((0, 1.0)) # type: ignore -type MyInt = Annotated[int, Gt(0)] +type IntGt0 = Annotated[int, Gt(0)] def test_type_alias(): @annotated - def f(a: MyInt): + def f(a: IntGt0): return a assert f(1) == 1 @@ -604,6 +604,17 @@ def test_type_alias(): f(0) # type: ignore +def test_union_types(): + @annotated + def f(a: int | None): + return a + + assert f(None) is None + assert f(0) == 0 + with pytest.raises(AssertionError): + f(1.0) # type: ignore + + type RecList = list[RecList] @@ -623,3 +634,18 @@ def test_recursive_type(): assert f([[[e]], 1]) is e # type: ignore with pytest.raises(AssertionError): assert f([[[e, None]]]) is e # type: ignore + + +type MutRec1 = list[MutRec2] +type MutRec2 = MutRec1 | int + + +def test_mutually_recursive_types(): + @annotated + def f(a: MutRec1): + actual = a + while isinstance(actual, list): + actual = actual[0] + return actual + + assert f([[[0]]]) == 0 diff --git a/tests/test_iter.py b/tests/test_iter.py deleted file mode 100644 index 503dd7b..0000000 --- a/tests/test_iter.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Annotated -from inspect import get_annotations -from contractme.contracting import get_recursive_composite_type_info -import dataclasses - - -type TR2 = TR # type: ignore -type TR = TR2 -type T0 = Annotated[TR, 3] -type T = Annotated[T0, 2] -type T2 = tuple[int, T] -type T3 = Annotated[T2, 1] - - -def f[TT: T3](a: TT) -> TT: # pragma: no cover - return a - - -def g(a: tuple[int, ...]): # pragma: no cover - pass - - -def test_complex(): - an = get_annotations(f)["a"] - type_infos = [dataclasses.astuple(e) for e in get_recursive_composite_type_info(an)] - assert type_infos == [ - (tuple, [1], False, (int, T)), - (int, [], False, tuple()), - (TR, [2, 3], False, tuple()), - ] - - -def test_tuple_ellipsis(): - an = get_annotations(g)["a"] - type_infos = [dataclasses.astuple(e) for e in get_recursive_composite_type_info(an)] - assert type_infos == [(tuple, [], True, (int,))] diff --git a/tests/test_typechecks.py b/tests/test_typechecks.py new file mode 100644 index 0000000..cddde9f --- /dev/null +++ b/tests/test_typechecks.py @@ -0,0 +1,48 @@ +from typing import Annotated +from annotated_types import Ge +from inspect import get_annotations +import dataclasses +from contractme.typecheck import get_recursive_composite_type_info, get_first_structure_error, ChildMod + + +type TR2 = TR # type: ignore +type TR = TR2 +type T0 = Annotated[TR, 3] +type T = Annotated[T0, 2] +type T2 = tuple[int, T] +type T3 = Annotated[T2, 1] + + +def f[TT: T3](a: TT) -> TT: # pragma: no cover + return a + + +def g(a: tuple[int, ...]): # pragma: no cover + pass + + +def h(a: int | Annotated[float, Ge(0.0)]): # pragma: no cover + pass + + +def test_get_info_complex(): + an = get_annotations(f)["a"] + type_infos = [dataclasses.astuple(e) for e in get_recursive_composite_type_info(an)] + assert type_infos == [ + (tuple, [1], ChildMod.NONE, (int, T)), + (int, [], ChildMod.NONE, tuple()), + (TR, [2, 3], ChildMod.NONE, tuple()), + ] + + +def test_get_info_tuple_ellipsis(): + an = get_annotations(g)["a"] + type_infos = [dataclasses.astuple(e) for e in get_recursive_composite_type_info(an)] + assert type_infos == [(tuple, [], ChildMod.REPEAT, (int,))] + + +def test_structure_error_union(): + an = get_annotations(h)["a"] + assert get_first_structure_error(1, an) is None + assert get_first_structure_error([], an) is not None + assert get_first_structure_error("string", an) is not None -- GitLab From 4a7b9094a18b10b0192dd1b18f4fb6640be3456a Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Sat, 11 Oct 2025 17:49:42 +0200 Subject: [PATCH 5/8] docstring for get_types_and_annotations_errors --- src/contractme/annotations.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/contractme/annotations.py b/src/contractme/annotations.py index e7eb334..64f0ad4 100644 --- a/src/contractme/annotations.py +++ b/src/contractme/annotations.py @@ -11,11 +11,17 @@ from contractme.typecheck import ( def get_types_and_annotations_errors(value, an): + """ + Checks the structure and annotations of data provided compared to the annotations. + + The structure is checked first, then iif it is valid, the annotations are. This allows + for a better error messaging: you first provide a correct structure, then correct values. + """ # we visit all of the possible nodes depth-first # for each one, we check its type, and we delegate any annotation check # then we loop over all these annotations, and we check them as well - # this allows for better errors handling: first the structure is checked to be fully OK, - # then we check the values and constraints. + # this has a potentially big runtime cost for nested option or recursive + # types. struct_err = get_first_structure_error(value, an) if struct_err: return f"structure is malformed: {struct_err}" -- GitLab From ebbfa79a702b77601edab3ccabee24326590217d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Germond?= Date: Thu, 19 Jun 2025 06:34:03 +0200 Subject: [PATCH 6/8] Union support for constraint checking Also some more types, not very useful but heh --- src/contractme/annotations.py | 2 +- src/contractme/typecheck.py | 84 +++++++++++++++++++++++++---------- tests/test_typechecks.py | 17 ++++++- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/src/contractme/annotations.py b/src/contractme/annotations.py index 64f0ad4..73353c9 100644 --- a/src/contractme/annotations.py +++ b/src/contractme/annotations.py @@ -48,7 +48,7 @@ def annotated(f): cond_arg_name = "result" cond_f = postcondition else: - cond_arg_name = n + cond_arg_name = arg_name cond_f = precondition def chk(**kw): diff --git a/src/contractme/typecheck.py b/src/contractme/typecheck.py index 5f39cad..4b55ba4 100644 --- a/src/contractme/typecheck.py +++ b/src/contractme/typecheck.py @@ -10,6 +10,7 @@ import typing from typing import ( Annotated, Any, + Callable, Iterable, TypeAliasType, TypeVar, @@ -24,8 +25,16 @@ assert sys.version_info >= (3, 12) type ClassInfo = type | types.UnionType | tuple[ClassInfo, ...] type AnnotationForm = Any +type Constraint = functools.partial | Any +# More specifically: some are taken into account (annotated-types), +# some are not (open world: it's yours and IDC about it) +# maybe they could be redispatched to extensions in the future? closing the +# world so that any malformed constraint could be detected? +type NormalizedConstraint = Any +# partial functions are re-contracted to annotated-types constraints -def normalize_constraint(constraint): + +def normalize_constraint(constraint: Constraint) -> NormalizedConstraint: if isinstance(constraint, functools.partial): # https://github.com/annotated-types/annotated-types?tab=readme-ov-file#gt-ge-lt-le match constraint.func, constraint.args: @@ -43,8 +52,7 @@ def normalize_constraint(constraint): else: return constraint - -def get_all_constraints_failures(val, constraints): +def get_all_constraint_errors(val, constraints: list[Constraint]) -> list[NormalizedConstraint]: if not constraints: return [] else: @@ -214,7 +222,7 @@ class ErrorsCollector: def locally_ok(self) -> bool: return not self._local_error - def check_with(self, check_function, *args): + def check_with(self, check_function: Callable[..., str], *args): err = check_function(*args) if err: self._append_error(err) @@ -224,13 +232,19 @@ class ErrorsCollector: def all_errors(self) -> list[str]: return self._errors + + def clear_errors(self) -> None: + self._errors = [] -def get_first_structure_error(value, an) -> str | None: +def get_first_structure_error(value, an, recursive_type_info=None) -> str | None: struct = ErrorsCollector() + if recursive_type_info is None: + recursive_type_info = get_recursive_composite_type_info(an) + check_values = [value] - for type_info in get_recursive_composite_type_info(an): + for type_info in recursive_type_info: # Struct check val = check_values.pop(0) struct.set_context(type_info) @@ -252,7 +266,8 @@ def get_first_structure_error(value, an) -> str | None: if type_info.child_mod == ChildMod.REPEAT: assert len(type_info.children) == 1 for v in val: - struct.check_with(get_first_structure_error, v, type_info.children[0]) + err = get_first_structure_error(v, type_info.children[0]) + struct.check_cond(not err, str(err)) if not struct.is_ok(): break elif is_sequence: @@ -271,28 +286,49 @@ def get_first_structure_error(value, an) -> str | None: return struct.maybe_first_error() -def get_constraints_errors(value, an): +def get_constraints_errors(value, an, recursive_type_info: Iterable[CompositeTypeInfo]|None=None): constraints = ErrorsCollector() + def only_valid(rti: Iterable[CompositeTypeInfo]): + # unimplemented + yield from rti + + if recursive_type_info is None: + recursive_type_info = get_recursive_composite_type_info(an) + check_values = [value] - for type_info in get_recursive_composite_type_info(an): + for type_info in only_valid(recursive_type_info): # Constraints check val = check_values.pop(0) - is_sequence = type_info.origin in (list, tuple) - constraint_failures = get_all_constraints_failures(val, type_info.constraints) - constraints.check_cond( - not constraint_failures, - f"{val!r} failed constraint {constraint_failures}", - ) - if constraints.locally_ok(): - # check only if parent constraints are already ok - if type_info.child_mod == ChildMod.REPEAT: - assert len(type_info.children) == 1 - for v in val: - constraints.check_with(get_constraints_errors, v, type_info.children[0]) - elif is_sequence: - # depth-first - check_values = list(val) + check_values + if type_info.child_mod == ChildMod.OPTION: + constraint_failures = None + for child in type_info.children: + # check structure, again, very innefficient + if get_first_structure_error(val, child): + continue + err = get_constraints_errors(val, child) + constraints.check_cond(not err, "tried option {child_type_info}: {err}") + if not err: + # found a config that works + constraints.clear_errors() + break + else: + is_sequence = type_info.origin in (list, tuple) + constraint_failures = get_all_constraint_errors(val, type_info.constraints) + constraints.check_cond( + not constraint_failures, + f"{val!r} failed constraint {constraint_failures}", + ) + if constraints.locally_ok(): + # check only if parent constraints are already ok + if type_info.child_mod == ChildMod.REPEAT: + assert len(type_info.children) == 1 + for v in val: + err = get_constraints_errors(v, type_info.children[0]) + constraints.check_cond(not err, str(err)) + elif is_sequence: + # depth-first + check_values = list(val) + check_values assert not check_values, f"bug! values remains to be checked ({check_values}) for {value} {an}" return constraints.all_errors() diff --git a/tests/test_typechecks.py b/tests/test_typechecks.py index cddde9f..4f63d21 100644 --- a/tests/test_typechecks.py +++ b/tests/test_typechecks.py @@ -2,7 +2,12 @@ from typing import Annotated from annotated_types import Ge from inspect import get_annotations import dataclasses -from contractme.typecheck import get_recursive_composite_type_info, get_first_structure_error, ChildMod +from contractme.typecheck import ( + get_recursive_composite_type_info, + get_first_structure_error, + get_constraints_errors, + ChildMod, +) type TR2 = TR # type: ignore @@ -44,5 +49,15 @@ def test_get_info_tuple_ellipsis(): def test_structure_error_union(): an = get_annotations(h)["a"] assert get_first_structure_error(1, an) is None + + # tricky: constraint error but structure is OK + assert get_first_structure_error(-1.0, an) is None + assert get_first_structure_error([], an) is not None assert get_first_structure_error("string", an) is not None + + +def test_constraints_error_union(): + an = get_annotations(h)["a"] + assert not get_constraints_errors(0.0, an) + assert get_constraints_errors(-1.0, an) -- GitLab From a9a41b55f216ed2ec5b86d9e42143d11e9d52aca Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Sat, 11 Oct 2025 18:43:39 +0200 Subject: [PATCH 7/8] remove unused check_with function --- src/contractme/typecheck.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/contractme/typecheck.py b/src/contractme/typecheck.py index 4b55ba4..1ebf8f4 100644 --- a/src/contractme/typecheck.py +++ b/src/contractme/typecheck.py @@ -222,11 +222,6 @@ class ErrorsCollector: def locally_ok(self) -> bool: return not self._local_error - def check_with(self, check_function: Callable[..., str], *args): - err = check_function(*args) - if err: - self._append_error(err) - def maybe_first_error(self) -> str | None: return None if not self._errors else self._errors[0] -- GitLab From ba25f1c340cb20c6410fa3947361501beeedc5c5 Mon Sep 17 00:00:00 2001 From: Leo Germond Date: Sat, 11 Oct 2025 18:45:12 +0200 Subject: [PATCH 8/8] black --- src/contractme/annotations.py | 2 +- src/contractme/typecheck.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/contractme/annotations.py b/src/contractme/annotations.py index 73353c9..87fa97e 100644 --- a/src/contractme/annotations.py +++ b/src/contractme/annotations.py @@ -13,7 +13,7 @@ from contractme.typecheck import ( def get_types_and_annotations_errors(value, an): """ Checks the structure and annotations of data provided compared to the annotations. - + The structure is checked first, then iif it is valid, the annotations are. This allows for a better error messaging: you first provide a correct structure, then correct values. """ diff --git a/src/contractme/typecheck.py b/src/contractme/typecheck.py index 1ebf8f4..a2390af 100644 --- a/src/contractme/typecheck.py +++ b/src/contractme/typecheck.py @@ -52,6 +52,7 @@ def normalize_constraint(constraint: Constraint) -> NormalizedConstraint: else: return constraint + def get_all_constraint_errors(val, constraints: list[Constraint]) -> list[NormalizedConstraint]: if not constraints: return [] @@ -227,7 +228,7 @@ class ErrorsCollector: def all_errors(self) -> list[str]: return self._errors - + def clear_errors(self) -> None: self._errors = [] @@ -281,7 +282,9 @@ def get_first_structure_error(value, an, recursive_type_info=None) -> str | None return struct.maybe_first_error() -def get_constraints_errors(value, an, recursive_type_info: Iterable[CompositeTypeInfo]|None=None): +def get_constraints_errors( + value, an, recursive_type_info: Iterable[CompositeTypeInfo] | None = None +): constraints = ErrorsCollector() def only_valid(rti: Iterable[CompositeTypeInfo]): -- GitLab