Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
thorbjoernl committed Aug 5, 2024
1 parent 4836e13 commit dffdfa4
Show file tree
Hide file tree
Showing 14 changed files with 196 additions and 67 deletions.
33 changes: 33 additions & 0 deletions scripts/build_sqlite_test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import aerovaldb
import re


if os.path.exists("tests/test-db/sqlite/test.sqlite"):
os.remove("tests/test-db/sqlite/test.sqlite")

jsondb = aerovaldb.open("json_files:tests/test-db/json")
sqlitedb = aerovaldb.open("sqlitedb:tests/test-db/sqlite/test.sqlite")

data = jsondb.get_config(
"project", "experiment", access_type=aerovaldb.AccessType.FILE_PATH, default="{}"
)
print(data)
print(jsondb._get_uri_for_file(data))
print(
jsondb.get_by_uri(
jsondb._get_uri_for_file(data), access_type=aerovaldb.AccessType.JSON_STR
)
)

sqlitedb.put_by_uri(data, jsondb._get_uri_for_file(data))

for i, uri in enumerate(list(jsondb.list_all())):
print(f"Processing uri {uri}")
data = jsondb.get_by_uri(
uri, access_type=aerovaldb.AccessType.JSON_STR, default="{}"
)
sqlitedb.put_by_uri(data, uri)

print(f"jsondb number of assets: {len(list(jsondb.list_all()))}")
# print(f"sqlite number of assets: {len(list(sqlitedb.list_all()))}")
46 changes: 46 additions & 0 deletions scripts/tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import simplejson
import sqlite3


con = sqlite3.connect(":memory:")

data = simplejson.dumps({"test": 1234})

print(data)

cur = con.cursor()

cur.execute(
"""
CREATE TABLE test(key, value)
"""
)

con.commit()

cur.execute(
"""
INSERT INTO test
VALUES(?, ?)
""",
("test", data)
)

con.commit()

cur.execute(
"""
SELECT value FROM test
WHERE key='test'
"""
)
print(simplejson.loads(cur.fetchone()[0]))

import aerovaldb

with aerovaldb.open("tests/test-db/json") as db:
print(data)

#db.put_by_uri(data, "/v0/config/project/experiment")

