Skip to content

Commit

Permalink
WASM support
Browse files Browse the repository at this point in the history
Add wasm support
  • Loading branch information
CGamesPlay committed Oct 13, 2024
1 parent 7c7db77 commit 3192357
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 13 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ The package has no library dependencies and provides pre-compiled wheels for all
```sh
pip install tree-sitter
# For wasm support
pip install tree-sitter[wasm]
```

## Usage
Expand All @@ -39,6 +41,22 @@ from tree_sitter import Language, Parser
PY_LANGUAGE = Language(tspython.language())
```

#### Wasm support

If you enable the `wasm` extra, then tree-sitter will be able to use wasmtime to load languages compiled to wasm and parse with them. Example:

```python
from pathlib import Path
from wasmtime import Engine
from tree_sitter import Language, Parser

engine = Engine()
wasm_bytes = Path("my_language.wasm").read_bytes()
MY_LANGUAGE = Language.from_wasm("my_language", engine, wasm_bytes)
```

Languages loaded this way work identically to native-binary languages.

### Basic parsing

Create a `Parser` and configure it to use a language:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ tests = [
"tree-sitter-python>=0.23.0",
"tree-sitter-rust>=0.23.0",
]
wasm = ["wasmtime>=23"]

[tool.ruff]
target-version = "py39"
Expand Down
28 changes: 16 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,33 @@
"tree_sitter/binding/range.c",
"tree_sitter/binding/tree.c",
"tree_sitter/binding/tree_cursor.c",
"tree_sitter/binding/wasmtime.c",
"tree_sitter/binding/module.c",
],
include_dirs=[
"tree_sitter/binding",
"tree_sitter/core/lib/include",
"tree_sitter/core/lib/src",
"tree_sitter/core/lib/src/wasm",
],
define_macros=[
("PY_SSIZE_T_CLEAN", None),
("TREE_SITTER_HIDE_SYMBOLS", None),
("TREE_SITTER_FEATURE_WASM", None),
],
undef_macros=[
"TREE_SITTER_FEATURE_WASM",
],
extra_compile_args=[
"-std=c11",
"-fvisibility=hidden",
"-Wno-cast-function-type",
"-Werror=implicit-function-declaration",
] if system() != "Windows" else [
"/std:c11",
"/wd4244",
],
extra_compile_args=(
[
"-std=c11",
"-fvisibility=hidden",
"-Wno-cast-function-type",
"-Werror=implicit-function-declaration",
]
if system() != "Windows"
else [
"/std:c11",
"/wd4244",
]
),
)
],
)
36 changes: 36 additions & 0 deletions tests/test_wasm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import importlib.resources
from unittest import TestCase

from tree_sitter import Language, Parser, Tree

try:
import wasmtime

class TestWasm(TestCase):
@classmethod
def setUpClass(cls):
javascript_wasm = (
importlib.resources.files("tests")
.joinpath("wasm/tree-sitter-javascript.wasm")
.read_bytes()
)
engine = wasmtime.Engine()
cls.javascript = Language.from_wasm("javascript", engine, javascript_wasm)

def test_parser(self):
parser = Parser(self.javascript)
self.assertIsInstance(parser.parse(b"test"), Tree)

def test_language_is_wasm(self):
self.assertEqual(self.javascript.is_wasm, True)

except ImportError:

class TestWasmDisabled(TestCase):
def test_parser(self):
def runtest():
Language.from_wasm("javascript", None, b"")

self.assertRaisesRegex(
RuntimeError, "wasmtime module is not loaded", runtest
)
Binary file added tests/wasm/tree-sitter-javascript.wasm
Binary file not shown.
5 changes: 4 additions & 1 deletion tree_sitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
MIN_COMPATIBLE_LANGUAGE_VERSION,
)

Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns."

Point.__doc__ = (
"A position in a multi-line text document, in terms of rows and columns."
)
Point.row.__doc__ = "The zero-based row of the document."
Point.column.__doc__ = "The zero-based column of the document."

Expand Down
127 changes: 127 additions & 0 deletions tree_sitter/binding/language.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "types.h"

extern void wasm_engine_delete(TSWasmEngine *engine);
extern TSWasmEngine *wasmtime_engine_clone(TSWasmEngine *engine);

int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
PyObject *language;
if (!PyArg_ParseTuple(args, "O:__init__", &language)) {
Expand Down Expand Up @@ -30,10 +33,119 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
}

void language_dealloc(Language *self) {
if (self->wasm_engine != NULL) {
wasm_engine_delete(self->wasm_engine);
}
ts_language_delete(self->language);
Py_TYPE(self)->tp_free(self);
}

// ctypes.cast(managed_pointer.ptr(), ctypes.c_void_p).value
static void *get_managed_pointer(PyObject *cast, PyObject *c_void_p, PyObject *managed_pointer) {
void *ptr = NULL;
PyObject *ptr_method = NULL;
PyObject *ptr_result = NULL;
PyObject *cast_result = NULL;
PyObject *value_attr = NULL;

// Call .ptr() method on the managed pointer
ptr_method = PyObject_GetAttrString(managed_pointer, "ptr");
if (ptr_method == NULL) {
goto cleanup;
}
ptr_result = PyObject_CallObject(ptr_method, NULL);
if (ptr_result == NULL) {
goto cleanup;
}

// Call cast function
cast_result = PyObject_CallFunctionObjArgs(cast, ptr_result, c_void_p, NULL);
if (cast_result == NULL) {
goto cleanup;
}

// Get the 'value' attribute from the cast result
value_attr = PyObject_GetAttrString(cast_result, "value");
if (value_attr == NULL) {
goto cleanup;
}

// Convert the value attribute to a C void pointer
ptr = PyLong_AsVoidPtr(value_attr);

cleanup:
Py_XDECREF(value_attr);
Py_XDECREF(cast_result);
Py_XDECREF(ptr_result);
Py_XDECREF(ptr_method);

if (PyErr_Occurred()) {
return NULL;
}

return ptr;
}

PyObject *language_from_wasm(PyTypeObject *cls, PyObject *args) {
ModuleState *state = (ModuleState *)PyType_GetModuleState(cls);
TSWasmError error;
TSWasmStore *wasm_store = NULL;
TSLanguage *language = NULL;
Language *self = NULL;
char *name;
PyObject *py_engine = NULL;
char *wasm;
Py_ssize_t wasm_length;
if (state->wasmtime_engine_type == NULL) {
PyErr_SetString(PyExc_RuntimeError, "wasmtime module is not loaded");
return NULL;
}
if (!PyArg_ParseTuple(args, "sO!y#:from_wasm", &name, state->wasmtime_engine_type, &py_engine, &wasm, &wasm_length)) {
return NULL;
}

TSWasmEngine *engine = (TSWasmEngine *)get_managed_pointer(state->ctypes_cast, state->c_void_p, py_engine);
if (engine == NULL) {
goto fail;
}
engine = wasmtime_engine_clone(engine);
if (engine == NULL) {
goto fail;
}

wasm_store = ts_wasm_store_new(engine, &error);
if (wasm_store == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to create TSWasmStore: %s", error.message);
goto fail;
}

language = (TSLanguage *)ts_wasm_store_load_language(wasm_store, name, wasm, wasm_length, &error);
if (language == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to load language: %s", error.message);
goto fail;
}

self = (Language *)cls->tp_alloc(cls, 0);
if (self == NULL) {
goto fail;
}

self->language = language;
self->wasm_engine = engine;
self->version = ts_language_version(self->language);
#if HAS_LANGUAGE_NAMES
self->name = ts_language_name(self->language);
#endif
return (PyObject *)self;

fail:
if (engine != NULL) {
wasm_engine_delete(engine);
}
ts_language_delete(language);
return NULL;
}

PyObject *language_repr(Language *self) {
#if HAS_LANGUAGE_NAMES
if (self->name == NULL) {
Expand Down Expand Up @@ -82,6 +194,10 @@ PyObject *language_get_field_count(Language *self, void *Py_UNUSED(payload)) {
return PyLong_FromUnsignedLong(ts_language_field_count(self->language));
}

PyObject *language_is_wasm(Language *self, void *Py_UNUSED(payload)) {
return PyBool_FromLong(ts_language_is_wasm(self->language));
}

PyObject *language_node_kind_for_id(Language *self, PyObject *args) {
TSSymbol symbol;
if (!PyArg_ParseTuple(args, "H:node_kind_for_id", &symbol)) {
Expand Down Expand Up @@ -190,6 +306,9 @@ PyObject *language_query(Language *self, PyObject *args) {
return PyObject_CallFunction((PyObject *)state->query_type, "Os#", self, source, length);
}

PyDoc_STRVAR(language_from_wasm_doc,
"from_wasm(self, name, engine, wasm, /)\n--\n\n"
"Load a language compiled as wasm.");
PyDoc_STRVAR(language_node_kind_for_id_doc,
"node_kind_for_id(self, id, /)\n--\n\n"
"Get the name of the node kind for the given numerical id.");
Expand Down Expand Up @@ -220,6 +339,12 @@ PyDoc_STRVAR(
"Create a new :class:`Query` from a string containing one or more S-expression patterns.");

static PyMethodDef language_methods[] = {
{
.ml_name = "from_wasm",
.ml_meth = (PyCFunction)language_from_wasm,
.ml_flags = METH_CLASS | METH_VARARGS,
.ml_doc = language_from_wasm_doc,
},
{
.ml_name = "node_kind_for_id",
.ml_meth = (PyCFunction)language_node_kind_for_id,
Expand Down Expand Up @@ -291,6 +416,8 @@ static PyGetSetDef language_accessors[] = {
PyDoc_STR("The number of valid states in this language."), NULL},
{"field_count", (getter)language_get_field_count, NULL,
PyDoc_STR("The number of distinct field names in this language."), NULL},
{"is_wasm", (getter)language_is_wasm, NULL,
PyDoc_STR("Check if the language came from a wasm module."), NULL},
{NULL},
};

Expand Down
35 changes: 35 additions & 0 deletions tree_sitter/binding/module.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <wasm.h>
#include "types.h"

extern PyType_Spec language_type_spec;
Expand All @@ -15,6 +16,8 @@ extern PyType_Spec range_type_spec;
extern PyType_Spec tree_cursor_type_spec;
extern PyType_Spec tree_type_spec;

void tsp_load_wasmtime_symbols();

// TODO(0.24): drop Python 3.9 support
#if PY_MINOR_VERSION > 9
#define AddObjectRef PyModule_AddObjectRef
Expand Down Expand Up @@ -62,6 +65,9 @@ static void module_free(void *self) {
Py_XDECREF(state->tree_type);
Py_XDECREF(state->query_error);
Py_XDECREF(state->re_compile);
Py_XDECREF(state->wasmtime_engine_type);
Py_XDECREF(state->ctypes_cast);
Py_XDECREF(state->c_void_p);
}

static struct PyModuleDef module_definition = {
Expand Down Expand Up @@ -147,6 +153,35 @@ PyMODINIT_FUNC PyInit__binding(void) {
if (namedtuple == NULL) {
goto cleanup;
}

PyObject *wasmtime_engine = import_attribute("wasmtime", "Engine");
if (wasmtime_engine == NULL) {
// No worries, disable functionality.
PyErr_Clear();
} else {
// Ensure wasmtime_engine is a PyTypeObject
if (!PyType_Check(wasmtime_engine)) {
PyErr_SetString(PyExc_TypeError, "wasmtime.Engine is not a type");
goto cleanup;
}
state->wasmtime_engine_type = (PyTypeObject *)wasmtime_engine;

tsp_load_wasmtime_symbols();
if (PyErr_Occurred()) {
goto cleanup;
}

state->ctypes_cast = import_attribute("ctypes", "cast");
if (state->ctypes_cast == NULL) {
goto cleanup;
}

state->c_void_p = import_attribute("ctypes", "c_void_p");
if (state->c_void_p == NULL) {
goto cleanup;
}
}

PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column");
PyObject *point_kwargs = PyDict_New();
PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter"));
Expand Down
Loading

0 comments on commit 3192357

Please sign in to comment.