Skip to content

Commit

Permalink
started adding tests for provider manager
Browse files Browse the repository at this point in the history
  • Loading branch information
mdorier committed Jan 30, 2024
1 parent 43ec4e3 commit 90ef7fc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 8 deletions.
26 changes: 19 additions & 7 deletions python/mochi/bedrock/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pybedrock_server
import pymargo.core
import pymargo
from typing import Mapping, List
from .spec import ProcSpec, MargoSpec, PoolSpec, XstreamSpec, SSGSpec, AbtIOSpec, ProviderSpec
import json

Expand Down Expand Up @@ -227,7 +228,8 @@ def config(self) -> dict:

@property
def spec(self) -> list[ProviderSpec]:
return [ProviderSpec.from_dict(provider) for provider in self.config]
abt_spec = self._server.margo.spec.argobots
return [ProviderSpec.from_dict(provider, abt_spec) for provider in self.config]

def __len__(self):
return len(self._internal.providers)
Expand All @@ -244,12 +246,22 @@ def __delitem__(self, key: str) -> None:
def lookup(self, locator: str):
return Provider(self, self._internal.lookup_provider(locator))

def create(self, config: str|dict|ProviderSpec) -> Provider:
if isinstance(config, dict):
config = json.dumps(config)
elif isinstance(config, ProviderSpec):
config = config.to_json()
return Provider(self, self._internal.add_providers_from_json(config))
def create(self, name: str, type: str, provider_id: int, pool: str|Pool,
config: str|dict = "{}", dependencies: Mapping[str,str] = {},
tags: List[str] = []) -> Provider:
if isinstance(pool, Pool):
pool = pool.name
if isinstance(config, str):
config = json.loads(config)
info = {
"name": name,
"type": type,
"provider_id": provider_id,
"dependencies": dependencies,
"tags": tags,
"config": config
}
return Provider(self, self._internal.add_providers_from_json(json.dumps(info)))

def migrate(self, provider: str, dest_addr: str,
dest_provider_id: str, migration_config: str|dict = "{}",
Expand Down
2 changes: 1 addition & 1 deletion python/mochi/bedrock/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ class ProviderSpec:
validator=instance_of(dict),
factory=dict)
tags: List[str] = attr.ib(
validator=instance_of(List[str]),
validator=instance_of(List),
factory=list)

def to_dict(self) -> dict:
Expand Down
87 changes: 87 additions & 0 deletions python/mochi/bedrock/test_provider_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
import pymargo.logging
import mochi.bedrock.server as mbs
import mochi.bedrock.spec as spec


class TestProviderManager(unittest.TestCase):

def setUp(self):
config = {
"libraries": {
"module_a": "libModuleA.so",
"module_b": "libModuleB.so"
},
"providers": [
{
"name": "my_provider_A",
"type": "module_a",
"provider_id": 1
}
]
}
self.server = mbs.Server(address="na+sm", config=config)
self.server.margo.engine.logger.set_log_level(pymargo.logging.level.critical)

def tearDown(self):
self.server.finalize()
del self.server

def test_get_provider_manager(self):
providers = self.server.providers
self.assertIsInstance(providers, mbs.ProviderManager)
self.assertEqual(len(providers), 1)
provider_A = providers[0]
provider_B = providers["my_provider_A"]
self.assertEqual(provider_A.name, provider_B.name)
self.assertEqual(provider_A.type, provider_B.type)
self.assertEqual(provider_A.provider_id, provider_B.provider_id)
self.assertEqual(provider_A.handle, provider_B.handle)
with self.assertRaises(IndexError):
p = providers[1]
with self.assertRaises(mbs.BedrockException):
p = providers["bla"]

def test_provider_manager_config(self):
config = self.server.providers.config
self.assertIsInstance(config, list)
self.assertEqual(len(config), 1)
provider_1 = config[0]
self.assertIsInstance(provider_1, dict)
for key in ["name", "pool", "config", "provider_id", "dependencies", "tags", "type"]:
self.assertIn(key, provider_1)

def test_provider_manager_spec(self):
spec_list = self.server.providers.spec
self.assertIsInstance(spec_list, list)
for s in spec_list:
self.assertIsInstance(s, spec.ProviderSpec)

def test_add_provider(self):
providers = self.server.providers
providers.create(
name="my_provider_B",
pool="__primary__",
provider_id=2,
type="module_b")
self.assertEqual(len(providers), 2)
provider_A = providers[1]
provider_B = providers["my_provider_B"]
self.assertEqual(provider_A.name, provider_B.name)
self.assertEqual(provider_A.type, provider_B.type)
self.assertEqual(provider_A.provider_id, provider_B.provider_id)
self.assertEqual(provider_A.handle, provider_B.handle)

def test_remove_provider(self):
self.test_add_provider()
providers = self.server.providers
del providers["my_provider_B"]
self.assertEqual(len(providers), 1)
with self.assertRaises(IndexError):
p = providers[1]
with self.assertRaises(mbs.BedrockException):
p = providers["my_provider_B"]


if __name__ == '__main__':
unittest.main()

0 comments on commit 90ef7fc

Please sign in to comment.