From 282648a952bcb8e7c0b6a5b79e06e50dbffe3d87 Mon Sep 17 00:00:00 2001 From: zach Date: Fri, 20 Sep 2024 09:57:52 -0700 Subject: [PATCH] cleanup: make Codec/Json base classes --- bin/src/options.rs | 2 ++ examples/imports.py | 4 +-- lib/src/prelude.py | 76 +++++++++++++++++++++++++++++++++------------ 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/bin/src/options.rs b/bin/src/options.rs index 49f546c..e02e489 100644 --- a/bin/src/options.rs +++ b/bin/src/options.rs @@ -7,6 +7,8 @@ pub struct Options { #[structopt(parse(from_os_str))] pub input_py: PathBuf, + // #[structopt(parse(from_os_str))] + // pub other_files: Vec, #[structopt(short = "o", parse(from_os_str), default_value = "index.wasm")] pub output: PathBuf, diff --git a/examples/imports.py b/examples/imports.py index 7a14896..a59fc9d 100644 --- a/examples/imports.py +++ b/examples/imports.py @@ -10,7 +10,7 @@ def reflect(x: str) -> str: pass @extism.import_fn("example", "update_dict") -def update_dict(x: dict) -> dict: +def update_dict(x: extism.JsonObject) -> extism.JsonObject: pass @extism.plugin_fn @@ -22,5 +22,5 @@ def count_vowels(): if ch in ['A', 'a', 'E', 'e', 'I', 'i', 'O', 'o', 'U', 'u']: total += 1 extism.log(extism.LogLevel.Info, "Hello!") - extism.output_json(update_dict({"count": total})) + extism.output(update_dict({"count": total})) diff --git a/lib/src/prelude.py b/lib/src/prelude.py index 47842dd..92b74a7 100644 --- a/lib/src/prelude.py +++ b/lib/src/prelude.py @@ -1,6 +1,8 @@ from typing import Union, Optional - import json +from enum import Enum +from abc import ABC, abstractmethod + import extism_ffi as ffi LogLevel = ffi.LogLevel @@ -18,27 +20,18 @@ IMPORT_INDEX = 0 -class Codec: +class Codec(ABC): """ Codec is used to serialize and deserialize values in Extism memory """ - def __init__(self, value): - self.value = value - - def get(self): - """Method to get the inner value""" - return self.value - - def set(self, x): - """Method to set in the inner value""" - self.value = x - + @abstractmethod def encode(self) -> bytes: """Encode the inner value to bytes""" raise Exception("encode not implemented") - @staticmethod + @classmethod + @abstractmethod def decode(s: bytes): """Decode a value from bytes""" raise Exception("encode not implemented") @@ -46,11 +39,18 @@ def decode(s: bytes): class Json(Codec): def encode(self) -> bytes: - return json.dumps(self.value).encode() + v = self + if not isinstance(self, dict) and hasattr(self, "__dict__"): + v = self.__dict__ + return json.dumps(v).encode() - @staticmethod - def decode(s: bytes): - return Json(json.loads(s.decode())) + @classmethod + def decode(cls, s: bytes): + return cls(**json.loads(s.decode())) + + +class JsonObject(Json, dict): + pass def _store(x) -> int: @@ -139,16 +139,54 @@ def inner(*args): return inner -def input_json(): +def input_json(t: Optional[type] = None): """Get input as JSON""" + if t is not None: + return json.loads(input_str(), object_hook=lambda x: t(**x)) return json.loads(input_str()) def output_json(x): """Set JSON output""" + if hasattr(x, "__dict__"): + x = x.__dict__ output_str(json.dumps(x)) +def input(t: type = None): + if t is None: + return None + if t is str: + return input_str() + elif t is bytes: + return input_bytes() + elif issubclass(t, Codec): + return t.decode(input_bytes()) + elif t is dict or t is list: + return json.loads(input_str()) + elif issubclass(t, Enum): + return t(input_str()) + else: + raise Exception(f"Unsupported type for input: {t}") + + +def output(x=None): + if x is None: + return + if isinstance(x, str): + output_str(x) + elif isinstance(x, bytes): + output_bytes(x) + elif isinstance(x, Codec): + output_bytes(x.encode()) + elif isinstance(x, dict) or isinstance(x, list): + output_json(x) + elif isinstance(x, Enum): + output_str(x.value) + else: + raise Exception(f"Unsupported type for output: {type(x)}") + + class Var: @staticmethod def get_bytes(key: str) -> Optional[bytes]: