diff --git a/src/charm.py b/src/charm.py index 5de7488..2e4cd76 100755 --- a/src/charm.py +++ b/src/charm.py @@ -12,9 +12,10 @@ from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider from ops.charm import CharmBase, CollectStatusEvent from ops.framework import StoredState -from ops.charm import RelationJoinedEvent, RelationDepartedEvent +from ops.charm import InstallEvent, RelationJoinedEvent, RelationDepartedEvent from ops.main import main from ops.model import ActiveStatus, BlockedStatus, ErrorStatus, Relation +from pathlib import Path from typing import List logger = logging.getLogger(__name__) @@ -22,11 +23,14 @@ class JujuControllerCharm(CharmBase): DB_BIND_ADDR_KEY = 'db-bind-address' + ALL_BIND_ADDRS_KEY = 'db-bind-addresses' + _stored = StoredState() def __init__(self, *args): super().__init__(*args) + self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.collect_unit_status, self._on_collect_status) self.framework.observe(self.on.config_changed, self._on_config_changed) self.framework.observe( @@ -34,7 +38,8 @@ def __init__(self, *args): self.framework.observe( self.on.website_relation_joined, self._on_website_relation_joined) - self._stored.set_default(db_bind_address='', last_bind_addresses=[], all_bind_addresses=dict()) + self._stored.set_default( + db_bind_address='', last_bind_addresses=[], all_bind_addresses=dict()) self.framework.observe( self.on.dbcluster_relation_changed, self._on_dbcluster_relation_changed) @@ -45,6 +50,12 @@ def __init__(self, *args): self.framework.observe( self.on.metrics_endpoint_relation_broken, self._on_metrics_endpoint_relation_broken) + def _on_install(self, event: InstallEvent): + """Ensure that the controller configuration file exists.""" + file_path = self._controller_config_path() + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + open(file_path, 'w+').close() + def _on_collect_status(self, event: CollectStatusEvent): if len(self._stored.last_bind_addresses) > 1: event.add_status(BlockedStatus( @@ -133,12 +144,13 @@ def _on_metrics_endpoint_relation_broken(self, event: RelationDepartedEvent): self.control_socket.remove_metrics_user(username) def _on_dbcluster_relation_changed(self, event): - """Ensure that a bind address for Dqlite is set in relation data, - if we can determine a unique one from the relation's bound space. + """Maintain our own bind address in relation data. If we are the leader, aggregate the bind addresses for all the peers, and ensure the result is set in the application data bag. + If the aggregate addresses have changed, rewrite the config file. """ - self._ensure_db_bind_address(event) + relation = event.relation + self._ensure_db_bind_address(relation) if self.unit.is_leader(): # The event only has *other* units so include this @@ -146,19 +158,33 @@ def _on_dbcluster_relation_changed(self, event): ip = self._stored.db_bind_address all_bind_addresses = {self.unit.name: ip} if ip else dict() - for unit in event.relation.units: - unit_data = event.relation.data[unit] + for unit in relation.units: + unit_data = relation.data[unit] if self.DB_BIND_ADDR_KEY in unit_data: all_bind_addresses[unit.name] = unit_data[self.DB_BIND_ADDR_KEY] if self._stored.all_bind_addresses == all_bind_addresses: return - event.relation.data[self.app]['db-bind-addresses'] = json.dumps(all_bind_addresses) - self._stored.all_bind_addresses = all_bind_addresses + relation.data[self.app][self.ALL_BIND_ADDRS_KEY] = json.dumps(all_bind_addresses) + self._update_config_file(all_bind_addresses) + else: + app_data = relation.data[self.app] + if self.ALL_BIND_ADDRS_KEY in app_data: + all_bind_addresses = json.loads(app_data[self.ALL_BIND_ADDRS_KEY]) + else: + all_bind_addresses = dict() + + if self._stored.all_bind_addresses == all_bind_addresses: + return + + self._update_config_file(all_bind_addresses) - def _ensure_db_bind_address(self, event): - ips = [str(ip) for ip in self.model.get_binding(event.relation).network.ingress_addresses] + def _ensure_db_bind_address(self, relation): + """Ensure that a bind address for Dqlite is set in relation data, + if we can determine a unique one from the relation's bound space. + """ + ips = [str(ip) for ip in self.model.get_binding(relation).network.ingress_addresses] self._stored.last_bind_addresses = ips if len(ips) > 1: @@ -170,9 +196,24 @@ def _ensure_db_bind_address(self, event): if self._stored.db_bind_address == ip: return - event.relation.data[self.unit].update({self.DB_BIND_ADDR_KEY: ip}) + logger.info('setting new DB bind address: %s', ip) + relation.data[self.unit].update({self.DB_BIND_ADDR_KEY: ip}) self._stored.db_bind_address = ip + def _update_config_file(self, bind_addresses): + file_path = self._controller_config_path() + with open(file_path) as conf_file: + conf = yaml.safe_load(conf_file) + + if not conf: + conf = dict() + conf['db-bind-addresses'] = bind_addresses + + with open(file_path, 'w') as conf_file: + yaml.dump(conf, conf_file) + + self._stored.all_bind_addresses = bind_addresses + def api_port(self) -> str: """Return the port on which the controller API server is listening.""" api_addresses = self._agent_conf('apiaddresses') @@ -199,6 +240,10 @@ def _agent_conf(self, key: str): agent_conf = yaml.safe_load(agent_conf_file) return agent_conf.get(key) + def _controller_config_path(self) -> str: + unit_num = self.unit.name.split('/')[1] + return f'/var/lib/juju/agents/controller-{unit_num}/agent.conf' + def metrics_username(relation: Relation) -> str: """