From 7e02e98b9226a77273f95db6938f54d689523412 Mon Sep 17 00:00:00 2001 From: zach Date: Wed, 25 Sep 2024 11:11:59 -0700 Subject: [PATCH] cleanup: API improvements (#10) --- Makefile | 10 +-- bin/src/opt.rs | 2 +- bin/src/options.rs | 2 + build.py | 23 ++++-- examples/imports.py | 4 +- lib/src/prelude.py | 193 +++++++++++++++++++++++++++++++++++++------- 6 files changed, 189 insertions(+), 45 deletions(-) diff --git a/Makefile b/Makefile index ccf887f..c7037fb 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +PYTHON_FILES=lib/src/prelude.py bin/src/invoke.py build.py + build: ./build.py build @@ -5,14 +7,10 @@ install: ./build.py install format: - uv run ruff format lib/src/prelude.py - uv run ruff format bin/src/invoke.py - uv run ruff format build.py + uv run ruff format $(PYTHON_FILES) check: - uv run ruff check lib/src/prelude.py - uv run ruff check bin/src/invoke.py - uv run ruff check build.py + uv run ruff check $(PYTHON_FILES) clean: rm -rf bin/target lib/target diff --git a/bin/src/opt.rs b/bin/src/opt.rs index 8b3d4e3..c4d6915 100644 --- a/bin/src/opt.rs +++ b/bin/src/opt.rs @@ -28,7 +28,7 @@ fn find_deps() -> PathBuf { directories::BaseDirs::new() .unwrap() - .data_local_dir() + .data_dir() .join("extism-py") } diff --git a/bin/src/options.rs b/bin/src/options.rs index 49f546c..e02e489 100644 --- a/bin/src/options.rs +++ b/bin/src/options.rs @@ -7,6 +7,8 @@ pub struct Options { #[structopt(parse(from_os_str))] pub input_py: PathBuf, + // #[structopt(parse(from_os_str))] + // pub other_files: Vec, #[structopt(short = "o", parse(from_os_str), default_value = "index.wasm")] pub output: PathBuf, diff --git a/build.py b/build.py index 6ef7f23..c86e21f 100755 --- a/build.py +++ b/build.py @@ -26,7 +26,7 @@ def bin_dir() -> Path: if system in ["linux", "darwin"]: return home / ".local" / "bin" elif system == "windows": - return Path(os.getenv("APPDATA")) / "Python" / "Scripts" + return Path(os.getenv("USERHOME")) else: raise OSError(f"Unsupported OS {system}") @@ -42,6 +42,13 @@ def data_dir() -> Path: raise OSError(f"Unsupported OS {system}") +def exe_file() -> str: + if system == "windows": + return "extism-py.exe" + else: + return "extism-py" + + def run_subprocess(command, cwd=None, quiet=False): try: logging.info(f"Running command: {' '.join(command)} in {cwd or '.'}") @@ -78,7 +85,7 @@ def do_build(args): check_rust_installed() run_subprocess(["cargo", "build", "--release"], cwd="./lib", quiet=args.quiet) run_subprocess(["cargo", "build", "--release"], cwd="./bin", quiet=args.quiet) - shutil.copy2(Path("./bin/target/release/extism-py"), Path("./extism-py")) + shutil.copy2(Path("./bin/target/release") / exe_file(), Path(".") / exe_file()) logging.info("Build completed successfully.") @@ -89,16 +96,16 @@ def do_install(args): bin_dir.mkdir(parents=True, exist_ok=True) data_dir.mkdir(parents=True, exist_ok=True) - dest_path = bin_dir / "extism-py" + dest_path = bin_dir / exe_file() logging.info(f"Copying binary to {dest_path}") - shutil.copy2(Path("./bin/target/release/extism-py"), dest_path) + shutil.copy2(Path("./bin/target/release") / exe_file(), dest_path) logging.info(f"Copying data files to {data_dir}") shutil.copytree( Path("./lib/target/wasm32-wasi/wasi-deps/usr"), data_dir, dirs_exist_ok=True ) - logging.info(f"extism-py installed to {bin_dir}") + logging.info(f"{exe_file()} installed to {bin_dir}") logging.info(f"Data files installed to {data_dir}") logging.info("Installation completed successfully.") @@ -107,15 +114,15 @@ def do_clean(args): logging.info("Cleaning build artifacts...") shutil.rmtree("./lib/target", ignore_errors=True) shutil.rmtree("./bin/target", ignore_errors=True) - if Path("./extism-py").exists(): - Path("./extism-py").unlink() + if (Path(".") / exe_file()).exists(): + (Path(".") / exe_file()).unlink() logging.info("Clean completed successfully.") def get_version(): try: result = subprocess.run( - ["extism-py", "--version"], capture_output=True, text=True, check=True + [exe_file(), "--version"], capture_output=True, text=True, check=True ) return result.stdout.strip() except subprocess.CalledProcessError: diff --git a/examples/imports.py b/examples/imports.py index 7a14896..a59fc9d 100644 --- a/examples/imports.py +++ b/examples/imports.py @@ -10,7 +10,7 @@ def reflect(x: str) -> str: pass @extism.import_fn("example", "update_dict") -def update_dict(x: dict) -> dict: +def update_dict(x: extism.JsonObject) -> extism.JsonObject: pass @extism.plugin_fn @@ -22,5 +22,5 @@ def count_vowels(): if ch in ['A', 'a', 'E', 'e', 'I', 'i', 'O', 'o', 'U', 'u']: total += 1 extism.log(extism.LogLevel.Info, "Hello!") - extism.output_json(update_dict({"count": total})) + extism.output(update_dict({"count": total})) diff --git a/lib/src/prelude.py b/lib/src/prelude.py index 47842dd..69ac0ad 100644 --- a/lib/src/prelude.py +++ b/lib/src/prelude.py @@ -1,16 +1,27 @@ from typing import Union, Optional - import json +from enum import Enum +from abc import ABC, abstractmethod +from datetime import datetime +from base64 import b64encode, b64decode +from dataclasses import is_dataclass + import extism_ffi as ffi LogLevel = ffi.LogLevel -log = ffi.log input_str = ffi.input_str input_bytes = ffi.input_bytes output_str = ffi.output_str output_bytes = ffi.output_bytes memory = ffi.memory +def log(level, msg): + if isinstance(msg, bytes): + msg = msg.decode() + elif not isinstance(msg, str): + msg = str(msg) + ffi.log(level, msg) + HttpRequest = ffi.HttpRequest __exports = [] @@ -18,39 +29,117 @@ IMPORT_INDEX = 0 -class Codec: +class Codec(ABC): """ Codec is used to serialize and deserialize values in Extism memory - """ - - def __init__(self, value): - self.value = value - - def get(self): - """Method to get the inner value""" - return self.value - - def set(self, x): - """Method to set in the inner value""" - self.value = x + """ + @abstractmethod def encode(self) -> bytes: """Encode the inner value to bytes""" raise Exception("encode not implemented") - @staticmethod + @classmethod + @abstractmethod def decode(s: bytes): """Decode a value from bytes""" raise Exception("encode not implemented") + def __post_init__(self): + self._fix_fields() + + def _fix_fields(self): + if not hasattr(self, '__annotations__'): + return + for k in self.__annotations__: + ty = self.__annotations__[k] + v = getattr(self, k) + setattr(self, k, self._fix_field(ty, v)) + return self + + def _fix_field(self, ty: type, v): + def check_subclass(a, b): + try: + return issubclass(a, b) + except Exception as _: + return False + if isinstance(v, dict) and check_subclass(ty, Codec): + return ty(**v)._fix_fields() + elif isinstance(v, str) and check_subclass(ty, Enum): + return ty(v) + elif isinstance(v, list) and hasattr(ty, '__origin__') and ty.__origin__ is list: + ty = ty.__args__[0] + return [self._fix_field(ty, x) for x in v] + elif hasattr(ty, '__origin__') and ty.__origin__ is Union: + if len(ty.__args__) == 2 and ty.__args__[1] == type(None) and v is not None: + ty = ty.__args__[0] + return self._fix_field(ty, v) + return v + + +class JSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Json): + return json.loads(o.encode().decode()) + elif isinstance(o, bytes): + return b64encode(o).decode() + elif isinstance(o, datetime): + return o.isoformat() + elif isinstance(o, Enum): + return str(o.value) + elif isinstance(o, list): + return [self.default(x) for x in o] + elif isinstance(o, dict): + return {k: self.default(x) for k, x in o.items()} + return super().default(o) + + +class JSONDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + + def object_hook(self, dct): + if not isinstance(dct, dict): + return dct + for k, v in dct.items(): + if isinstance(v, str): + try: + dct[k] = datetime.fromisoformat(v) + continue + except Exception as _: + pass + + try: + dct[k] = b64decode(v.encode()) + continue + except Exception as _: + pass + elif isinstance(v, dict): + dct[k] = self.object_hook(v) + elif isinstance(v, list): + dct[k] = [self.object_hook(x) for x in v] + return dct + class Json(Codec): def encode(self) -> bytes: - return json.dumps(self.value).encode() + v = self + if not isinstance(self, (dict, datetime, bytes)) and hasattr(self, "__dict__"): + if len(self.__dict__) > 0: + v = self.__dict__ + return json.dumps(v, cls=JSONEncoder).encode() + + @classmethod + def decode(cls, s: bytes): + x = json.loads(s.decode(), cls=JSONDecoder) + if is_dataclass(cls): + return cls(**x) + else: + return cls(**x)._fix_fields() - @staticmethod - def decode(s: bytes): - return Json(json.loads(s.decode())) + +class JsonObject(Json, dict): + pass def _store(x) -> int: @@ -59,9 +148,11 @@ def _store(x) -> int: elif isinstance(x, bytes): return ffi.memory.alloc(x).offset elif isinstance(x, dict) or isinstance(x, list): - return ffi.memory.alloc(json.dumps(x).encode()).offset + return ffi.memory.alloc(json.dumps(x, cls=JSONEncoder).encode()).offset elif isinstance(x, Codec): return ffi.memory.alloc(x.encode()).offset + elif isinstance(x, Enum): + return ffi.memory.alloc(str(x.value).encode()).offset elif isinstance(x, ffi.memory.MemoryHandle): return x.offset elif isinstance(x, int): @@ -85,9 +176,11 @@ def _load(t, x): elif t is bytes: return ffi.memory.bytes(mem) elif t is dict or t is list: - return json.loads(ffi.memory.string(mem)) + return json.loads(ffi.memory.string(mem), cls=JSONDecoder) elif issubclass(t, Codec): return t.decode(ffi.memory.bytes(mem)) + elif issubclass(t, Enum): + return t(ffi.memory.string(mem)) elif t is ffi.memory.MemoryHandle: return mem elif t is type(None): @@ -139,14 +232,58 @@ def inner(*args): return inner -def input_json(): +def input_json(t: Optional[type] = None): """Get input as JSON""" - return json.loads(input_str()) + if t is int or t is float: + return t(json.loads(input_str(), cls=JSONDecoder)) + if issubclass(t, Json): + return t(**json.loads(input_str(), cls=JSONDecoder)) + return json.loads(input_str(), cls=JSONDecoder) def output_json(x): """Set JSON output""" - output_str(json.dumps(x)) + if isinstance(x, int) or isinstance(x, float): + output_str(json.dumps(str(x))) + return + + if hasattr(x, "__dict__"): + x = x.__dict__ + output_str(json.dumps(x, cls=JSONEncoder)) + + +def input(t: type = None): + if t is None: + return None + if t is str: + return input_str() + elif t is bytes: + return input_bytes() + elif issubclass(t, Codec): + return t.decode(input_bytes()) + elif t is dict or t is list: + return json.loads(input_str(), cls=JSONDecoder) + elif issubclass(t, Enum): + return t(input_str()) + else: + raise Exception(f"Unsupported type for input: {t}") + + +def output(x=None): + if x is None: + return + if isinstance(x, str): + output_str(x) + elif isinstance(x, bytes): + output_bytes(x) + elif isinstance(x, Codec): + output_bytes(x.encode()) + elif isinstance(x, dict) or isinstance(x, list): + output_json(x) + elif isinstance(x, Enum): + output_str(x.value) + else: + raise Exception(f"Unsupported type for output: {type(x)}") class Var: @@ -169,7 +306,7 @@ def get_json(key: str): x = Var.get_str(key) if x is None: return x - return json.loads(x) + return json.loads(x, cls=JSONDecoder) @staticmethod def set(key: str, value: Union[bytes, str]): @@ -191,7 +328,7 @@ def get_json(key: str): x = ffi.config_get(key) if x is None: return None - return json.loads(x) + return json.loads(x, cls=JSONDecoder) class HttpResponse: @@ -215,7 +352,7 @@ def data_str(self): def data_json(self): """Get response body JSON""" - return json.loads(self.data_str()) + return json.loads(self.data_str(), cls=JSONDecoder) class Http: