diff --git a/examples/remote_code.py b/examples/remote_code.py new file mode 100644 index 0000000000000000000000000000000000000000..35aab66026c5ad3faad5983f0809569b395d6def --- /dev/null +++ b/examples/remote_code.py @@ -0,0 +1,38 @@ +from heros import LocalHERO + +import random +import time +import argparse + +random.seed() + + +class TestObject(LocalHERO): + + foovar: str = "" + testme: int = 0 + + def hello(self) -> str: + self.testme += 1 + return "world" + + def get_power_func(self, power): + def inner(x): + return x**power + return inner + + def run_func(self, func, args, kwargs): + return func.__get__(self)(*args, **kwargs) + + +parser = argparse.ArgumentParser(prog="remote_object", description="Example of how to access a remote HERO") +parser.add_argument("--realm", "-r", default="heros", type=str) +parser.add_argument("name", help="identifier of the remote HERO") +args = parser.parse_args() + +with TestObject(args.name, realm=args.realm) as obj: + # keep running + i = 0 + while True: + time.sleep(1) + i += 1 diff --git a/pyproject.toml b/pyproject.toml index 63cf2921eb4b14e0b56731a3ffcb791ae43f5252..27b0aec0865ec9e483ce10b5225ea5384b243761 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "eclipse-zenoh>=1.1.0", "cbor2", "numpy", + "dill>=0.4.0", ] classifiers = [ diff --git a/src/heros/serdes.py b/src/heros/serdes.py index eb7d6f0f9684095d9eb486153905d68519d9d9d6..b02549f7fc926cc86fa337b8bc35e22c26e18d4b 100644 --- a/src/heros/serdes.py +++ b/src/heros/serdes.py @@ -1,11 +1,13 @@ import cbor2 import numpy as np import weakref +import dill from heros.helper import full_classname from .inspect import is_hero_event, is_hero_method ndarray_tag = 4242 +function_tag = 4243 hero_tag = 9000 unserializable_tag = 9001 unserializable_object_reference_tag = 9002 @@ -112,6 +114,9 @@ def cbor_default_encoder(encoder, value): # value.__self__ corresponds is the HERO instance encoder.encode(cbor2.CBORTag(hero_method_tag, [value.__self__, value.__name__])) + elif callable(value): + encoder.encode(cbor2.CBORTag(function_tag, dill.dumps(value))) + elif type(value) is UnserializableRemoteObject: # encode an reference on an remote object that cannot be serialized encoder.encode(cbor2.CBORTag(unserializable_object_reference_tag, value.id)) @@ -147,6 +152,9 @@ def cbor_tag_hook(decoder, tag, shareable_index=None): remote_hero, method_name = tag.value return getattr(remote_hero, method_name) + if tag.tag == function_tag: + return dill.loads(tag.value) + if tag.tag == unserializable_tag: # decode t, i, s = tag.value diff --git a/tests/test_serdes.py b/tests/test_serdes.py index 5f023dc4a2624777493736e67472dfcf43c90c07..9f199f2d63e925e8f7b382ce2fc20607090e31c9 100644 --- a/tests/test_serdes.py +++ b/tests/test_serdes.py @@ -1,3 +1,4 @@ +import pytest from heros.serdes import serialize, deserialize import numpy as np @@ -16,3 +17,45 @@ def test_ndarray(): assert not arr_none_cont.flags.f_contiguous assert not arr_none_cont.flags.c_contiguous assert np.array_equiv(arr_none_cont, deserialize(serialize(arr_none_cont))) + + +def add(a, b): + return a + b + + +def greet(name="world"): + return f"Hello, {name}!" + + +lambda_fn = lambda x: x * 2 # noqa: E731 + + +@pytest.mark.parametrize( + "func,args,kwargs,expected", + [ + (add, (1, 2), {}, 3), + (greet, (), {}, "Hello, world!"), + (greet, (), {"name": "Alice"}, "Hello, Alice!"), + (lambda_fn, (5,), {}, 10), + ], +) +def test_function_serialization_roundtrip(func, args, kwargs, expected): + """Ensure serialized + deserialized function keeps its behavior.""" + serialized = serialize(func) + deserialized = deserialize(serialized) + + assert callable(deserialized), "Deserialized object is not callable" + assert deserialized(*args, **kwargs) == expected, "Function behavior changed after roundtrip" + + +def test_different_functions_remain_distinct(): + """Ensure two different functions don't deserialize to the same object.""" + ser_add = serialize(add) + ser_greet = serialize(greet) + + f1 = deserialize(ser_add) + f2 = deserialize(ser_greet) + + assert f1 is not f2 + assert f1(1, 2) == 3 + assert f2("Bob") == "Hello, Bob!"