print(db.get_by_uri("/v0/config/project/experiment", access_type=aerovaldb.AccessType.JSON_STR))
7 changes: 7 additions & 0 deletions src/aerovaldb/aerovaldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,3 +1196,10 @@ def _normalize_access_type(
return default

assert False

def list_all(self) -> Generator[str, None, None]:
"""Iterator to list over the URI of each object
stored in the current aerovaldb connection, returning
the URI of each.
"""
raise NotImplementedError
4 changes: 2 additions & 2 deletions src/aerovaldb/jsondb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def _read_json(self, file_path: str | Path) -> str:
return await f.read()

with open(abspath, "r") as f:
return f.read()
return f.read().replace("\n", "")

def _get(self, abspath: str) -> str:
"""Returns an element from the cache."""
Expand All @@ -140,7 +140,7 @@ def _get(self, abspath: str) -> str:

def _put(self, abspath: str, *, json: str, modified: float):
self._cache[abspath] = {
"json": json,
"json": "".join(json.split(r"\n")),
"last_modified": os.path.getmtime(abspath),
}
while self.size > self._max_size:
Expand Down
59 changes: 39 additions & 20 deletions src/aerovaldb/jsondb/jsonfiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
parse_uri,
parse_formatted_string,
build_uri,
extract_substitutions,
)
from .templatemapper import (
TemplateMapper,
Expand All @@ -34,9 +35,6 @@
from ..lock.lock import FakeLock, FileLock
from hashlib import md5
import simplejson # type: ignore
from ..sqlitedb.utils import (
extract_substitutions,
) # TODO: Move this to a more approriate location before merging PR.

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -302,7 +300,7 @@ async def _get(

if access_type == AccessType.JSON_STR:
raw = await self._cache.get_json(file_path, no_cache=not use_caching)
return json_dumps_wrapper(raw)
return raw

raw = await self._cache.get_json(file_path, no_cache=not use_caching)

Expand Down Expand Up @@ -474,27 +472,35 @@ def _get_uri_for_file(self, file_path: str) -> str:
:param file_path : The file_path.
"""
file_path = os.path.join(self._basedir, file_path)
file_path = os.path.relpath(file_path, start=self._basedir)

for route in self.PATH_LOOKUP:
templates = self._get_templates(route)
# templates = self._get_templates(route)
if file_path.startswith("reports/"):
str = "/".join(file_path.split("/")[1:3])
subs = parse_formatted_string("{project}/{experiment}", str)
else:
str = "/".join(file_path.split("/")[0:2])
subs = parse_formatted_string("{project}/{experiment}", str)

for t in templates:
route_arg_names = extract_substitutions(t)
# project = args["project"]
# experiment = args["experiment"]

try:
all_args = parse_formatted_string(t, f"./{file_path}")

route_args = {
k: v for k, v in all_args.items() if k in route_arg_names
}
kwargs = {
k: v for k, v in all_args.items() if not (k in route_arg_names)
}
except:
continue
else:
return build_uri(route, route_args, kwargs)
template = self._get_template(route, subs)
route_arg_names = extract_substitutions(route)

try:
all_args = parse_formatted_string(template, f"./{file_path}")

route_args = {k: v for k, v in all_args.items() if k in route_arg_names}
kwargs = {
k: v for k, v in all_args.items() if not (k in route_arg_names)
}
except Exception:
continue
else:
return build_uri(route, route_args, kwargs)

raise ValueError(f"Unable to build URI for file path {file_path}")

Expand Down Expand Up @@ -634,3 +640,16 @@ def lock(self):
return FileLock(self._get_lock_file())

return FakeLock()

def list_all(self):
# glb = glob.iglob()
glb = glob.iglob(os.path.join(self._basedir, "./**"), recursive=True)

for f in glb:
if os.path.isfile(f):
try:
uri = self._get_uri_for_file(f)
except (ValueError, KeyError):
continue
else:
yield uri
28 changes: 21 additions & 7 deletions src/aerovaldb/sqlitedb/sqlitedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
from aerovaldb.exceptions import UnsupportedOperation
from ..aerovaldb import AerovalDB
from ..routes import *
from .utils import extract_substitutions
from ..types import AccessType
from ..utils import json_dumps_wrapper, parse_uri, async_and_sync, build_uri
from ..utils import (
json_dumps_wrapper,
parse_uri,
async_and_sync,
build_uri,
extract_substitutions,
)
import os


Expand All @@ -31,7 +36,8 @@ class AerovalSqliteDB(AerovalDB):
"model",
"modvar",
"time",
)
),
ROUTE_MODELS_STYLE: ("project", "experiment"),
}

# This lookup table stores the name of the table in which json
Expand Down Expand Up @@ -128,7 +134,7 @@ def _initialize_db(self):

cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {table_name}({column_names},json,
CREATE TABLE IF NOT EXISTS {table_name}({column_names},json TEXT,
UNIQUE({column_names}))
"""
Expand Down Expand Up @@ -170,14 +176,16 @@ async def _get(self, route, route_args, *args, **kwargs):
WHERE
({columnlist}) = ({substitutionlist})
""",
route_args,
route_args | kwargs,
)
fetched = cur.fetchone()[0]
fetched = fetched.replace('\\"', '"')
if access_type == AccessType.JSON_STR:
return fetched

if access_type == AccessType.OBJ:
return simplejson.loads(fetched, allow_nan=True)
dt = simplejson.loads(fetched, allow_nan=True)
return dt

assert False # Should never happen.

Expand All @@ -189,7 +197,7 @@ async def _put(self, obj, route, route_args, *args, **kwargs):
table_name = AerovalSqliteDB.TABLE_NAME_LOOKUP[route]

columnlist, substitutionlist = self._get_column_list_and_substitution_list(
route_args
route_args | kwargs
)

json = obj
Expand All @@ -204,6 +212,7 @@ async def _put(self, obj, route, route_args, *args, **kwargs):
""",
route_args | kwargs,
)
self._con.commit()

@async_and_sync
async def get_by_uri(
Expand Down Expand Up @@ -232,4 +241,9 @@ async def get_by_uri(
async def put_by_uri(self, obj, uri: str):
route, route_args, kwargs = parse_uri(uri)

# if isinstance(obj, str):
# obj = "".join(obj.split(r"\n"))
await self._put(obj, route, route_args, **kwargs)

def list_all(self):
raise NotImplementedError
10 changes: 0 additions & 10 deletions src/aerovaldb/sqlitedb/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
import re


def extract_substitutions(template: str):
"""
For a python template string, extracts the names between curly brackets:
For example 'blah blah {test} blah {test2}' returns [test, test2]
"""
return re.findall(r"\{(.*?)\}", template)
16 changes: 13 additions & 3 deletions src/aerovaldb/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import re
import regex as re
import asyncio
import functools
from typing import Callable, ParamSpec, TypeVar
Expand All @@ -7,6 +7,15 @@
import urllib


def extract_substitutions(template: str):
"""
For a python template string, extracts the names between curly brackets:
For example 'blah blah {test} blah {test2}' returns [test, test2]
"""
return re.findall(r"\{(.*?)\}", template)


def json_dumps_wrapper(obj, **kwargs) -> str:
"""
Wrapper which calls simplejson.dumps with the correct options, known to work for objects
Expand All @@ -29,10 +38,11 @@ def parse_formatted_string(template: str, s: str) -> dict:
# First split on any keyword arguments, note that the names of keyword arguments will be in the
# 1st, 3rd, ... positions in this list
tokens = re.split(r"\{(.*?)\}", template)
keywords = tokens[1::2]

# keywords = tokens[1::2]
keywords = extract_substitutions(template)
# Now replace keyword arguments with named groups matching them. We also escape between keyword
# arguments so we support meta-characters there. Re-join tokens to form our regexp pattern

tokens[1::2] = map("(?P<{}>.*)".format, keywords)
tokens[0::2] = map(re.escape, tokens[0::2])
pattern = "".join(tokens)
Expand Down
16 changes: 0 additions & 16 deletions tests/sqlitedb/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +0,0 @@
import pytest
from aerovaldb.sqlitedb.utils import extract_substitutions


@pytest.mark.parametrize(
"template,result",
(
("{A}{B}{C}", {"A", "B", "C"}),
("{A}hello world{B} test {C}", {"A", "B", "C"}),
("", set()),
),
)
def test_extract_substitutions(template: str, result: set[str]):
l = extract_substitutions(template)

assert set(l) == result

This file was deleted.

This file was deleted.

Binary file added tests/test-db/sqlite/test.sqlite
Binary file not shown.
22 changes: 20 additions & 2 deletions tests/test_aerovaldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,17 @@ def tmpdb(tmp_path, dbtype: str) -> aerovaldb.AerovalDB:


@pytest.mark.asyncio
@pytest.mark.parametrize("resource", (("json_files:./tests/test-db/json",)))
@pytest.mark.parametrize(
"resource",
(
pytest.param(
"json_files:./tests/test-db/json",
),
pytest.param(
"sqlitedb:./tests/test-db/sqlite/test.sqlite",
),
),
)
@GET_PARAMETRIZATION
async def test_getter(resource: str, fun: str, args: list, kwargs: dict, expected):
"""
Expand All @@ -288,7 +298,15 @@ async def test_getter(resource: str, fun: str, args: list, kwargs: dict, expecte
assert data["path"] == expected


@pytest.mark.parametrize("resource", (("json_files:./tests/test-db/json",)))
@pytest.mark.parametrize(
"resource",
(
pytest.param(
"json_files:./tests/test-db/json",
),
pytest.param("sqlitedb:./tests/test-db/sqlite/test.sqlite"),
),
)
@GET_PARAMETRIZATION
def test_getter_sync(resource: str, fun: str, args: list, kwargs: dict, expected):
with aerovaldb.open(resource, use_async=False) as db:
Expand Down
Loading

0 comments on commit dffdfa4

Please sign in to comment.