diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..9d0f1d8 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,26 @@ +name: Mypy + +on: [push, workflow_dispatch, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | + bash -s -- --batch + conda env create -f env-dev.yml + conda run --name maize-dev pip install --no-deps . + conda run --name maize-dev pip install types-PyYAML types-toml + - name: Analysing the code with mypy + run: | + conda run --name maize-dev mypy --strict --explicit-package-bases maize/core maize/utilities maize/steps diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..afdc66a --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,33 @@ +name: Pylint + +on: [push, workflow_dispatch, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + id: conda-cache + with: + path: ./miniconda3 + key: ${{ runner.os }}-conda-${{ hashFiles('**/env-dev.yml') }} + restore-keys: | + ${{ runner.os }}-conda-${{ hashFiles('**/env-dev.yml') }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + if: steps.conda-cache.outputs.cache-hit != true + run: | + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | + bash -s -- --batch + conda env create -f env-dev.yml + conda run --name maize-dev pip install --no-deps . + - name: Analysing the code with pylint + run: | + conda run --name maize-dev pylint --exit-zero maize/core maize/utilities maize/steps diff --git a/.github/workflows/tests-quick.yml b/.github/workflows/tests-quick.yml new file mode 100644 index 0000000..8e8f4fb --- /dev/null +++ b/.github/workflows/tests-quick.yml @@ -0,0 +1,46 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Pytest-quick + +on: + pull_request: + workflow_dispatch: + workflow_call: + inputs: + role: + required: true + type: string + default: "test" + push: + branches: [ "master" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | + bash -s -- --batch + conda env create -f env-dev.yml + conda run --name maize-dev pip install --no-deps . + - name: Run fast test suite with pytest + run: | + conda run --name maize-dev pytest --cov=maize --cov-branch --cov-report xml -k "not random" ./tests + - name: Upload coverage + uses: actions/upload-artifact@v4 + with: + name: coverage-xml + path: ${{ github.workspace }}/coverage.xml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..47f4c78 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +.ipynb_checkpoints +.VSCodeCounter +__pycache__ +*.egg-info +.coverage +.vscode +.pylintrc +.*_cache +*_autosummary +build +coverage +docs/_build +docs/_static +docs/autosummary +docs/Makefile +docs/make.bat +test-config.toml +examples/testing.ipynb diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 0000000..8b28073 --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,8 @@ +Primary Authors +=============== + +- [Thomas Löhr](https://github.com/tlhr) +- Marco Klähn +- [Finlay Clark](https://github.com/fjclark) +- Bob van Schendel +- Lili Cao \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..6f54165 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,355 @@ +# CHANGELOG + +## Version 0.8.3 +## Fixes +- Fixed incorrect working directory being used in some cases + +## Version 0.8.2 +## Features +- Command batching option for `run_multi` + +## Changes +- Improved GPU status querying + +## Fixes +- Fixed excessive writing of temporary job output files +- Pinned Numpy to <2.0 for now + +## Version 0.8.1 +### Features +- Workflow submission improvements + +### Changes +- Added more detailed job logging + +### Fixes +- Fix for excessive queue polling on workflow submission +- Some minor SonarQube fixes + +## Version 0.8.0 +### Features +- Workflow submission functionality +- Added `wait_for_all` command to wait for multiple submitted workflows +- Added `gpu_info` utility to determine if it's safe to launch a command on GPU + +### Changes +- Parameters with a default are now automatically considered optional +- Modified JobHandler to cancel jobs on interrupt (#36) +- Multiple async job submission attempts + +### Fixes +- Ensure that iterables of paths are correctly handled +- Various typing fixes + +## Version 0.7.10 +### Changes +- `parallel` can now run with only constant inputs + +### Fixes +- All node superclasses are now searched for type information +- Fix for incorrect generic type updates +- Fixed inconsistencies between `flow.scratch` and `flow.config.scratch` + +## Version 0.7.9 +### Changes +- Added logger flushing on shutdown +- Slightly improved error message for duplicate nodes +- Added extra batch job logging + +## Version 0.7.8 +### Features +- Added version flag + +### Changes +- Unset workflow-level parameters no longer raise (fixes #22) +- Writing to yaml now encodes paths as strings + +### Fixes +- Fixed sonarqube bugs +- Fixed off-by-one in job completion logging + +## Version 0.7.7 +### Features +- Added active flag to all nodes to enable instant shutdown + +### Changes +- Improved graph connection logging + +## Version 0.7.6 +### Features +- Added `MultiParameter` hook functionality to allow complex parameter mappings + +### Changes +- As a result of the hook functionality to `MultiParameter`, this class is now a double generic. If you use it in your code directly (unlikely), you will need to update the type signature. + +### Fixes +- Fixed wrong order in shell variable expansion for config command spec +- Fixed `MergeLists` shutting down prematurely when an input closes while others still have data + +## Version 0.7.5 +### Features +- Added `Choice` plumbing node +- Added `ContentValidator` + +### Changes +- Added type information to serialized output + +### Fixes +- Fixed ordering for serialized output +- Casting for scratch path spec + +## Version 0.7.4 +### Changes +- Added option to specify scratch at workflow level + +### Fixes +- Install package-based config if available + +## Version 0.7.3 +### Features +- Added `IndexDistribute` +- Added `IntegerMap` +- Added `FileBuffer` + +### Changes +- Improved typing for dict utility functions +- Improved visualization colors +- Improved port activity determination for Multiplex +- Cleaned up workflow serialization + +### Fixes +- Fixed typechecking of files passed as strings + +## Version 0.7.2 +### Fixes +- Fixed paths passed to inputs as parameters in a separate JSON failing to be cast + +## Version 0.7.1 +### Features +- Allowed mapping of inputs in serialized workflows + +### Fixes +- Node list output is now sorted alphabetically + +## Version 0.7.0 +### Features +- Added diagrams to node documentation +- Added multiple new plumbing nodes (`TimeDistribute`, `MergeLists`, `CopyEveryNIter`) +- Node preparation is now cached, avoiding multiple dependency lookups +- `FileParameter` will now attempt casting strings to paths +- Allowed caching in `MultiInput` + +### Changes +- Job queues will now be kept saturated +- Deprecated static `MultiPort`s +- Environment variables in interpreter specification are now expanded +- Split `TestRig.setup_run` for explicit use with variable outputs + +### Fixes +- Fixed incorrect job submission counts +- Fixed typing issues in `TestRig` +- Added proper shutdown for `Multiplex` + +## Version 0.6.2 +### Features +- Interpreter - script pairs can now be non-path commands +- Added option to use list of paths for FileParameters + +### Fixes +- Updated guide + dev instructions + +## Version 0.6.1 +### Features +- Added package directory as a search path for global config + +### Changes +- Made class tags private + +## Version 0.6.0 +### Features +- Added send and receive hook support +- Added component tagging option + +### Changes +- Config dependencies are now converted to absolute paths +- Removed init files causing problems with contrib +- Refactored execution to use correct logging + +### Fixes +- Expanded test coverage +- Fix for `_prepare` calls with missing interpreter +- Fix for premature channel flush when handling converging data streams + +## Version 0.5.1 +### Features +- Added queue option to `JobResourceConfig` +- Added option to explicitly prefer batch submission + +### Changes +- Warning when receiving multiple times from the same port without looping +- Added warning for single char custom batch attributes +- Job submission will now only submit n_jobs if larger than max_jobs +- Improved file validation, will now wait for files until timeout +- Changed handling of flags to explicit `--flag` / `--no-flag` +- `prepare()` no longer requires a call to the parent method + +### Fixes +- Fix for receive not recognising optional unconnected ports +- Fixed looped nodes not handling cached input files properly +- Fix for `Workflow.from_dict` not recognizing input setting +- More robust batch job submissions +- Fixed occassional deadlocks under high logging loads +- Fixed `Return` nodes potentially freezing looped workflows + +## Version 0.5.0 +### Features +- Set parameters are logged at workflow start +- Added asynchronous command execution +- It is now possible to map free inputs on the workflow level +- Added checks for common graph construction problems +- Added support for CUDA MPS to run multiple processes on one GPU + +### Changes +- Set default batch polling interval to 120s +- Added functionality to skip node execution if all parameters are unset / optional +- `Void` can now take any number of inputs +- Dynamic workflow creation is now possible using `expose` +- Added `working_dir` option to `run_command` +- Improved workflow status reporting +- Custom job attributes are now formatted correctly based on batch system +- Status updates now show full path for nodes in subgraphs with duplicate names + +### Fixes +- Fixed missing cleanup when using relative scratch directory +- Avoid error when specifying duplicate `loop=True` parameter +- Fixed `typecheck` not handling dictionaries properly +- Fixed `common_parent` not breaking after divergence +- Fixed looped nodes not sending the correct file when dealing with identical names +- Temporary fix for paths not being parsed from YAML files + +## Version 0.4.1 +### Features +- Added command inputs to node `run_multi` method +- `FileChannel` can now send and receive dicts of files + +### Changes +- Changed `exclusive_use` for batch submission to `False` by default +- Changed required Python to 3.10 to avoid odd beartype behaviour +- Documentation cleanup + +### Fixes +- Fixed missing documentation propagation for inherited nodes +- Fixed `MultiPort` not being mapped properly in subgraphs +- Removed weird error code check +- Added missing `default` property to MultiInput +- Fixed misbehaving subgraph looping test + +## Version 0.4.0 +### Breaking changes +- Refactored looping system, only way to loop a node now is to use `loop=True` +- Removed dynamic interface building + +### Features +- Allowed setting walltime in per-job config + +### Changes +- Lowered logging noise +- Looping is now inherited from subgraphs +- Interface mapping sets attribute by default + +### Fixes +- Fixed error in building docs +- Fixed validation failures not showing command output +- Fixed tests to match parameterization requirement +- Fix for incorrect walltime parsing for batch submission +- Throw a proper error when mapping ports with existing names +- More verbose message for missing connection + +## Version 0.3.3 +### Features +- Added `--parameters` commandline option to override workflow parameters using a JSON file +- Added timeout option to `run_command` +- Added simple multiple file I/O nodes + +### Changes +- Unparameterized generic nodes will now cause graph construction to fail +- `FileParameter` will now cast strings to `Path` objects when setting +- Maize is now a proper namespace package + +### Fixes +- Fixed cascading generic nodes sometimes not passing files correctly +- Fixed overeager parameter checks for serialized workflows +- Fixed bug preventing `run_multi` from running without `batch_options` + +## Version 0.3.2 +### Changes +- Updated dependencies +- Added contribution guidelines +- Prepared for initial public release + +## Version 0.3.1 +### Features +- Added mode option to `MultiPort` +- Added `custom_attributes` option for batch jobs + +### Changes +- Better batch submission tests +- Batch submission now only requires `batch_options` + +### Fixes +- Fixed resource management issue +- Fixed file copy issues with the `LoadFile` node +- Fixed off-by-one in finished job count +- Various improvements to `parallel` + +## Version 0.3 +### Features +- Added pre-execution option for `run_command` +- Allow setting default parameters in global config +- Added interface / component serialization + +### Changes +- `Input` can now also act as a `Parameter` +- `Input` can now cache its data +- Overhauled `FileChannel`, now supports lists of files + +### Fixes +- Improved type checking for paths +- Fix for PSI/J execution freezing when using `fork` +- Resolved occasional node looping issues +- `SaveFile` now handles directories properly +- Inherited nodes now set datatypes properly +- Better missing graphviz information +- Improved working directory cleanup +- `LoadFile` now never moves files +- `run_command` now parses whitespace / strings correctly +- Missing config warnings + +## Version 0.2 +### Features +- Can now submit jobs to arbitrary resource manager systems (SLURM, PBS, etc) +- `run_command` now accepts `stdin` command input (Thanks Marco) +- Added tool to convert from functions to nodes +- Added experimental node parallelization macro +- Added utility nodes for batching data, with example workflow +- Added `Barrier` node for conditional sending, `Yes` as an equivalent to the Unix command +- Workflow visualization improvements (depth limit, node status) + +### Changes +- All execution is now performed with Exaworks PSI/J (new `psij-python` dependency) +- Dynamic typechecking now uses `beartype` +- Parameters with no default value will cause an error if not set to optional +- Status summaries now show approximate number of items in channel +- Channel size can be set globally for the whole workflow + +### Fixes +- Many fixes in dynamic typechecking +- More informative error messages +- Fixed a race condition during certain executions +- Fixed an issue where channel data would get flushed in long-running workflows +- Fixed issues relating to Graphviz +- Lowered chances of zombie / orphan processes +- Fixed issue where the config would not get read properly + +## Version 0.1 +Initial release. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c026f50 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,75 @@ +Contributing +============ +This is a summary of code conventions and contributing guidelines for *maize*. + +Issues & Pull requests +---------------------- +To report bugs or suggest new features, open an issue, or submit a pull request with your changes. + +Installation +------------ +If you will be working on both *maize* and *maize-contrib* it will be easiest to install the environment for *maize-contrib* first, as it encompasses all dependencies for maize and domain-specific extensions. You can then install both packages using editable installs. + +Code style +---------- +We use the [PEP8](https://peps.python.org/pep-0008/) convention with a 100 character line length - you can use `black` as a formatter with the `--line-length=100` argument. The code base features full static typing, use the following [mypy](https://mypy.readthedocs.io/en/stable/) command to check it: + +```shell +mypy --follow-imports=silent --ignore-missing-imports --strict maize +``` + +Type hints should only be omitted when either mypy or typing doesn't yet fully support the required feature, such as [higher-kinded types](https://github.com/python/typing/issues/548) or type-tuples ([PEP646](https://peps.python.org/pep-0646/)). + +> [!IMPORTANT] +> If you installed maize in editable mode you may need to specify its location with `$MYPYPATH` to ensure `mypy` can find it. See this [setuptools issue](https://github.com/pypa/setuptools/issues/3518). + +Documentation +------------- +Every public class or function should feature a full docstring with a full description of parameters / attributes. We follow the [numpy docstring](https://numpydoc.readthedocs.io/en/latest/format.html) standard for readability reasons. Docs are built using [sphinx](https://www.sphinx-doc.org/en/master/) in the `docs` folder: + +```shell +make html +``` + +There will be some warnings from `autosummary` that can be ignored. The built docs can then be found in `docs/_build/html`. To preview them locally you can start a local webserver running the following command in the `docs/_build/html` folder: + +```shell +python -m http.server 8000 +``` + +The docs are then available at `http://localhost:8000/index.html`. + +If you add a new feature, you should mention the new behaviour in the `userguide`, in the `cookbook`, and ideally add an example under `examples`. If the feature necessitated a deeper change to the fundamental design, you should also update `design`. + +Testing +------- +Tests are written using [pytest](https://docs.pytest.org/en/7.2.x/contents.html) and cover the lower-level components as well as higher-level graph execution, and can be run with: + +```shell +pytest --log-cli-level=DEBUG tests/ +``` + +Any new features or custom nodes should be covered by suitable tests. To make testing the latter a bit easier, you can use the `maize.utilities.testing.TestRig` class together with `maize.utilities.testing.MockChannel` if required. + +Coverage can be reported using: + +```shell +pytest tests/ -v --cov maize --cov-report html:coverage +``` + +New versions +------------ +To release a new version of maize, perform the following steps: + +1. Create a new branch titled `release-x.x.x` +2. Add your changes to `CHANGELOG.md` +3. Increment `maize.__version__` +4. Commit your changes +5. Rebuild and update the remote documentation (see above) +6. Create a tag using `git tag vx.x.x` +7. Push your changes with `git push` and `git push --tags` +8. Update `master`: + 1. `git checkout master` + 2. `git pull origin master` + 3. `git merge release-x.x.x` + 4. `git push origin master` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..1f120b7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,169 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + 1. Definitions. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + END OF TERMS AND CONDITIONS + APPENDIX: How to apply the Apache License to your work. + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + Copyright 2022 Molecular AI, AstraZeneca + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..82e5ba8 --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +![Maturity level-1](https://img.shields.io/badge/Maturity%20Level-ML--1-yellow) + + + +*maize* is a graph-based workflow manager for computational chemistry pipelines. + +It is based on the principles of [*flow-based programming*](https://github.com/flowbased/flowbased.org/wiki) and thus allows arbitrary graph topologies, including cycles, to be executed. Each task in the workflow (referred to as *nodes*) is run as a separate process and interacts with other nodes in the graph by communicating through unidirectional *channels*, connected to *ports* on each node. Every node can have an arbitrary number of input or output ports, and can read from them at any time, any number of times. This allows complex task dependencies and cycles to be modelled effectively. + +This repository contains the core workflow execution functionality. For domain-specific steps and utilities, you should additionally install [**maize-contrib**](https://github.com/MolecularAI/maize-contrib), which will have additional dependencies. + +Teaser +------ +A taste for defining and running workflows with *maize*. + +```python +"""A simple hello-world-ish example graph.""" + +from maize.core.interface import Parameter, Output, MultiInput +from maize.core.node import Node +from maize.core.workflow import Workflow + +# Define the nodes +class Example(Node): + data: Parameter[str] = Parameter(default="Hello") + out: Output[str] = Output() + + def run(self) -> None: + self.out.send(self.data.value) + + +class ConcatAndPrint(Node): + inp: MultiInput[str] = MultiInput() + + def run(self) -> None: + result = " ".join(inp.receive() for inp in self.inp) + self.logger.info("Received: '%s'", result) + + +# Build the graph +flow = Workflow(name="hello") +ex1 = flow.add(Example, name="ex1") +ex2 = flow.add(Example, name="ex2", parameters=dict(data="maize")) +concat = flow.add(ConcatAndPrint) +flow.connect(ex1.out, concat.inp) +flow.connect(ex2.out, concat.inp) + +# Check and run! +flow.check() +flow.execute() +``` + +Installation +------------ +If you plan on not modifying maize, and will be using [maize-contrib](https://github.com/MolecularAI/maize-contrib), then you should just follow the installation instructions for the latter. Maize will be installed automatically as a dependency. + +Note that [maize-contrib](https://github.com/MolecularAI/maize-contrib) requires several additional domain-specific packages, and you should use its own environment file instead if you plan on using these extensions. + +To get started quickly with running maize, you can install from an environment file: + +```bash +conda env create -f env-users.yml +conda activate maize +pip install --no-deps ./ +``` + +If you want to develop the code or run the tests, use the development environment and install the package in editable mode: + +```bash +conda env create -f env-dev.yml +conda activate maize-dev +pip install --no-deps ./ +``` + +### Manual install +Maize requires the following packages and also depends on python 3.10: + +- dill +- networkx +- pyyaml +- toml +- numpy +- matplotlib +- graphviz +- beartype +- psij-python + +We also strongly recommend the installation of [mypy](https://mypy.readthedocs.io/en/stable/). To install everything use the following command: + +```bash +conda install -c conda-forge python=3.10 dill networkx yaml toml mypy +``` + +If you wish to develop or add additional modules, the following additional packages will be required: + +- pytest +- sphinx + +Docs +---- +You can find guides, examples, and the API in the [**documentation**](https://molecularai.github.io/maize). + +Status +------ +*maize* is still in an experimental stage, but the core of it is working: +- Arbitrary workflows with conditionals and loops +- Subgraphs, Subsubgraphs, ... +- Type-safe channels, graph will not build if port types mismatch +- Nodes for broadcasting, merging, round-robin, ... +- Potential deadlock warnings +- Multiple retries +- Fail-okay nodes +- Channels can send most types of data (using dill in the background) +- Commandline exposure +- Custom per-node python environments diff --git a/docs/_templates/custom-base.rst b/docs/_templates/custom-base.rst new file mode 100644 index 0000000..5536fa1 --- /dev/null +++ b/docs/_templates/custom-base.rst @@ -0,0 +1,5 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/docs/_templates/custom-class.rst b/docs/_templates/custom-class.rst new file mode 100644 index 0000000..008baab --- /dev/null +++ b/docs/_templates/custom-class.rst @@ -0,0 +1,32 @@ +{{ name | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :show-inheritance: + :inherited-members: + + {% block methods %} + .. automethod:: __init__ + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Attributes') }} + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/_templates/custom-module.rst b/docs/_templates/custom-module.rst new file mode 100644 index 0000000..4850bb1 --- /dev/null +++ b/docs/_templates/custom-module.rst @@ -0,0 +1,69 @@ +{{ name | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: {{ _('Module Attributes') }} + + .. autosummary:: + :toctree: + :template: custom-base.rst + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: {{ _('Functions') }} + + .. autosummary:: + :toctree: + :template: custom-base.rst + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: {{ _('Classes') }} + + .. autosummary:: + :toctree: + :template: custom-class.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: {{ _('Exceptions') }} + + .. autosummary:: + :toctree: + :template: custom-base.rst + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. rubric:: Modules + +.. autosummary:: + :toctree: + :template: custom-module.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..83c814a --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,99 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +# pylint: disable=all + +import sys +import time +import os + +sys.path.insert(0, os.path.abspath("..")) +import maize.maize + +project = "maize" +copyright = f"{time.localtime().tm_year}, Molecular AI group" +author = "Thomas Löhr" +release = version = maize.maize.__version__ + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx.ext.graphviz", +] + +autosummary_generate = True +add_module_names = True + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "maize-contrib": ("https://molecularai.github.io/maize-contrib/doc", None), +} + + +# These handlers ensure that the value for `required_callables` is always emitted +def process_required_callables_sig(app, what, name, obj, options, signature, return_annotation): + if what == "attribute" and ( + name.endswith("required_callables") or name.endswith("required_packages") + ): + options["no-value"] = False + else: + options["no-value"] = True + + +def include_interfaces(app, what, name, obj, skip, options): + if what == "attribute" and ( + name.endswith("required_callables") or name.endswith("required_packages") + ): + return False + return None + + +def setup(app): + app.connect("autodoc-process-signature", process_required_callables_sig) + app.connect("autodoc-skip-member", include_interfaces) + + +# -- AZ Colors +_COLORS = { + "mulberry": "rgb(131,0,81)", + "lime-green": "rgb(196,214,0)", + "navy": "rgb(0,56,101)", + "graphite": "rgb(63,68,68)", + "light-blue": "rgb(104,210,223)", + "magenta": "rgb(208,0,111)", + "purple": "rgb(60,16,83)", + "gold": "rgb(240,171,0)", + "platinum": "rgb(157,176,172)", +} + +graphviz_output_format = "svg" + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_title = "maize" +html_logo = "resources/maize-logo.svg" +html_theme = "furo" +html_theme_options = { + "sidebar_hide_name": True, + "light_css_variables": { + "color-brand-primary": _COLORS["mulberry"], + "color-brand-content": _COLORS["mulberry"], + "color-api-name": _COLORS["navy"], + "color-api-pre-name": _COLORS["navy"], + }, +} +html_static_path = ["_static"] diff --git a/docs/docs/cookbook.rst b/docs/docs/cookbook.rst new file mode 100644 index 0000000..8f62592 --- /dev/null +++ b/docs/docs/cookbook.rst @@ -0,0 +1,318 @@ +Cookbook +======== + +I want to create a... + +Simple linear workflow +---------------------- +We will create a simple workflow that just reads some data, processes it, and logs the result. First, import the required nodes and modules: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, LogResult + from maize.steps.your_module import YourNode + +Then construct your workflow by instantiating a :class:`~maize.core.workflow.Workflow` object and adding nodes: + +.. code-block:: python + + flow = Workflow(name="example") + load = flow.add(LoadData[int], parameters={"data": 17}) + my_node = flow.add(YourNode, parameters={"param": False}) + result = flow.add(LogResult) + +Now connect them together: + +.. code-block:: python + + flow.connect(load.out, my_node.inp) + flow.connect(my_node.out, result.inp) + +Finally check everything's okay and run! + +.. code-block:: python + + flow.check() + flow.execute() + +You should see the result being written to the log. + +Cyclic workflow +--------------- +Cyclic workflows often appear whenever we want to iteratively improve a solution until a certain threshold has been reached. The general outline for such a workflow will look like this: + +.. image:: ../resources/cyclic-workflow.svg + +This means that the node labelled ``Calc`` will return only once the final solution has been produced and sent off to the final node. It is further important that ``Calc`` runs in a loop (e.g. by using :meth:`~maize.core.node.Node.loop`) to allow continuous processing. To build this workflow in maize, instantiate your workflow and add the required nodes as normal: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, LogResult + from maize.steps.plumbing import Merge + + flow = Workflow(name="cycle") + load = flow.add(LoadData[int], parameters={"data": 17}) + merge = flow.add(Merge[int]) + calc = flow.add(Calc) + result = flow.add(LogResult) + +We can now connect everything together analogously to the diagram above: + +.. code-block:: python + + flow.connect_all( + (load.out, merge.inp), + (merge.out, calc.inp), + (calc.out, merge.inp), + (calc.out_final, result.inp) + ) + +Note that although it appears like we are connecting two nodes to the same input, the input in :class:`~maize.steps.plumbing.Merge` is a :class:`~maize.core.interface.MultiInput`, which means that it will create additional ports as necessary. + +Conditionally branched workflow +------------------------------- +Remember that in maize, all nodes are constantly running as their own process, no matter if they have data to process or not. This makes branched workflows easy to implement: We only need to output data to the corresponding port. Here's an example node definition: + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import Input, Output + + class Condition(Node): + inp: Input[int] = Input() + out_a: Output[int] = Output() + out_b: Output[int] = Output() + + def run(self) -> None: + data = self.inp.receive() + if data < 10: + self.out_a.send(data) + else: + self.out_b.send(data) + +You can of course optionally wrap this in a loop (using :meth:`~maize.core.node.Node.loop`) and also define the condition for branching using a :class:`~maize.core.interface.Parameter` or even using another input. Creating the workflow can now be done as before: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, LogResult + + flow = Workflow(name="cycle") + load = flow.add(LoadData[int], parameters={"data": 17}) + cond = flow.add(Condition) + out_a = flow.add(LogResult, name="out_a") + out_b = flow.add(LogResult, name="out_b") + flow.connect_all( + (load.out, cond.inp), + (cond.out_a, out_a.inp), + (cond.out_b, out_b.inp) + ) + + +Pass a result back to python +---------------------------- +We can use the special :class:`~maize.steps.io.Return` node to transfer a value back to the main process. This can be useful if you want to incorporate a workflow into an existing python script without having to save and load from a file. All you need to do is call :meth:`~maize.steps.io.Return.get` after workflow execution: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, Return + + flow = Workflow(name="example") + load = flow.add(LoadData[int], parameters={"data": 17}) + result = flow.add(Return[int]) + flow.connect(load.out, result.inp) + + flow.check() + flow.execute() + + res = result.get() + # 17 + +.. _workflow-load-balance: + +Workflow with load balancing +---------------------------- +Using nodes from the :mod:`~maize.steps.plumbing` module allows to distribute data over multiple nodes. We assume our data arrives in the form of a sequence of items that should be split across nodes and then collected together again at the end. We'll use :class:`~maize.steps.plumbing.Delay` nodes to emulate our workers: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, LogResult + from maize.steps.plumbing import Scatter, Accumulate, RoundRobin, Merge, Delay + + flow = Workflow(name="balance") + load = flow.add(LoadData, parameters={"data": ["a", "b", "c"]}) + + # Decomposes our list into items and sends them separately + scatter = flow.add(Scatter[str]) + + # Sends each item it receives to a different output + bal = flow.add(RoundRobin[str], name="bal") + + # Our processing nodes + worker1 = flow.add(Delay[str], name="worker1") + worker2 = flow.add(Delay[str], name="worker2") + worker3 = flow.add(Delay[str], name="worker3") + + # Merges multiple inputs into one output + merge = flow.add(Merge[str]) + + # Accumulate multiple items into one list + accu = flow.add(Accumulate[str], parameters={"n_packets": 3}) + out = flow.add(LogResult) + + flow.connect_all( + (load.out, scatter.inp), + (scatter.out, bal.inp), + (bal.out, worker1.inp), + (bal.out, worker2.inp), + (bal.out, worker3.inp), + (worker1.out, merge.inp), + (worker2.out, merge.inp), + (worker3.out, merge.inp), + (merge.out, accu.inp), + (accu.out, out.inp) + ) + +Alternatively, you can make use of the :func:`~maize.utilities.macros.parallel` macro to automate this sometimes tedious procedure: + +.. code-block:: python + + from maize.utilities.macros import parallel + from maize.core.workflow import Workflow + from maize.steps.io import LoadData, LogResult + from maize.steps.plumbing import Scatter, Accumulate, RoundRobin, Merge, Delay + + flow = Workflow(name="balance") + load = flow.add(LoadData, parameters={"data": ["a", "b", "c"]}) + + # Decomposes our list into items and sends them separately + scatter = flow.add(Scatter[str]) + + # Apply our macro + worker_subgraph = flow.add(parallel(Delay[str], n_branches=3)) + + # Accumulate multiple items into one list + accu = flow.add(Accumulate[str], parameters={"n_packets": 3}) + out = flow.add(LogResult) + + flow.connect_all( + (load.out, scatter.inp), + (scatter.out, worker_subgraph.inp), + (worker_subgraph.out, accu.inp), + (accu.out, out.inp) + ) + +Workflow as a script +-------------------- +We can make a workflow callable on the commandline as a normal script by using the :func:`~maize.utilities.io.setup_workflow` function and exposing workflow parameters using :meth:`~maize.core.graph.Graph.map`: + +.. code-block:: python + + from maize.core.workflow import Workflow + from maize.utilities.io import setup_workflow + from maize.steps.io import LoadData, LogResult + from maize.steps.your_module import YourNode + + if __name__ == "__main__": + flow = Workflow(name="example") + load = flow.add(LoadData[int], parameters={"data": 17}) + my_node = flow.add(YourNode, parameters={"param": False}) + result = flow.add(LogResult) + + flow.connect(load.out, my_node.inp) + flow.connect(my_node.out, result.inp) + + flow.map(load.data, my_node.param) + + setup_workflow(flow) + +You will now have a workflow that behaves as a normal script, with a help message listing all maize and workflow-specific parameters (using ``-h`` or ``--help``). + +Workflow with ultra-high throughput +----------------------------------- +You may sometimes run into situations where you have a very large amount of individual datapoints to send through a workflow. For instance, you might have a simple docking workflow and want to dock 100000 SMILES codes in one go. For some machines, this naive implementation might actually work: + +.. code-block:: python + + from pathlib import Path + + from maize.core.workflow import Workflow + from maize.steps.io import Return, Void + from maize.steps.mai.docking import AutoDockGPU + from maize.steps.mai.molecule import LoadSmiles, Smiles2Molecules + + flow = Workflow(name="vina") + smi = flow.add(LoadSmiles) + gyp = flow.add(Smiles2Molecules) + adg = flow.add(AutoDockGPU) + sco = flow.add(Return[list[float]], name="scores") + log = flow.add(Void) + + smi.path.set(Path("smiles-100k.smi")) + adg.grid_file.set(Path("grid/rec.maps.fld")) + adg.scores_only.set(True) + adg.strict.set(False) + + flow.connect_all( + (smi.out, gyp.inp), + (gyp.out, adg.inp), + (adg.out, log.inp), + (adg.out_scores, sco.inp), + ) + flow.check() + flow.execute() + +Here, 100000 SMILES will be loaded and sent to the embedding and docking steps at once. This can cause memory issues with the channels and possibly the backend software that you're using. It is also inefficient as docking will have to wait for all molecules to be embedded before it can start docking. A simple solution is to split the data into chunks and send these individual batches to looped variants of the inner processing nodes. This can be accomplished with :class:`~maize.steps.plumbing.Batch` and :class:`~maize.steps.plumbing.Combine`: + +.. code-block:: python + + from pathlib import Path + + from maize.core.workflow import Workflow + from maize.steps.io import Return, Void + from maize.steps.plumbing import Batch, Combine + from maize.steps.mai.docking import AutoDockGPU + from maize.steps.mai.molecule import LoadSmiles, Smiles2Molecules + + flow = Workflow(name="vina") + smi = flow.add(LoadSmiles) + + # Split the data into batches + bat = flow.add(Batch[str]) + + # Important: You must specify looped execution for the inner nodes! + gyp = flow.add(Smiles2Molecules, loop=True) + adg = flow.add(AutoDockGPU, loop=True) + + # Reassemble original data shape + com = flow.add(Combine[NDArray[np.float32]]) + sco = flow.add(Return[list[float]], name="scores") + log = flow.add(Void) + + smi.path.set(Path("smiles-100k.smi")) + adg.grid_file.set(Path("grid/rec.maps.fld")) + adg.scores_only.set(True) + adg.strict.set(False) + + flow.connect_all( + (smi.out, bat.inp), + (bat.out, gyp.inp), + (gyp.out, adg.inp), + (adg.out, log.inp), + (adg.out_scores, com.inp), + (com.out, sco.inp) + ) + + # This is the number of batches we will use + n_batches = flow.combine_parameters(bat.n_batches, com.n_batches) + n_batches.set(100) + + flow.check() + flow.execute() + +The embedding and docking steps will now only receive batches of 100 molecules at a time, avoiding potential memory issues and allowing parallel execution of embedding and docking. This technique can be easily combined with `load balancing <#workflow-load-balance>`_. + diff --git a/docs/docs/design.rst b/docs/docs/design.rst new file mode 100644 index 0000000..1c02e74 --- /dev/null +++ b/docs/docs/design.rst @@ -0,0 +1,60 @@ +Design +====== +This document describes the design of the core underlying *maize*, intended for anyone wishing to make modifications or just get an overview of the code. While we will briefly cover some user-facing aspects, most details on the exact use can be found in the :doc:`user guide `. + +At its core, *maize* is an implementation of the flow-based programming paradigm. A workflow is modelled as a graph, with individual nodes performing certain tasks. Nodes are connected among each other through channels, and this is indeed the only communication between components. Contrary to directed-acyclic-graph execution, each node runs *simultaneously* in its own process, and performs computation as soon as one or more inputs are received. + +Nodes and Graphs +---------------- +Nodes (:mod:`maize.core.node`) are defined by creating a class inheriting from :class:`~maize.core.node.Node`. It requires parameters as well as at least one input or output port to be declared in the class body. It also requires a :meth:`~maize.core.node.Node.run` method with the actual code to be run as part of the graph. The user can receive or send data at any time from any of the ports and make use of the parameters. + +Subgraphs (:mod:`maize.core.graph`) allow logical groupings of individual nodes and are defined by inheriting from :class:`~maize.core.graph.Graph` and defining a :meth:`~maize.core.graph.Graph.build` method. This method defines the graph structure by adding nodes and connecting their ports using the :meth:`~maize.core.graph.Graph.add` and :meth:`~maize.core.graph.Graph.connect` methods respectively. The :meth:`~maize.core.graph.Graph.map_parameters` and :meth:`~maize.core.graph.Graph.map_port` methods allow one to reference ports and parameters from contained nodes and expose them externally, thus making the subgraph appear like a separate node. In fact, both :class:`~maize.core.node.Node` and :class:`~maize.core.graph.Graph` inherit most functionality from :class:`~maize.core.component.Component`, representing a general node in a graph. + +Workflows +--------- +Workflows (:mod:`maize.core.workflow`) are defined by instantiating a new :class:`~maize.core.workflow.Workflow` object and adding nodes or subgraphs using the :meth:`~maize.core.graph.Graph.add` method, as one would for a subgraph. These are connected using :meth:`~maize.core.graph.Graph.connect`, and parameters can be declared for commandline use using :meth:`~maize.core.graph.Graph.map_parameters`. The workflow may not have any unconnected ports though. This and other basic properties can be checked using :meth:`~maize.core.graph.Graph.check`. Workflows have the same functionality as graphs, but with some additional methods to allow for instantiating from a file or saving state. They also expose the :meth:`~maize.core.workflow.Workflow.execute` method for actually running the workflow. + +While the computational layout of a *maize* workflow is obviously a graph, it is also a tree: + +.. image:: ../resources/graph-as-tree.svg + :width: 70 % + :align: center + +The workflow object sits at the root, with subgraphs representing branches, and individual nodes as leaves. Every class in this hierarchy inherits from :class:`~maize.core.component.Component`, encapsulating behaviour for referring to parents and children in the tree structure, as well as addressing them. Many properties of a node are by default inherited from the parent object: + +.. image:: ../resources/component-class-diagram.svg + :width: 50 % + :align: center + +Interfaces +---------- +Interfaces (:mod:`maize.core.interface`) describe how nodes, graphs, and workflows interact with each other and the outside world. There are two main types of interfaces - :class:`~maize.core.interface.Parameter` and :class:`~maize.core.interface.Port`. The former describes any values passed to a component at compile time, before actually running the workflow. These will typically be initial inputs, configuration files, or program running options. The latter describes how a component communicates with others, specifically in the form of inputs (:class:`~maize.core.interface.Input`) and outputs (:class:`~maize.core.interface.Output`). Every interface inherits from :class:`~maize.core.interface.Interface` and contains a name and reference to it's owning component. + +.. image:: ../resources/interface-class-diagram.svg + :align: center + +Because interfaces are declared as class attributes to avoid excess boilerplate, they use a special :meth:`~maize.core.interface.Interface.build` method to create a new separate instance when building the parent component. A further unique property of all interfaces is that they are type-safe, i.e. each interface has a specific datatype (using the python :mod:`typing` module). Type consistency can be checked by static type-checkers such as `mypy `_, but the type information is also saved in a :attr:`~maize.core.interface.Interface.datatype` attribute so it can be used at compile- or run-time to verify if connections are compatible or parameters conform to the correct type. + +Running +------- + +Workflow +^^^^^^^^ +*maize* can be divided into two separate systems - the code that is run as part of the main process (while constructing the graph), and code that is run in a separate child process, typically the individual nodes. The former includes the class bodies and :meth:`~maize.core.graph.Graph.build` methods, while the latter will always be the :meth:`~maize.core.node.Node.run` method. Upon calling :meth:`~maize.core.workflow.Workflow.execute` on a workflow, *maize* will start each :meth:`~maize.core.node.Node.run` method in a separate python process and receive messages from each node. These messages include status updates (:class:`~maize.core.runtime.StatusUpdate`) and possible errors. The workflow can stop in the following ways: + +1. *maize* collects enough status updates indicating stopped or completed nodes to complete the graph. +2. One of the nodes sets the workflow-wide shutdown signal (:attr:`~maize.core.workflow.Workflow.signal`) +3. *maize* catches an error raised by a node (as long as the node is not configured to ignore it) + +Upon shutdown, *maize* will attempt to join all processes with a timeout. In addition to the nodes, a separate :class:`~maize.core.runtime.Logger` process is also started, and shutdown last. + +Node +^^^^ +While the :meth:`~maize.core.node.Node.run` method is user defined, it is not called directly by the workflow. Instead, it is wrapped by a general :meth:`~maize.core.node.Node.execute` method responsible for cleanly executing the user code. It (and the private :meth:`~maize.core.node.Node._attempt_loop`) is responsible for attempting to run any code multiple times if so requested, handling certain logging aspects, and sending errors on to the main process. + +A key aspect of the individual node execution procedure is how ports are handled. Because channels have a limited capacity, and may be full or empty when trying to send or receive items, they need to be polled in such a way as to avoid unresponsive or frozen processes. This is generally accomplished by attempting to :meth:`~maize.core.interface.Output.send` or :meth:`~maize.core.interface.Input.receive` with a timeout continuously in a loop and monitoring the parent shutdown signal and connected channel. If, for example, an upstream node has finished processing, the receiving input port will attempt to receive a final item and then shutdown the port. This can cause an intentional shutdown cascade, mimicking the behaviour expected from a directed acyclic graph. + +Nodes can be classified into two groups: ones performing only a single calculation, and ones running continuously. The latter should use the :meth:`~maize.core.node.Node.loop` generator to allow a clean shutdown, and make use of the :meth:`~maize.core.interface.Input.ready` method to check if data can be received from a port. This is to ensure that optional inputs can be skipped, for example when merging data from different sources (see :class:`~maize.core.steps.Merge` for an example). + +.. caution:: + Using continuous nodes with cyclic topologies can easily result in deadlock-prone workflows. Make sure you have a clean exit strategy for each node using :meth:`~maize.core.node.Node.loop`. diff --git a/docs/docs/development.rst b/docs/docs/development.rst new file mode 100644 index 0000000..5cb5243 --- /dev/null +++ b/docs/docs/development.rst @@ -0,0 +1,77 @@ +Development +=========== +This is a summary of code conventions used in *maize*. + +Installation +------------ +If you will be working on both *maize* and *maize-contrib* it will be easiest to install the environment for *maize-contrib* first, as it encompasses all dependencies for maize and domain-specific extensions. You can then install both packages using editable installs. + +Code style +---------- +We use the :pep:`8` convention with a 100 character line length - you can use ``black`` as a formatter with the ``--line-length=100`` argument. The code base features full static typing, use the following `mypy `_ command to check it: + +.. code-block:: shell + + mypy --explicit-package-bases --strict maize/steps maize/core maize/utilities + +You may need to install missing types using :code:`mypy --install-types`. Type hints should only be omitted when either mypy or typing doesn't yet fully support the required feature, such as `higher-kinded types `_ or type-tuples (:pep:`646`). + +.. caution:: + If you installed maize in editable mode you may need to specify its location with ``MYPYPATH`` to ensure ``mypy`` can find it. See this `setuptools issue `_. + +Documentation +------------- +Every public class or function should feature a full docstring with a full description of parameters / attributes. We follow the `numpy docstring `_ standard for readability reasons. Docs are built using `sphinx `_ in the ``docs`` folder: + +.. code-block:: shell + + cd docs + sphinx-build -b html ./ _build/html + +There will be some warnings from ``autosummary`` that can be ignored. The built docs can then be found in ``docs/_build/html``. To preview them locally you can start a local webserver running the following command in the ``docs/_build/html`` folder: + +.. code-block:: shell + + python -m http.server 8000 + +The docs are then available at ``_. + +If you add a new feature, you should mention the new behaviour in the :doc:`userguide`, in the :doc:`cookbook`, and ideally add an example under :doc:`examples`. If the feature necessitated a deeper change to the fundamental design, you should also update :doc:`design`. + +Testing +------- +Tests are written using `pytest `_ and cover the lower-level components as well as higher-level graph execution, and can be run with: + +.. code-block:: shell + + pytest --log-cli-level=DEBUG tests/ + +Any new features or custom nodes should be covered by suitable tests. To make testing the latter a bit easier, you can use the :class:`~maize.utilities.testing.TestRig` class together with :class:`~maize.utilities.testing.MockChannel` if required. + +Coverage can be reported using: + +.. code-block:: shell + + pytest tests/ -v --cov maize --cov-report html:coverage + +New versions +------------ +To release a new version of maize, perform the following steps: + +1. Create a new branch titled ``release-x.x.x`` +2. Add your changes to ``CHANGELOG.md`` +3. Increment :attr:`maize.__version__` +4. Commit your changes +5. Rebuild and update the remote documentation (see above) +6. Create a tag using :code:`git tag vx.x.x` +7. Push your changes with :code:`git push` and :code:`git push --tags` +8. Update ``master``: + + 1. :code:`git checkout master` + 2. :code:`git pull origin master` + 3. :code:`git merge release-x.x.x` + 4. :code:`git push origin master` + +9. Create a wheel for bundling with *maize-contrib* or other dependent repositories: + + :code:`pip wheel --no-deps .` diff --git a/docs/docs/examples.rst b/docs/docs/examples.rst new file mode 100644 index 0000000..4dbab55 --- /dev/null +++ b/docs/docs/examples.rst @@ -0,0 +1,18 @@ +Examples +======== + +Simple DAG +^^^^^^^^^^ +A very simple directed acyclic graph workflow: + +.. literalinclude:: ../../examples/helloworld.py + :language: python + :linenos: + +Simple DCG +^^^^^^^^^^ +An example of a workflow containing a cycle: + +.. literalinclude:: ../../examples/simpledcg.py + :language: python + :linenos: diff --git a/docs/docs/glossary.rst b/docs/docs/glossary.rst new file mode 100644 index 0000000..96e57e7 --- /dev/null +++ b/docs/docs/glossary.rst @@ -0,0 +1,38 @@ +Glossary +======== +Brief definitions of terms used by the *maize* documentation. + +.. glossary:: + + component + A component of a graph, or the graph itself. It is also the base class to :class:`~maize.core.node.Node` and :class:`~maize.core.graph.Graph`. + + node + A :term:`component` of a graph that contains no further nodes, i.e. it is atomic, or from a tree point-of-view, a leaf node. See :class:`~maize.core.node.Node`. + + graph + A :term:`component` that contains multiple nodes. If it's the root-level graph it is referred to as a :term:`workflow`. + + subgraph + A :term:`graph` that is itself a :term:`component` of a graph, i.e. not a top-level :term:`workflow`. + + workflow + The top / root-level :term:`graph` containing all subgraphs or nodes. See :class:`~maize.core.workflow.Workflow`. + + interface + An interface to a :term:`component`, i.e. either some kind of :term:`parameter` or :term:`port`. See the :class:`~maize.core.interface.Interface` base class. + + port + A port can be either an :term:`input` or :term:`output` and represents the primary :term:`interface` for communication between components. See :class:`~maize.core.interface.Port`. + + parameter + A parameter allows setting up a :term:`component` before execution, using either files (:class:`~maize.core.interface.FileParameter`) or arbitrary data (:class:`~maize.core.interface.Parameter`). + + input + Any kind of :term:`component` input, can be set to a :term:`channel`. See :class:`~maize.core.interface.Input` and :class:`~maize.core.interface.MultiInput`. + + output + Any kind of :term:`component` output, can be set to a :term:`channel`. See :class:`~maize.core.interface.Output` and :class:`~maize.core.interface.MultiOutput`. + + channel + Any kind of unidirectional communication channel between components. They have no information about their connection partners and can only be connected to a single :term:`input` and :term:`output`. See :class:`~maize.core.channels.DataChannel` and :class:`~maize.core.channels.FileChannel`. \ No newline at end of file diff --git a/docs/docs/quickstart.rst b/docs/docs/quickstart.rst new file mode 100644 index 0000000..41e767d --- /dev/null +++ b/docs/docs/quickstart.rst @@ -0,0 +1,50 @@ +Quickstart +========== + +Installation +------------ +If you plan on using `maize-contrib `_, then you should just follow the `installation instructions `_ for the latter. Maize will be installed automatically as a dependency. + +Note that `maize-contrib `_ requires several additional domain-specific packages, and you should use its own environment file instead if you plan on using these extensions. + +To get started quickly with running maize, you can install from an environment file: + +.. code-block:: bash + + conda env create -f env-users.yml + +If you want to develop the code or run the tests, use the development environment and install the package in editable mode: + +.. code-block:: bash + + conda env create -f env-dev.yml + conda activate maize-dev + pip install --no-deps -e ./ + +.. caution:: + If you want to develop both maize and *maize-contrib* you may need to install both using legacy editable packages by adding ``--config-settings editable_mode=compat`` as arguments to ``pip``, as not doing so will stop tools like ``pylint`` and ``pylance`` from finding the imports. See this `setuptools issue `_. + +Manual install +^^^^^^^^^^^^^^ +Maize requires the following packages and also depends on python 3.10: + +* dill +* networkx +* pyyaml +* toml +* numpy +* matplotlib +* graphviz +* beartype +* psij-python + +We also strongly recommend the installation of `mypy `_. To install everything use the following command: + +.. code-block:: bash + + conda install -c conda-forge python=3.10 dill networkx yaml toml mypy + +If you wish to develop or add additional modules, the following additional packages will be required: + +* pytest +* sphinx diff --git a/docs/docs/reference.rst b/docs/docs/reference.rst new file mode 100644 index 0000000..5784ddd --- /dev/null +++ b/docs/docs/reference.rst @@ -0,0 +1,13 @@ +Reference +========= +This is the technical documentation for developers: + +Core +---- +.. autosummary:: + :toctree: _autosummary + :template: custom-module.rst + :recursive: + + maize.core + maize.utilities diff --git a/docs/docs/roadmap.rst b/docs/docs/roadmap.rst new file mode 100644 index 0000000..6f57f01 --- /dev/null +++ b/docs/docs/roadmap.rst @@ -0,0 +1,12 @@ +Roadmap +======= + +* Restart handling, including checkpoints and data / state persistence +* Read in custom steps from user-supplied folder, or even on the commandline + + * Add option for plugin search path + +* GUI +* Replace configuration parsing with APISchema or Pydantic +* Template parameter using jinja2 and dictionaries +* Video tutorial diff --git a/docs/docs/steps/index.rst b/docs/docs/steps/index.rst new file mode 100644 index 0000000..e09d3d0 --- /dev/null +++ b/docs/docs/steps/index.rst @@ -0,0 +1,10 @@ +Steps +===== +Documentation on pre-defined steps. For external and domain-specific steps, see the :ref:`maize-contrib documentation `. + +.. toctree:: + :maxdepth: 1 + :caption: Steps + + plumbing + io diff --git a/docs/docs/steps/io.rst b/docs/docs/steps/io.rst new file mode 100644 index 0000000..601bd31 --- /dev/null +++ b/docs/docs/steps/io.rst @@ -0,0 +1,10 @@ +Input / Output +============== +These nodes allow getting data in and out of workflows via parameters. + +.. automodule:: maize.steps.io + :members: + :no-value: + :noindex: + :exclude-members: full_timer, run_timer, logger, run, build + :no-inherited-members: \ No newline at end of file diff --git a/docs/docs/steps/plumbing.rst b/docs/docs/steps/plumbing.rst new file mode 100644 index 0000000..6bf4dca --- /dev/null +++ b/docs/docs/steps/plumbing.rst @@ -0,0 +1,10 @@ +Plumbing +======== +These nodes make piping data between varying numbers of other steps a bit easier. + +.. automodule:: maize.steps.plumbing + :members: + :no-value: + :noindex: + :exclude-members: full_timer, run_timer, logger, run, build + :no-inherited-members: diff --git a/docs/docs/userguide.rst b/docs/docs/userguide.rst new file mode 100644 index 0000000..e50db1e --- /dev/null +++ b/docs/docs/userguide.rst @@ -0,0 +1,607 @@ +User guide +========== + +This is a detailed user guide for *maize*. We will approach workflows in a top-down manner, starting with workflow definitions, followed by grouping multiple nodes together in subgraphs, and finally discussing how to implement custom functionality in your own nodes. + +Workflows +--------- +A :term:`workflow` is a high-level description of a :term:`graph`, allowing execution. It contains multiple :term:`nodes ` or :term:`subgraphs `, joined together with :term:`channels `. You can construct a workflow from both pre-defined nodes and custom ones to tailor a workflow to your particular needs. The following image summarizes the anatomy of a workflow with exposed parameters, with :term:`parameters ` shown in yellow, :term:`inputs ` in green, and :term:`outputs ` in red: + +.. image:: ../resources/workflow-anatomy.svg + +Workflows can be defined programmatically in python (the most flexible approach), or described in a tree-based serialization format (``JSON``, ``YAML``, or ``TOML``). They can then be run within python, or exposed as a commandline tool to be integrated into other pipelines. In contrast to pipelining tools like `*airflow* `_ and `*luigi* `_, *maize* can run workflows with arbitrary topologies including cycles and conditionals. + +.. _custom-workflows: + +Adding nodes +^^^^^^^^^^^^ +Defining a workflow starts by creating a :class:`~maize.core.workflow.Workflow` object: + +.. code-block:: python + + from maize.core.workflow import Workflow + + flow = Workflow(name="Example") + +Not specifying a name will result in a random 6-character sequence being used instead. There are additional useful options: + +.. code-block:: python + + flow = Workflow( + name="Example", + level="debug", + cleanup_temp=False, + default_channel_size=5, + logfile=Path("out.log") + ) + +`level` specifies the logging verbosity (see the :mod:`python logging module `), `cleanup_temp` specifies whether the directories created during execution should be cleaned up, `default_channel_size` determines how many items can sit in an inter-node channel at a time, and `logfile` allows one to write the logs to a file, as opposed to ``STDOUT``. + +We can then start adding nodes to the workflow using the :meth:`~maize.core.workflow.Workflow.add` method: + +.. code-block:: python + + node = flow.add(Example) + +Note that if you want to add another ``Example`` node, you will have to specify a custom name: + +.. code-block:: python + + node2 = flow.add(Example, name="other_example") + +We can again specify additional options that change the way the node is set up and how it is run, including :term:`parameters ` to be overridden: + +.. code-block:: python + + other = flow.add( + OtherExample, + name="other", + parameters=dict(value=42), + loop=True, + fail_ok=True, + n_attempts=3, + max_loops=5 + ) + +If not explicitly told to do so, the :meth:`~maize.core.node.Node.run` method of the node, containing all user code, will only run once. That means it will typically wait to receive some data, process it, send it onwards, and then shutdown. If you want to keep it running and continuously accept input, you can set ``loop`` to ``True``. If you want to limit the maximum number of iterations, use ``max_loops``, by default the node will loop until it receives a shutdown signal or detects neighbouring nodes shutting down. + +.. tip:: + The ``max_loops`` argument can be useful when testing the behaviour of continuously running nodes. For some examples, see the test suites for :mod:`~maize.steps.plumbing`. + +Any failures encountered during execution will raise an exception and cause the whole workflow to shutdown (the default), unless ``fail_ok`` is enabled. This can be useful for additional optional calculations that are not upstream of other essential nodes. Similarly, if a node is expected to fail occasionally, one can increase ``n_attempts`` from the default ``1``. + +The order nodes are added doesn't matter. Alternatively, if you want to add a lot of nodes at once, you can use the :meth:`~maize.core.workflow.Workflow.add_all` method and specify multiple node classes: + +.. code-block:: python + + node, other = flow.add_all(Example, OtherExample) + +You won't be able to directly override any parameters or specify additional keyword arguments though. + +.. tip:: + When defining a new node, it will be automatically added to an internal registry of node types. The node class (not instance) can then be retrieved from the name only using the :meth:`~maize.core.component.Component.get_node_class` function. + +Setting parameters +^^^^^^^^^^^^^^^^^^ +Configuration of nodes is performed using :term:`parameters `. These are typically node settings that are unique to that node and the wrapped software, and would make little sense to change during execution of a workflow. These are things like configuration files or other set-and-forget options. Parameters can be set at node instantiation with :meth:`~maize.core.workflow.Workflow.add` as mentioned above, or on the nodes themselves using the :meth:`~maize.core.interface.Parameter.set` method: + +.. code-block:: python + + node.value.set(37) + other.config.set(Path("config.yml")) + +You can also set them on the commandline, if they are correctly exposed (see `running workflows <#running-workflows>`_). + +Alternatively, :term:`inputs ` can also act as :term:`parameters `. This simplifies cases in which values might normally be set statically at workflow definition, but should allow changes in some special workflow cases. To enable this, :class:`~maize.core.interface.Input` can be instantiated with the ``optional`` flag and / or default values (using ``default`` and ``default_factory``). In those cases, it will not have to be connected to another port and can be used as a parameter instead. Note that once the port is connected, it will not be able to be used as a parameter. + +Handling external software +^^^^^^^^^^^^^^^^^^^^^^^^^^ +In many cases, nodes will depend on packages that may have conflicts with other packages from other nodes. In this situation it is possible to run a node in a different python environment by specifying the path to the desired python executable (e.g. as part of a ``conda`` environment) to the special, always available :attr:`~maize.core.node.Node.python` parameter. In this case you must make sure that the relevant imports are defined in :meth:`~maize.core.node.Node.run` and not in the top-level. In addition, the other environment must also have maize installed. + +If your system has a *module* framework, and the node you're using requires custom software, you can use the :attr:`~maize.core.node.Node.modules` parameter to list modules to load. + +.. danger:: + Module loading is performed in the python process of the node, not in a subprocess. Thus, any modifications to the existing python environment can have unintended consequences. + +If a node requires custom scripts you can use the :attr:`~maize.core.node.Node.scripts` parameter to specify an ``interpreter`` - ``script`` pair, or if you installed the required command in a non-standard location (and it's not in your ``PATH``), you can use the :attr:`~maize.core.node.Node.commands` parameter. + +.. tip:: + All these options (:attr:`~maize.core.node.Node.python`, :attr:`~maize.core.node.Node.modules`, :attr:`~maize.core.node.Node.scripts`, :attr:`~maize.core.node.Node.commands`) can be specified in a workflow-agnostic configuration file (see `global config <#configuring-workflows>`_). This is useful for one-time configurations of HPC or other systems. + +Connecting nodes +^^^^^^^^^^^^^^^^ +The next step is connecting the individual nodes together. In most cases, you should prefer the :meth:`~maize.core.graph.Graph.connect` method: + +.. code-block:: python + + flow.connect(node.output, other.input) + +Maize will assign a :term:`channel` connecting these ports together based on their types. If the type is a file or a list or dictionary of files (using :class:`pathlib.Path`), Maize will use a special :class:`~maize.core.channels.FileChannel` to connect the ports, with the option of either copying the file(s) (the default, and appropriate for smaller files), creating a symlink (when dealing with potentially large files) with ``mode="link"``, or simply moving the file(s) using ``mode="move"``. + +In some cases you might be faced with creating a long sequential workflow in which you are joining many nodes with a single :term:`input` and single :term:`output` each. In that case you can use the more convenient :meth:`~maize.core.graph.Graph.auto_connect` or :meth:`~maize.core.graph.Graph.chain` methods on the nodes instead of the ports: + +.. code-block:: python + + flow.auto_connect(node, other) + + # Or: + flow.chain(node, other, another, and_another) + +In general it is better to explicitly define the connection though, as these methods will attempt to connect the first compatible pair of ports found. If you are faced with creating a lot of connections at once, you can use the :meth:`~maize.core.graph.Graph.connect_all` method with (:term:`output` - :term:`input`) pairs: + +.. code-block:: python + + flow.connect_all( + (node.output, other.input), + (other.output, another.input) + ) + +.. caution:: + Static type checking for :meth:`~maize.core.graph.Graph.connect_all` and :meth:`~maize.core.graph.Graph.add_all` is only implemented for up to 6 items due to limitations in Python's type system. + +Some nodes, especially general purpose data-piping tools such as those provided under :doc:`steps `, use special ports that allow multiple connections (:class:`~maize.core.interface.MultiPort`). In this case, just call :meth:`~maize.core.graph.Graph.connect` multiple times on the same port. This will create more ports and connections as required: + +.. code-block:: python + + flow.connect_all( + (node.output, other.input), + (node.output, another.input) + ) + +Here we have connected the output of ``node`` to two different node inputs. Note that this is *only* possible with :class:`~maize.core.interface.MultiPort`, this is because there is some ambiguity on whether data packets should be copied or distributed (if you want to implement behaviour like this, look at the :mod:`~maize.steps.plumbing` module). Under the hood, :class:`~maize.core.interface.MultiPort` creates multiple individual ports as required, and the node using them must loop through them to send or receive items. + +Handling parameters +^^^^^^^^^^^^^^^^^^^ +One step we might want to perform is to expose node-specific parameters on the workflow level. This can be done using :meth:`~maize.core.graph.Graph.map`, it will simply map all node parameters to the workflow, using the same name: + +.. code-block:: python + + flow.map(other.value, another.param, and_another.temperature) + +The workflow will now have parameters named ``value``, ``param``, and ``temperature``. These will be accessible as attributes on the workflow itself, or in the :attr:`~maize.core.component.Component.parameters` dictionary. For more fine-grained control over naming, and mapping multiple node parameters to a single workflow parameter, use :meth:`~maize.core.graph.Graph.combine_parameters`: + +.. code-block:: python + + flow.val = flow.combine_parameters(other.value, another.value, name="val") + +Here both passed parameters are mapped to a single one, allowing a single call to :meth:`~maize.core.interface.Parameter.set` to adjust multiple values at once. One example where this can be useful is a temperature setting for many physical computations. Thus, a call to :meth:`~maize.core.interface.Parameter.set` will set both ``other.value`` and ``another.value`` to ``37``: + +.. code-block:: python + + flow.val.set(37) + +.. _alt-workflow: + +Alternative workflow definitions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +While a workflow definition through python is the most flexible, it can also be done using a suitable serialization format. In this example we will use ``YAML`` due to its readibility, but you can also use ``JSON`` or ``TOML``: + +.. code-block:: yaml + + name: workflow + nodes: + - name: node + type: Example + parameters: + value: 42 + - name: term + type: Return + + channels: + - receiving: + term: input + sending: + node: output + +This file is equivalent to the following python code: + +.. code-block:: python + + flow = Workflow("workflow") + node = flow.add(Example, name="node", parameters=dict(value=42)) + term = flow.add(Return, name="term") + flow.connect(node.output, term.input) + +Any arguments you would normally pass to the node initialization can be defined under the node list item. To read in the workflow defined above the :meth:`~maize.core.workflow.Workflow.from_file` method can be used: + +.. code-block:: python + + flow = Workflow.from_file("file.yaml") + +.. important:: + You may need to import the relevant node definitions before loading the workflow. + +You can also save workflows to a file to recover them later: + +.. code-block:: python + + flow.to_file("file.yaml") + +You can also save a workflow to a normal python dictionary and use your own serialization method using :meth:`~maize.core.workflow.Workflow.to_dict` and :meth:`~maize.core.workflow.Workflow.from_dict`. + +Running workflows +^^^^^^^^^^^^^^^^^ +Before we run our workflow it is good practice to check that is has been constructed correctly and that all dependencies are accessible: + +.. code-block:: python + + flow.check() + +:meth:`~maize.core.graph.Graph.check` ensures that all nodes are connected, that the types used are consistent, all non-optional parameters are set, and will attempt to run each node in it's designated environment (using :meth:`~maize.core.node.Node._prepare`) by loading modules, packages and ensuring all required software is available. Running our constructed workflow is now just as simple as: + +.. code-block:: python + + flow.execute() + +This will log the execution progress to ``STDOUT`` by default. However in many cases we will want to execute a given workflow as a commandline tool, in the form of a python script. To do so we can use the :func:`~maize.utilities.io.setup_workflow` function: + +.. code-block:: python + + setup_workflow(flow) + +.. hint:: + On some systems it might be necessary to use the :code:`if __name__ == "__main__": ...` guard with workflow scripts to avoid issues with spawning new processes. + +It will create an argument parser with two groups: one for general *maize* parameters governing verbosity, global configuration, and whether to only check the graph without running it; and another one for any exposed parameters. This means that a workflow script can be called like this: + +.. code-block:: shell + + python workflow.py --val 42 + +Calling the script with ``--help`` will show all available maize and workflow-specific options: + +======== ============================ ============================== +Short Long Information +======== ============================ ============================== + ``-c`` ``--check`` Check if the graph was built correctly and exit + ``-l`` ``--list`` List all available nodes and exit + ``-o`` ``--options`` List all exposed workflow parameters and exit + ``-d`` ``--debug`` Provide debugging information + ``-q`` ``--quiet`` Silence all output except errors and warnings + ``NA`` ``--keep`` Keep all output files + ``NA`` ``--config CONFIG`` Global configuration file to use + ``NA`` ``--log LOG`` Logfile to use (instead of ``STDOUT``) + ``NA`` ``--parameters PARAMETERS`` A serialized file containing additional parameters +======== ============================ ============================== + +If you have a workflow in a serialized format, you can run it using the ``maize`` command: + +.. code-block:: bash + + maize flow.yaml + +.. note:: + In *maize-contrib*, complete workflows are defined as functions and then *exposed* in ``pyproject.toml`` as runnable scripts using the :func:`~maize.core.workflow.expose` decorator. To use it, define a function taking no arguments, and returning a workflow instance, then use the decorator on this function. You can then refer to it in ``pyproject.toml`` to make it globally available as a commandline tool (upon reinstallation). + +.. _config-workflow: + +Configuring workflows +^^^^^^^^^^^^^^^^^^^^^ +Maize further allows the use of a global configuration file (``--config``) to adjust options that are more workflow-independent. Here's an example: + +.. literalinclude:: ../../examples/config.toml + :language: toml + :linenos: + +While not recommended, you can also specify the config in python, or overwrite specific parts of it: + +.. code-block:: python + + from maize.utilities.io import Config, NodeConfig + + flow.config = Config() + flow.config.update(Path("path/to/config.toml")) + flow.config.scratch = Path("./") + flow.config.nodes["vina"] = NodeConfig(modules=["OtherVinaModule"]) + +Here, :class:`~maize.utilities.io.NodeConfig` is a node-level configuration class allowing the specification of paths to any required software. + +By default :class:`~maize.utilities.io.Config` will look for a configuration file named ``maize.toml`` in ``$XDG_CONFIG_HOME`` (usually at ``~/.config/``, see `here for more information on the XDG standard `_) or one specified using the ``MAIZE_CONFIG`` environment variable. If you're confused about what to add to your config for a particular workflow, you can use :meth:`~maize.core.workflow.Workflow.generate_config_template` to create a TOML template that you can populate with the correct paths. Note that only one of ``scripts`` and ``commands`` needs to be specified for a given command. + +.. _custom-graphs: + +Subgraphs +--------- +When creating complex workflows, we will often find ourselves in a situation where multiple nodes can be grouped together to one logical unit -- a :term:`subgraph`: + +.. image:: ../resources/graph-anatomy.svg + +This is where :term:`subgraphs ` can be helpful. To define them, create a new class with :class:`~maize.core.graph.Graph` as a base, and add nodes by calling the :meth:`~maize.core.graph.Graph.add` and :meth:`~maize.core.graph.Graph.connect` methods in a custom ``build()`` method, as if we were creating a normal workflow: + +.. code-block:: python + + from maize.core.graph import Graph + from maize.steps.plumbing import Delay + + class SubGraph(Graph): + out: Output[int] + delay: Parameter[int] + + def build(self) -> None: + node = self.add(Example) + delay = self.add(Delay, parameters=dict(delay=2)) + self.connect(node.out, delay.inp) + self.out = self.map_port(delay.out) + self.map(delay.delay) + +A key difference between a :term:`subgraph` and a workflow is that the former will always have exposed ports. We however have to clarify which port should be exposed how, by using the :meth:`~maize.core.graph.Graph.map_port` method and specifying a reference to the original port of a contained node and optionally a new name. We can again use the :meth:`~maize.core.graph.Graph.map` convenience method to automatically expose parameters. Note that to get the benefits of type-checking you should in those cases declare all interfaces in the class body. To group multiple parameters together, we can again use :meth:`~maize.core.graph.Graph.combine_parameters`. ``SubGraph`` will now behave just like any other :term:`component` in the workflow: + +.. code-block:: python + + sg = flow.add(SubGraph, parameters=dict(delay=3)) + flow.connect(sg.out, another.inp) + +A common situation is running all contained nodes in a loop, in this case you can pass ``loop=True`` just like for a normal node. At execution, the whole workflow is flattened and each node executed normally, irrespective of nesting. The :term:`subgraph` paradigm is therefore mostly a conceptual aid for complex workflows. + +.. _custom-nodes: + +Custom nodes +------------ +A :term:`node` is the fundamental unit of computation in *maize*. It features at least one :term:`port`, and any number of :term:`parameters ` that allow communication with other nodes and the user, respectively: + +.. image:: ../resources/node-anatomy.svg + +Each port (and by extension :term:`channel`) has a specific type associated that will prevent a graph from succeeding with a call to :meth:`~maize.core.graph.Graph.check` in case of type mismatches. All interfaces of a port will typically be available as attributes, but can also be accessed through specific dictionaries (:attr:`~maize.core.component.Component.inputs`, :attr:`~maize.core.component.Component.outputs` and :attr:`~maize.core.component.Component.parameters`). All computation is performed in :meth:`~maize.core.node.Node.run`, which is defined by the user. This is an example node definition: + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import Parameter, Output + + class Example(Node): + out: Output[str] = Output() + data: Parameter[str] = Parameter(default="hello") + + def run(self) -> None: + self.out.send(self.data.value) + +This node takes a ``str`` as a parameter (with a default value of ``"hello"``) and outputs it to an :term:`output` port. It only runs this once (unless added to the workflow with ``loop`` set to ``True``), and if the call to :meth:`~maize.core.interface.Output.send` was successful it will immediately return and complete. It's also possible to specify optional parameters with no default value using the ``optional`` keyword in the :class:`~maize.core.interface.Parameter` constructor. + +.. important:: + Any custom imports must be made inside the :meth:`~maize.core.node.Node.run` method, or in functions called by :meth:`~maize.core.node.Node.run`. + +Handling files +^^^^^^^^^^^^^^ +A custom node can also send and receive files. This can be accomplished by specifying :class:`pathlib.Path` as a type. If you expect the files you are receiving or sending to be very large, you should also set the ``mode`` parameter to ``'link'`` or ``'move'``, to ensure that large files don't get copied. + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import Parameter, Output + + class Example(Node): + out: Output[Path] = Output(mode="link") + data: Parameter[str] = Parameter(default="hello") + + def run(self) -> None: + path = create_large_file(self.data.value) + self.out.send(path) + +Behind the scenes, maize lets the receiving node know that one or more files are available. The files only get copied or linked once the other node calls the :meth:`~maize.core.interface.Input.receive` method, avoiding most situations in which files could be overwritten. + +.. _looped-nodes: + +Looped execution +^^^^^^^^^^^^^^^^ +The above example represents a case of a single execution. We may however be interested in performing some form of continuous repeating computation. This can be accomplished by passing ``loop=True`` to the node or subgraph when adding it to the workflow. The following node, when used with looping, will continuously send the same value, akin to the Unix ``yes`` command: + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import Parameter, Output + + class Example(Node): + out: Output[str] = Output() + data: Parameter[str] = Parameter(default="hello") + + def run(self) -> None: + self.out.send(self.data.value) + +However, in some cases you might want to keep state over multiple loops. In that situation, you can setup any data structures you need in the :meth:`~maize.core.node.Node.prepare` method, which will be called before :meth:`~maize.core.node.Node.run`. The :class:`~maize.steps.plumbing.RoundRobin` node is a good example of this: + +.. literalinclude:: ../../maize/steps/plumbing.py + :pyobject: RoundRobin + +Here, we created a :meth:`~maize.core.node.Node.prepare` method, followed by creating an iterator over all outputs, and initializing the first output by calling `next`. In :meth:`~maize.core.node.Node.run`, we can use this output as normal and increment the iterator. + +.. caution:: + Patterns using continuous loops like this always have the potential to cause deadlocks, as they have no explicit exit condition. In many cases however downstream nodes that finish computation will signal a port shutdown and consequently cause the sending port to exit. + +A common pattern with looped nodes is an optional receive, i.e. we will want to receive one or multiple values (see :class:`~maize.core.interface.MultiPort`) only if they are available and then continue. This can be accomplished by using optional ports, and querying them using :meth:`~maize.core.interface.Input.ready` before attempting to receive: + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import MultiInput, Output + + class Example(Node): + inp: MultiInput[str] = MultiInput(optional=True) + out: Output[str] = Output() + + def run(self) -> None: + concat = "" + for inp in self.inp: + if inp.ready(): + concat += inp.receive() + self.out.send(concat) + +This node will always send a value every iteration, no matter if data is available or not. The optional flag will also ensure it can shutdown correctly when neighbouring nodes stop. Alternatively you can use the :meth:`~maize.core.interface.Input.receive_optional` method to unconditionally receive a value, with the possibility of receiving ``None``. + +Another useful option is to allow an input to cache a previously received value by adding the ``cached`` flag to the constructor: + +.. code-block:: python + + from maize.core.node import Node + from maize.core.interface import Input, Output + + class Example(Node): + inp: Input[str] = Input(cached=True) + out: Output[str] = Output() + + def run(self) -> None: + data = self.inp.receive() + self.out.send(data + "-bar") + +In this case, if the node received the string ``"foo"`` the previous iteration, but hasn't been sent a new value this iteration, it will still receive ``"foo"``. This is particularly useful for setting up parameters at the beginning of a workflow and then keeping them unchanged over various internal loops. + +Generic nodes +^^^^^^^^^^^^^ +When designing nodes for general purpose "plumbing" use, it is attractive to allow generic types. Rather than using :obj:`typing.Any`, it is safer to use a :class:`typing.TypeVar`, for example like this: + +.. code-block:: python + + import time + from typing import TypeVar, Generic + + from maize.core.node import Node + from maize.core.interface import Input, Output + + T = TypeVar("T") + + class Example(Node, Generic[T]): + inp: Input[T] = Input() + out: Output[T] = Output() + + def run(self) -> None: + data = self.inp.receive() + time.sleep(5) + self.out.send(data) + +This informs the type system that ``Example`` can receive any value, but the sent value will be of the same type. When creating a workflow, we should however explicitly specify the type when adding nodes: + +.. code-block:: python + + node = flow.add(Example[int], name="node") + +This way we get static typing support throughout our workflow, minimizing errors in graph construction. + +.. caution:: + Dynamic type-checking with generic nodes is currently in an experimental phase. If you encounter problems, using :obj:`typing.Any` is a temporary work-around. + +Running commands +^^^^^^^^^^^^^^^^ +There are two main ways of running commands: locally or using a resource manager (such as `SLURM `_). Both can be used through :meth:`~maize.core.node.Node.run_command` and :meth:`~maize.core.node.Node.run_multi`: by default any command will simply be run locally (with optional validation) and return a :class:`subprocess.CompletedProcess` instance containing the returncode and any output generated on *standard output* or *standard error*. + +Software dependencies +""""""""""""""""""""" +A common issue is that many programs will require some environment preparation that is often heavily system dependent. To accomodate this, any node definitions should include a :attr:`~maize.core.node.Node.required_callables` definition listing the commands or software that is necessary to run, and / or a :attr:`~maize.core.node.Node.required_packages` attribute listing python packages required in the environment. They can then be specified in the `global config <#configuring-workflows>`_ using the ``modules``, ``scripts``, and ``commands`` parameters or using the corresponding pre-defined parameters (see `handling software <#handling-external-software>`_). For example, if the node ``MyNode`` requires an executable named ``executable``, it will first load any modules under the ``MyNode`` heading, followed by looking for an entry including ``executable`` in the ``commands`` and ``scripts`` sections. Any discovered matching commands will be place in the :attr:`~maize.core.node.Node.runnable` dictionary, which can be used with any command invocation: + +.. code-block:: python + + class Example(Node): + + required_callables = ["executable"] + required_packages = ["my_package"] + + inp: Input[float] = Input() + out: Output[float] = Output() + + def run(self) -> None: + import my_package + data = self.inp.receive() + res = self.run_command(f"{self.runnable['executable']} --data {data}") + self.out.send(float(res.stdout)) + +Here, we are running a command that takes some floating point value as input, and outputs a result to *standard output*. We convert this output to a float and send it on. In practice you will probably need more sophisticated parsing of command outputs. The associated configuration section might look something like this: + +.. code-block:: toml + + [example] + python = "/path/to/python/interpreter" # must contain 'my_package' + commands.executable = "/path/to/executable" + +If ``executable`` is a script and requires a preceding interpreter to run, your configuration might look like this instead: + +.. code-block:: toml + + [example] + python = "/path/to/python/interpreter" # must contain 'my_package' + scripts.executable.interpreter = "/path/to/interpreter" + scripts.executable.location = "/path/to/script" + +The ``interpreter`` specification is fairly liberal and also allows the use of *singularity* containers or other non-path objects. If your node requires more customized environment setups, you can implement the :meth:`~maize.core.node.Node.prepare` method with your own initialization logic. + +Running in parallel +""""""""""""""""""" +You can also run multiple commands in parallel using :meth:`~maize.core.node.Node.run_multi`. It takes a list of commands to run and runs them in batches according to the ``n_jobs`` parameter. This can be useful when processing potentially large batches of data with software that does not have its own internal parallelization. Each command can optionally be run in a separate working directory, and otherwise accepts the same parameters as :meth:`~maize.core.node.Node.run_command`: + +.. code-block:: python + + class Example(Node): + + required_callables = ["executable"] + + inp: Input[list[float]] = Input() + out: Output[list[float]] = Output() + + def run(self) -> None: + data = self.inp.receive() + commands = [f"{self.runnable['executable']} --data {d}" for d in data] + results = self.run_multi(commands, n_jobs=4) + output = [float(res.stdout) for res in results] + self.out.send(output) + +We did the same thing as above, but receive and send lists of floats and run our executable in parallel, using 4 jobs. + +Job submission +"""""""""""""" +To make use of batch processing systems common in HPC environments, pass execution options (:class:`~maize.utilities.execution.JobResourceConfig`) to :meth:`~maize.core.node.Node.run_command`: + +.. code-block:: python + + class Example(Node): + + inp: Input[float] = Input() + out: Output[float] = Output() + + def run(self) -> None: + data = self.inp.receive() + options = JobResourceConfig(nodes=2) + self.run_command(f"echo {data}", batch_options=options) + self.out.send(data) + +Batch system settings are handled in the maize configuration (see `configuring workflows <#config-workflow>`_) using :class:`~maize.utilities.execution.ResourceManagerConfig`, for example: + +.. code-block:: toml + + [batch] + system = "slurm" + queue = "core" + walltime = "00:05:00" + +Running batch commands in parallel can be done using :meth:`~maize.core.node.Node.run_multi` in the same way, i.e. passing a :class:`~maize.utilities.execution.JobResourceConfig` object to ``batch_options``. In this case, ``n_jobs`` refers to the maximum number of jobs to submit at once. A common pattern of use is to first prepare the required directory structure and corresponding commands, and then send all commands for execution at once. + +Resource management +""""""""""""""""""" +Because all nodes run simultaneously on a single machine with limited resources, maize features some simple management tools to reserve computational resources: + +.. code-block:: python + + class Example(Node): + + inp: Input[float] = Input() + out: Output[float] = Output() + + def run(self) -> None: + data = self.inp.receive() + with self.cpus(8): + data = do_something_heavy(data) + self.out.send(data) + +You can also reserve GPUs (using :attr:`~maize.core.node.Node.gpus`) using the same syntax. + +Advanced options +"""""""""""""""" +There are multiple additional options for :meth:`~maize.core.node.Node.run_command` and :meth:`~maize.core.node.Node.run_multi` that are worth knowing about: + +===================== ============================== +Option Information +===================== ============================== +``validators`` A list of :class:`~maize.utilities.validation.Validator` objects, allowing output files or *standard output / error* to be checked for content indicating success or failure. +``verbose`` If ``True``, will also log command output to the Maize log. +``raise_on_failure`` If ``True``, will raise an exception if something goes wrong, otherwise will just log a warning. This can be useful when handling batches of data in which some datapoints might be expected to fail. +``command_input`` Can be used to send data to *standard input*. This can be used for commands that might normally require manual user input or interactivity on the commandline. +``pre_execution`` Any command to run just before the main command. Note that if you need to load modules or set environment variables, you should use the options in the configuration system instead (see `handling software <#handling-external-software>`_). Not only does this allow full de-coupling of system and workflow configuration, but it is also more efficient as a module will only be loaded once. +``timeout`` Maximum runtime for a command in seconds. +``working_dirs`` Run in this directory instead of the node working directory (:meth:`~maize.core.node.Node.run_multi` only). +===================== ============================== diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..72aa4d8 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,67 @@ +.. maize documentation master file, created by + sphinx-quickstart on Thu Jan 12 09:54:03 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +maize +===== +This is the documentation for *maize* |version|. + +*maize* is a graph-based workflow manager for computational chemistry pipelines. It is based on the principles of `flow-based programming `_ and thus allows arbitrary graph topologies, including cycles, to be executed. Each task in the workflow (referred to as *nodes*) is run as a separate process and interacts with other nodes in the graph by communicating through unidirectional *channels*, connected to *ports* on each node. Every node can have an arbitrary number of input or output ports, and can read from them at any time, any number of times. This allows complex task dependencies and cycles to be modelled effectively. + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Guide + + docs/quickstart + docs/userguide + docs/cookbook + docs/examples + docs/steps/index + docs/glossary + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Development + + docs/roadmap + docs/design + docs/development + docs/reference + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: External + + Steps + Utilities + +Installation +------------ + +If you plan on using `maize-contrib `_, then you should just follow the installation instructions for the latter. Maize will be installed automatically as a dependency. + +To get started quickly with running maize, you can install from an environment file: + +.. code-block:: bash + + conda env create -f env-users.yml + +Teaser +------ +A taste for defining and running workflows with *maize*. + +.. literalinclude:: ../examples/helloworld.py + :language: python + :linenos: + + +Indices and tables +------------------ + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` \ No newline at end of file diff --git a/docs/resources/cob-mulberry.png b/docs/resources/cob-mulberry.png new file mode 100644 index 0000000..5a62939 Binary files /dev/null and b/docs/resources/cob-mulberry.png differ diff --git a/docs/resources/component-class-diagram.svg b/docs/resources/component-class-diagram.svg new file mode 100755 index 0000000..83a007e --- /dev/null +++ b/docs/resources/component-class-diagram.svg @@ -0,0 +1 @@ +name: strparent:Componentinputs:dict[str, Input[Any]]outputs: dict[str, Output[Any]]parameters:dict[str,Parameter[Any]]status:StatusComponentname: strsignal:multiprocessing.EventRunnableexecute()cleanup()max_loops:intcpus: Resourcesgpus:ResourcesNodebuild()check()loop(step: float)shutdown()run()nodes: dict[str, Component]channels:dict[Key, Channel[Any]]Graphadd(name: str, component: Component)connect(sending:Output[T],receiving:Input[T])map_port(name:str,port:Port[Any])map_parameters(name: str, *parameters)Workflowfrom_file(file: Path) -> Workflowfrom_checkpoint(file: Path) -> Workflowto_file(file:Path)to_checkpoint(file: Path)add_arguments(parser: ArgumentGroup)execute() \ No newline at end of file diff --git a/docs/resources/cyclic-workflow.svg b/docs/resources/cyclic-workflow.svg new file mode 100755 index 0000000..e64151d --- /dev/null +++ b/docs/resources/cyclic-workflow.svg @@ -0,0 +1 @@ +CalcInitialMergeOut \ No newline at end of file diff --git a/docs/resources/graph-anatomy.svg b/docs/resources/graph-anatomy.svg new file mode 100755 index 0000000..b4f46d0 --- /dev/null +++ b/docs/resources/graph-anatomy.svg @@ -0,0 +1 @@ +SubgraphInput[T]Output[T]MultiParameter[T]BAInput[T]C \ No newline at end of file diff --git a/docs/resources/graph-as-tree.svg b/docs/resources/graph-as-tree.svg new file mode 100755 index 0000000..2a26d08 --- /dev/null +++ b/docs/resources/graph-as-tree.svg @@ -0,0 +1 @@ +WorkflowSubgraphNode 1Node 2Node 3Parent -Child \ No newline at end of file diff --git a/docs/resources/interface-class-diagram.svg b/docs/resources/interface-class-diagram.svg new file mode 100755 index 0000000..886f29c --- /dev/null +++ b/docs/resources/interface-class-diagram.svg @@ -0,0 +1 @@ +name: strparent:Componentdatatype: Anypath: tuple[str, …]Interface[T]build(name: str, parent: Component)check()default: T | Noneis_set:boolvalue: TParameter[T]set(value: T)filepath: PathFileParameterparameters: list[Parameter[T]]MultiParameter[T]timeout: floatoptional:boolchannel: Channel[T] | Noneis_connected: boolPort[T]set_channel(channel: Channel[T])close()n_ports: int | None_ports: list[Port[T]]MultiPort[T]MultiInput[T]MultiOutput[T]active: boolInput[T]dump() -> list[T]preload(data: T)ready() -> boolreceive()->Treceive_optional() -> T | NoneOutput[T]send(item: T) \ No newline at end of file diff --git a/docs/resources/maize-logo.svg b/docs/resources/maize-logo.svg new file mode 100755 index 0000000..97e9aff --- /dev/null +++ b/docs/resources/maize-logo.svg @@ -0,0 +1 @@ +IMAEZ \ No newline at end of file diff --git a/docs/resources/node-anatomy.svg b/docs/resources/node-anatomy.svg new file mode 100755 index 0000000..89877a0 --- /dev/null +++ b/docs/resources/node-anatomy.svg @@ -0,0 +1 @@ +NodeInput[T]Output[T]Parameter[T]Channel[T]Channel[T] \ No newline at end of file diff --git a/docs/resources/workflow-anatomy.svg b/docs/resources/workflow-anatomy.svg new file mode 100755 index 0000000..62b57c0 --- /dev/null +++ b/docs/resources/workflow-anatomy.svg @@ -0,0 +1 @@ +WorkflowMultiParameter[T]BACMultiParameter[T]DMultiParameter[T] \ No newline at end of file diff --git a/env-dev.yml b/env-dev.yml new file mode 100644 index 0000000..41d7dfe --- /dev/null +++ b/env-dev.yml @@ -0,0 +1,29 @@ +name: maize-dev +channels: + - https://conda.anaconda.org/conda-forge + - defaults +dependencies: + - python=3.10 + - toml>=0.10.2 + - yaml>=0.2.5 + - dill>=0.3.6 + - networkx>=3.1 + - numpy>=1.24.3, <2.0 + - matplotlib>=3.7.1 + - pytest>=7.3.1 + - pytest-datadir>=1.4.1 + - pytest-mock>=3.10.0 + - pytest-cov>=5.0 + - mypy>=1.2.0, <1.7 + - pylint>=2.17.3 + - sphinx>=6.2.0 + - pandoc>=3.1.1 + - graphviz>=8.0.3 + - beartype>=0.13.1 + - pip>=20.0 + - pip: + - git+https://github.com/ExaWorks/psij-python.git@9e1a777 + - graphviz>=0.20.1 + - nbsphinx>=0.9.1 + - furo>=2023.3.27 + - pyyaml>=0.2.5 diff --git a/env-users.yml b/env-users.yml new file mode 100644 index 0000000..094094e --- /dev/null +++ b/env-users.yml @@ -0,0 +1,19 @@ +name: maize +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - toml>=0.10.2 + - yaml>=0.2.5 + - dill>=0.3.6 + - networkx>=3.1 + - mypy>=1.2.0, <1.7 + - graphviz>=8.0.3 + - beartype>=0.13.1 + - pip>=20.0 + - pip: + - git+https://github.com/ExaWorks/psij-python.git@9e1a777 + - graphviz>=0.20.1 + - pyyaml>=0.2.5 + - git+ssh://git@github.com/molecularai/maize.git diff --git a/examples/config.toml b/examples/config.toml new file mode 100644 index 0000000..0da2203 --- /dev/null +++ b/examples/config.toml @@ -0,0 +1,44 @@ +# maize global configuration file example + +# Where to save temporary files and all workflow directories +scratch = "/tmp" + +# Additional packages to load +packages = [ + "maize.steps.mai" +] + +# Environment variables to be set globally +[environment] +EXAMPLE = "VARIABLE" +OTHER_EXAMPLE = "OTHER_VARIABLE" + +# Batch job system options +[batch] +system = "slurm" # Can be one of {'cobalt', 'flux', 'local', 'lsf', 'pbspro', 'rp', 'slurm'} +max_jobs = 100 # The maximum number of jobs that can be submitted by a node at once +queue = "core" +project = "your_project" +launcher = "srun" # The launcher to use for the command, usually one of {'srun', 'mpirun', 'mpiexec'} +walltime = "24:00:00" # Job walltime limit, shorter times may improve queueing times + +# The next sections configure nodes requiring specific executables, +# here we are configuring the node defined as ``Example``: +[example] +python = "path/to/python" + +# Modules to load +modules = ["program_module/1.0"] + +# Commands and their paths +commands.other_program = "path/to/other_program" + +# You can also expand environment variables that might have been set by the module +commands.another_program = "$MODULE_ENV_VAR/another_program" + +# Scripts that require a specific interpreter +scripts.program.interpreter = "path/to/python" +scripts.program.location = "path/to/program/or/folder" + +# Default parameter settings for this node +parameters.value = 42 diff --git a/examples/graph.yaml b/examples/graph.yaml new file mode 100644 index 0000000..8dfb24d --- /dev/null +++ b/examples/graph.yaml @@ -0,0 +1,23 @@ +name: Example +level: DEBUG +nodes: +- name: sg + type: SubGraph +- name: ex + type: Example +- name: concat + type: ConcatAndPrint +channels: +- receiving: + concat: inp + sending: + sg: out +- receiving: + concat: inp + sending: + ex: out +parameters: +- name: data + value: "World" + map: + - ex: data diff --git a/examples/helloworld.py b/examples/helloworld.py new file mode 100644 index 0000000..c77acae --- /dev/null +++ b/examples/helloworld.py @@ -0,0 +1,34 @@ +"""A simple hello-world-ish example graph.""" + +from maize.core.interface import Parameter, Output, MultiInput +from maize.core.node import Node +from maize.core.workflow import Workflow + +# Define the nodes +class Example(Node): + data: Parameter[str] = Parameter(default="Hello") + out: Output[str] = Output() + + def run(self) -> None: + self.out.send(self.data.value) + + +class ConcatAndPrint(Node): + inp: MultiInput[str] = MultiInput() + + def run(self) -> None: + result = " ".join(inp.receive() for inp in self.inp) + self.logger.info("Received: '%s'", result) + + +# Build the graph +flow = Workflow(name="hello") +ex1 = flow.add(Example, name="ex1") +ex2 = flow.add(Example, name="ex2", parameters=dict(data="maize")) +concat = flow.add(ConcatAndPrint) +flow.connect(ex1.out, concat.inp) +flow.connect(ex2.out, concat.inp) + +# Check and run! +flow.check() +flow.execute() diff --git a/examples/simpledcg.py b/examples/simpledcg.py new file mode 100644 index 0000000..463f36b --- /dev/null +++ b/examples/simpledcg.py @@ -0,0 +1,53 @@ +"""Simple directed-cyclic-graph example with a subgraph""" + +from maize.core.graph import Graph +from maize.core.interface import Parameter, Input, Output +from maize.core.node import Node +from maize.core.workflow import Workflow + +from maize.steps.plumbing import Delay, Merge +from maize.steps.io import Return + + +class A(Node): + out: Output[int] = Output() + send_val: Parameter[int] = Parameter() + + def run(self) -> None: + self.out.send(self.send_val.value) + + +class B(Node): + inp: Input[int] = Input() + out: Output[int] = Output() + final: Output[int] = Output() + + def run(self) -> None: + val = self.inp.receive() + if val > 48: + self.logger.debug("%s stopping", self.name) + self.final.send(val) + return + self.out.send(val + 2) + +class SubGraph(Graph): + def build(self) -> None: + a = self.add(A, parameters=dict(send_val=36)) + d = self.add(Delay[int], parameters=dict(delay=1)) + self.connect(a.out, d.inp) + self.out = self.map_port(d.out) + self.val = self.combine_parameters(a.send_val, name="val") + + +flow = Workflow(name="test") +sg = flow.add(SubGraph) +b = flow.add(B, loop=True) +merge = flow.add(Merge[int]) +ret = flow.add(Return[int]) +flow.connect(sg.out, merge.inp) +flow.connect(merge.out, b.inp) +flow.connect(b.out, merge.inp) +flow.connect(b.final, ret.inp) +flow.combine_parameters(sg.val, name="val") +flow.check() +flow.execute() diff --git a/maize/core/__init__.py b/maize/core/__init__.py new file mode 100644 index 0000000..44a51aa --- /dev/null +++ b/maize/core/__init__.py @@ -0,0 +1,7 @@ +""" +Core +==== + +Core *maize* functionality. + +""" diff --git a/maize/core/channels.py b/maize/core/channels.py new file mode 100644 index 0000000..67c0464 --- /dev/null +++ b/maize/core/channels.py @@ -0,0 +1,441 @@ +""" +Channels +-------- +Communication channels used for passing files and data between components. + +The two main channel types of interest to users are `DataChannel` and `FileChannel`. +The former is a wrapper around `multiprocessing.Queue`, while the latter uses +channel-specific directories as a shared space to transmit files. All channels +are unidirectional - a connected node cannot send and receive to the same channel. +Either can be assigned to a `Port` using the `interface.Port.set_channel` method. + +Custom channels are possible and should subclass from the `Channel` abstract base class. +This will also require modifying the `graph.Graph.connect` method to use this new channel. + +""" + +from abc import ABC, abstractmethod +from multiprocessing import get_context +from pathlib import Path +import queue +import shutil +from tempfile import mkdtemp +import time +from typing import TYPE_CHECKING, Literal, TypeVar, Generic, Any, cast + +import dill +from maize.utilities.execution import DEFAULT_CONTEXT +from maize.utilities.io import common_parent, sendtree +from maize.utilities.utilities import has_file + +if TYPE_CHECKING: + from multiprocessing import Queue + +DEFAULT_TIMEOUT = 0.5 + + +T = TypeVar("T") + + +class ChannelException(Exception): + """Raised for miscallaneous channel issues.""" + + +class ChannelFull(ChannelException): + """Raised for full channels, will typically be caught by the port.""" + + +class ChannelEmpty(ChannelException): + """Raised for empty channels, will typically be caught by the port.""" + + +class Channel(ABC, Generic[T]): + """Represents a communication channel that will be plugged into a `Port`.""" + + @property + @abstractmethod + def active(self) -> bool: + """Returns whether the channel is active.""" + + @property + @abstractmethod + def ready(self) -> bool: + """Returns whether the channel is ready to receive from.""" + + @property + @abstractmethod + def size(self) -> int: + """Returns the current approximate size of the buffer""" + + @abstractmethod + def close(self) -> None: + """Closes the channel.""" + + @abstractmethod + def kill(self) -> None: + """Kills the channel, called at network shutdown and ensures no orphaned processes.""" + + @abstractmethod + def send(self, item: T, timeout: float | None = None) -> None: + """ + Send an item. + + Will attempt to send an item of any type into a channel, + potentially blocking if the channel is full. + + Parameters + ---------- + item + Item to send + timeout + Timeout in seconds to wait for space in the channel + + Raises + ------ + ChannelFull + If the channel is already full + + """ + + @abstractmethod + def receive(self, timeout: float | None = None) -> T | None: + """ + Receive an item. + + Parameters + ---------- + timeout + Timeout in seconds to wait for an item + + Returns + ------- + T | None + The received item, or ``None`` if the channel is empty + + """ + + @abstractmethod + def flush(self, timeout: float = 0.1) -> list[T]: + """ + Flush the contents of the channel. + + Parameters + ---------- + timeout + The timeout for item retrieval + + Returns + ------- + list[T] + List of unserialized channel contents + """ + + +_FilePathType = list[Path] | dict[Any, Path] | Path + +_T_PATH, _T_LIST, _T_DICT = 0, 1, 2 + + +class FileChannel(Channel[_FilePathType]): + """ + A communication channel for data in the form of files. Files must be represented + by `Path` objects, and can either be passed alone, as a list, or as a dictionary. + + When sending a file, it is first transferred (depending on `mode`) to an escrow + directory specific to the channel (typically a temporary directory). Upon calling + `receive`, this file is transferred to an input-directory for the receiving node. + + Parameters + ---------- + mode + Whether to ``copy`` (default), ``link``, or ``move`` files from node to node. + + See Also + -------- + DataChannel : Channel for arbitrary serializable data + + """ + + _destination_path: Path + + # FileChannel allows single Path objects, as well as lists and dictionaries + # of paths to be sent. To make this possible, we always send a dictionary of + # paths, but convert to and from dictionaries while sending and receiving, + # and communicate the type of data through a shared value object. + @staticmethod + def _convert(items: _FilePathType) -> dict[Any, Path]: + if isinstance(items, list): + items = {i: item for i, item in enumerate(items)} + elif isinstance(items, Path): + items = {0: items} + return items + + def __init__(self, mode: Literal["copy", "link", "move"] = "copy") -> None: + ctx = get_context(DEFAULT_CONTEXT) + self._channel_dir = Path(mkdtemp()) + self._payload: "Queue[dict[Any, Path]]" = ctx.Queue(maxsize=1) + self._file_trigger = ctx.Event() # File available trigger + self._shutdown_signal = ctx.Event() # Shutdown trigger + # FIXME temporary Any type hint until this gets solved + # https://github.com/python/typeshed/issues/8799 + self._transferred_type: Any = ctx.Value("i", _T_DICT) # Type of data sent, see _T_* + self.mode = mode + + @property + def active(self) -> bool: + return not self._shutdown_signal.is_set() + + @property + def ready(self) -> bool: + return self._file_trigger.is_set() + + @property + def size(self) -> int: + return 1 if self._file_trigger.is_set() else 0 + + def setup(self, destination: Path) -> None: + """ + Setup the file channel directories. + + Parameters + ---------- + destination + Path to the destination directory for the input port + + """ + destination.mkdir(exist_ok=True) + self._destination_path = destination.absolute() + if existing := list(self._destination_path.rglob("*")): + self.preload(existing) + + def close(self) -> None: + while self._file_trigger.is_set() and has_file(self._channel_dir): + time.sleep(DEFAULT_TIMEOUT) + self._shutdown_signal.set() + self._payload.cancel_join_thread() + shutil.rmtree(self._channel_dir, ignore_errors=True) + + def kill(self) -> None: + # No need to do anything special here + pass + + def preload(self, items: _FilePathType) -> None: + """Load a file into the channel without explicitly sending.""" + item = FileChannel._convert(items) + + self._update_type(items) + try: + self._payload.put(item, timeout=DEFAULT_TIMEOUT) + except queue.Full as full: + raise ChannelFull("Attempting to preload an already filled channel") from full + self._file_trigger.set() + + def send(self, item: _FilePathType, timeout: float | None = None) -> None: + items = FileChannel._convert(item) + items = {k: item.absolute() for k, item in items.items()} + if not all(file.exists() for file in items.values()) or self._file_trigger.is_set(): + # Give a time ultimatum, for the trigger + if timeout is not None: + time.sleep(timeout) + if self._file_trigger.is_set(): + raise ChannelFull("File channel already has data, are you sure it was received?") + + if not all(file.exists() for file in items.values()): + raise ChannelException( + f"Files at {common_parent(list(items.values())).as_posix()} not found" + ) + + self._update_type(item) + try: + self._payload.put( + sendtree(items, self._channel_dir, mode=self.mode), timeout=DEFAULT_TIMEOUT + ) + except queue.Full as full: + raise ChannelFull("Channel already has files as payload") from full + self._file_trigger.set() + + def receive(self, timeout: float | None = None) -> _FilePathType | None: + # Wait for the trigger and then check if we have the file + if not self._file_trigger.wait(timeout=timeout): + return None + + try: + files = self._payload.get(timeout=timeout) + except queue.Empty as empty: + raise ChannelException("Trigger was set, but no data to receive") from empty + + dest_files = sendtree(files, self._destination_path, mode="move") + self._file_trigger.clear() + return self._cast_type(dest_files) + + def flush(self, timeout: float = DEFAULT_TIMEOUT) -> list[_FilePathType]: + """ + Flush the contents of the channel. + + Parameters + ---------- + timeout + The timeout for item retrieval + + Returns + ------- + list[_FilePathType] + List with a paths to a file in the destination + directory or an empty list. This is to be consistent + with the signature of `DataChannel`. + + """ + files = self.receive(timeout=timeout) + if files is None: + return [] + return [files] + + def _update_type(self, data: _FilePathType) -> None: + """Updates the type of data being transferred to allow the correct type to be returned.""" + with self._transferred_type.get_lock(): + if isinstance(data, dict): + self._transferred_type.value = _T_DICT + elif isinstance(data, list): + self._transferred_type.value = _T_LIST + else: + self._transferred_type.value = _T_PATH + + def _cast_type(self, data: dict[Any, Path]) -> _FilePathType: + """Cast received data to match the original sent type.""" + if self._transferred_type.value == _T_DICT: + return data + elif self._transferred_type.value == _T_LIST: + return list(data.values()) + else: + return data[0] + + +class DataChannel(Channel[T]): + """ + A communication channel for data in the form of python objects. + + Any item sent needs to be serializable using `dill`. + + Parameters + ---------- + size + Size of the item queue + + See Also + -------- + FileChannel : Channel for files + + """ + + def __init__(self, size: int) -> None: + ctx = get_context(DEFAULT_CONTEXT) + try: + self._queue: "Queue[bytes]" = ctx.Queue(size) + + # This will only happen for *very* large graphs + # (more than 2000 or so connections, it will depend on the OS) + except OSError as err: # pragma: no cover + msg = ( + "You have reached the maximum number of channels " + "supported by your operating system." + ) + raise ChannelException(msg) from err + + self._n_items: Any = ctx.Value("i", size) # Number of items in the channel + self._n_items.value = 0 + self._signal = ctx.Event() # False + + @property + def active(self) -> bool: + return not self._signal.is_set() + + @property + def ready(self) -> bool: + return cast(bool, self._n_items.value > 0) + + @property + def size(self) -> int: + return cast(int, self._n_items.value) + + def close(self) -> None: + already_signalled = self._signal.is_set() + self._signal.set() + + # A channel can be closed from both the sending and receiving side. If the + # sending side sends an item and immediately quits, closing the channel from + # its side, the receiving node still needs to be able to receive the data + # cleanly, meaning we can't call `cancel_join_thread()` prematurely, as it + # will clear the internal queue. So we only call it if we get a second closing + # signal, i.e. from the receiving side. This means we have already received + # any data we want, or exited for a different reason. + if already_signalled: + self._queue.cancel_join_thread() + + def kill(self) -> None: + self._queue.close() + + def preload(self, items: list[bytes] | bytes) -> None: + """ + Pre-load the channel with a serialized item. Used by restarts. + + Parameters + ---------- + items + Serialized items to pre-load the channel with + + """ + if isinstance(items, bytes): + items = [items] + for item in items: + self._queue.put(item) + with self._n_items.get_lock(): + self._n_items.value += 1 + + def send(self, item: T, timeout: float | None = None) -> None: + try: + pickled = dill.dumps(item) + self._queue.put(pickled, timeout=timeout) + + # Sending an item and immediately polling will falsely result in + # a supposedly empty channel, so we wait a fraction of a second + time.sleep(DEFAULT_TIMEOUT) + except queue.Full as err: + raise ChannelFull("Channel queue is full") from err + + with self._n_items.get_lock(): + self._n_items.value += 1 + + def receive(self, timeout: float | None = None) -> T | None: + # We have no buffered data, so we try receiving now + try: + raw_item = self._queue.get(timeout=timeout) + + # The cast here is because `dill.loads` *could* return `Any`, + # but because we only interact with it through the channel, + # it actually will always return `T` + val = cast(T, dill.loads(raw_item)) + except queue.Empty: + val = None + else: + with self._n_items.get_lock(): + self._n_items.value -= 1 + return val + + def flush(self, timeout: float = DEFAULT_TIMEOUT) -> list[T]: + """ + Flush the contents of the channel. + + Parameters + ---------- + timeout + The timeout for item retrieval + + Returns + ------- + list[T] + List of unserialized channel contents + """ + items: list[T] = [] + while (item := self.receive(timeout=timeout)) is not None: + items.append(item) + return items diff --git a/maize/core/component.py b/maize/core/component.py new file mode 100644 index 0000000..5bfe432 --- /dev/null +++ b/maize/core/component.py @@ -0,0 +1,614 @@ +""" +Component +--------- +Provides a component class acting as the base for a graph or a component (node) thereof. + +`Component` is used as a base class for both `Node` and `Graph` and represents a +hierarchical component with a ``parent`` (if it's not the root node). It should +not be used directly. The workflow is internally represented as a tree, with the +`Workflow` representing the root node owning all other nodes. Leaf nodes are termed +`Node` and represent atomic workflow steps. Nodes with branches are (Sub-)`Graph`s, +as they contain multiple nodes, but expose the same interface that a would: + +.. code-block:: text + + Workflow + / \ + Subgraph Node + / \ + Node Node + +""" + +from collections.abc import Generator, Sequence +import inspect +import itertools +import logging +from multiprocessing import get_context +from pathlib import Path +from traceback import TracebackException +from typing import ( + TYPE_CHECKING, + Optional, + Any, + ClassVar, + TypeAlias, + Union, + Literal, + cast, + get_origin, +) +from typing_extensions import TypedDict + +from maize.core.interface import ( + FileParameter, + Input, + Interface, + MultiInput, + MultiOutput, + MultiParameter, + Output, + Port, + Parameter, + MultiPort, +) +import maize.core.interface as _in +from maize.core.runtime import Status, StatusHandler, StatusUpdate +from maize.utilities.execution import DEFAULT_CONTEXT +from maize.utilities.resources import Resources, cpu_count, gpu_count +from maize.utilities.utilities import Timer, unique_id +from maize.utilities.io import Config, NodeConfig, with_fields + +# https://github.com/python/typeshed/issues/4266 +if TYPE_CHECKING: + from logging import LogRecord + from multiprocessing import Queue + from multiprocessing.synchronize import Event as EventClass + from maize.core.graph import Graph + + MessageType: TypeAlias = StatusUpdate | None + + +class Component: + """ + Base class for all components. Should not be used directly. + + Parameters + ---------- + parent + Parent component, typically the graph in context + name + The name of the component + description + An optional additional description + fail_ok + If True, the failure in the component will + not trigger the whole network to shutdown + n_attempts + Number of attempts at executing the `run()` method + level + Logging level, if not given or ``None`` will use the parent logging level + cleanup_temp + Whether to remove any temporary directories after completion + resume + Whether to resume from a previous checkpoint + logfile + File to output all log messages to, defaults to STDOUT + max_cpus + Maximum number of CPUs to use, defaults to the number of available cores in the system + max_gpus + Maximum number of GPUs to use, defaults to the number of available GPUs in the system + loop + Whether to run the `run` method in a loop, as opposed to a single time + + See Also + -------- + node.Node + Node class for implementing custom tasks + graph.Graph + Graph class to group nodes together to form a subgraph + workflow.Workflow + Workflow class to group nodes and graphs together to an executable workflow + + """ + + # By default, we keep track of all defined nodes and subgraphs so that we can instantiate + # them from serialized workflow definitions in YAML, JSON, or TOML format. Importing the + # node / graph definition is enough to make them available. + __registry: ClassVar[dict[str, type["Component"]]] = {} + + def __init_subclass__(cls, name: str | None = None, register: bool = True): + if name is None: + name = cls.__name__.lower() + if register: + Component.__registry[name] = cls + super().__init_subclass__() + + @classmethod + def _generate_sample_config(cls, name: str) -> None: + """Generates a sample config to be concatenated with the docstring.""" + if cls.__doc__ is not None and cls.required_callables: + conf = NodeConfig().generate_template_toml(name, cls.required_callables) + conf = "\n".join(f" {c}" for c in conf.split("\n")) + cls.__doc__ += ( + f"\n .. rubric:: Config example\n\n Example configuration for {name}. " + f"Each required callable only requires either ``scripts`` " + f"OR ``commands``.\n\n .. code-block:: toml\n\n{conf}\n" + ) + + class _SerialType(TypedDict): + name: str + inputs: list[dict[str, Any]] + outputs: list[dict[str, Any]] + parameters: list[dict[str, Any]] + + @classmethod + def serialized_summary(cls) -> _SerialType: + """ + Provides a serialized representation of the component type. + + Returns + ------- + dict[str, Any] + Nested dictionary of the component type structure, including I/O and parameters. + + Examples + -------- + >>> Merge.serialized_summary() + {"name": "Merge", "inputs": [{"name": "inp", ...}]} + + """ + result: Component._SerialType = { + "name": cls.__name__, + "inputs": [], + "outputs": [], + "parameters": [], + } + for name, attr in cls.__dict__.items(): + if not isinstance(attr, Interface): + continue + data = {"name": name} | attr.serialized + match attr: + case Input() | MultiInput(): + result["inputs"].append(data) + case Output() | MultiOutput(): + result["outputs"].append(data) + case Parameter() | MultiParameter() | FileParameter(): + result["parameters"].append(data) + + return result + + @classmethod + def get_summary_line(cls) -> str: + """Provides a one-line summary of the node.""" + if cls.__doc__ is not None: + for line in cls.__doc__.splitlines(): + if line: + return line.lstrip() + return "" + + @classmethod + def get_interfaces( + cls, kind: Literal["input", "output", "parameter"] | None = None + ) -> set[str]: + """ + Returns all interfaces available to the node. + + Parameters + ---------- + kind + Kind of interface to retrieve + + Returns + ------- + set[str] + Interface names + + """ + inter = set() + for name, attr in (cls.__dict__ | cls.__annotations__).items(): + match get_origin(attr), kind: + case (_in.Input | _in.MultiInput), "input" | None: + inter.add(name) + case (_in.Output | _in.MultiOutput), "output" | None: + inter.add(name) + case (_in.Parameter | _in.MultiParameter | _in.FileParameter), "parameter" | None: + inter.add(name) + return inter + + @classmethod + def get_inputs(cls) -> set[str]: + """Returns all inputs available to the node.""" + return cls.get_interfaces(kind="input") + + @classmethod + def get_outputs(cls) -> set[str]: + """Returns all outputs available to the node.""" + return cls.get_interfaces(kind="output") + + @classmethod + def get_parameters(cls) -> set[str]: + """Returns all parameters available to the node.""" + return cls.get_interfaces(kind="parameter") + + __checked: bool = False + """Whether a node of a certain name has been checked for runnability with `prepare()`""" + + @classmethod + def set_checked(cls) -> None: + """Set a node as checked, to avoid duplicate runs of `prepare()`""" + cls.__checked = True + + @classmethod + def is_checked(cls) -> bool: + """``True`` if a node has been checked, to avoid duplicate runs of `prepare()`""" + return cls.__checked + + @staticmethod + def get_available_nodes() -> set[type["Component"]]: + """ + Returns all available and registered nodes. + + Returns + ------- + set[str] + All available node names + + """ + return set(Component.__registry.values()) + + @staticmethod + def get_node_class(name: str) -> type["Component"]: + """ + Returns the node class corresponding to the given name. + + Parameters + ---------- + name + Name of the component class to retrieve + + Returns + ------- + Type[Component] + The retrieved component class, can be passed to `add_node` + + """ + try: + return Component.__registry[name.lower()] + except KeyError as err: + raise KeyError( + f"Node of type {name.lower()} not found in the registry. " + "Have you imported the node class definitions?" + ) from err + + _COMPONENT_FIELDS = {"name", "description", "fail_ok", "n_attempts", "n_inputs", "n_outputs"} + + _tags: ClassVar[set[str]] = set() + """Tags to identify the component as exposing a particular kind of interface""" + + required_callables: ClassVar[list[str]] = [] + """List of external commandline programs that are required for running the component.""" + + required_packages: ClassVar[list[str]] = [] + """List of required python packages""" + + logger: logging.Logger + """Python logger for both the build and run procedures.""" + + run_timer: Timer + """Timer for the run duration, without waiting for resources or other nodes.""" + + full_timer: Timer + """Timer for the full duration, including waiting for resources or other nodes.""" + + work_dir: Path + """Working directory for the component.""" + + datatype: Any = None + """The component datatype if it's generic.""" + + status = StatusHandler() + """Current status of the component.""" + + def __init__( + self, + parent: Optional["Graph"] = None, + name: str | None = None, + description: str | None = None, + fail_ok: bool = False, + n_attempts: int = 1, + level: int | str | None = None, + cleanup_temp: bool = True, + scratch: Path | None = None, + resume: bool = False, + logfile: Path | None = None, + max_cpus: int | None = None, + max_gpus: int | None = None, + loop: bool | None = None, + ) -> None: + self.name = str(name) if name is not None else unique_id() + ctx = get_context(DEFAULT_CONTEXT) + + if parent is None: + # This is the queue for communication with the main process, it will + # allow exceptions to be raised, and shutdown tokens to be raised + # Pylint freaks out here, see: https://github.com/PyCQA/pylint/issues/3488 + self._message_queue: "Queue[MessageType]" = ctx.Queue() + + # This is the logging-only queue + self._logging_queue: "Queue[LogRecord | None]" = ctx.Queue() + + # This will signal the whole graph to shutdown gracefully + self.signal: "EventClass" = ctx.Event() + + # Special considerations apply when resuming, especially for channel handling + self.resume = resume + + # Global config + self.config = Config.from_default() + + # Cleanup all working directories + self.cleanup_temp = cleanup_temp + + # Overwrite default location + self.scratch = Path(scratch) if scratch is not None else self.config.scratch + + # Optional logfile + self.logfile = logfile + + # Whether to loop the `run` method for itself (nodes) or child nodes (graph) + self.looped = loop if loop is not None else False + + # Resource management + self.cpus = Resources(max_count=max_cpus or cpu_count(), parent=self) + self.gpus = Resources(max_count=max_gpus or gpu_count(), parent=self) + + else: + self._message_queue = parent._message_queue + self._logging_queue = parent._logging_queue + self.signal = parent.signal + self.resume = parent.resume + self.config = parent.config + self.cleanup_temp = parent.cleanup_temp + self.scratch = parent.scratch + self.logfile = parent.logfile + self.cpus = parent.cpus + self.gpus = parent.gpus + self.looped = parent.looped if loop is None else loop + + # Logging level follows the parent graph by default, but can be overridden + if level is None: + level = logging.INFO if parent is None or parent.level is None else parent.level + self.level: int | str = level.upper() if isinstance(level, str) else level + + # Temporary working directory path to allow some basic pre-execution uses + self.work_dir = Path("./") + + # Prepared commands + self.runnable: dict[str, str] = {} + + self.parent = parent + self.description = description + self.fail_ok = fail_ok + self.n_attempts = n_attempts + + # Both atomic components and subgraphs can have ports / parameters + self.inputs: dict[str, Input[Any] | MultiInput[Any]] = {} + self.outputs: dict[str, Output[Any] | MultiOutput[Any]] = {} + self.parameters: dict[str, Parameter[Any]] = {} + self.status = Status.READY + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name='{self.name}', parent={self.parent})" + + def __lshift__(self, other: "Component") -> Union["Component", None]: + if self.parent is not None: + self.parent.auto_connect(receiving=self, sending=other) + return self + return None # pragma: no cover + + def __rshift__(self, other: "Component") -> Union["Component", None]: + if self.parent is not None: + self.parent.auto_connect(receiving=other, sending=self) + return other + return None # pragma: no cover + + def _init_spec(self, cls: type["Component"]) -> set[str]: + """Gets initializable attributes for a class""" + spec = inspect.getfullargspec(cls) + if spec.defaults is None: + return set(spec.args[1:]) + serializable = set() + for arg, default_value in zip(reversed(spec.args), reversed(spec.defaults)): + if hasattr(self, arg) and getattr(self, arg) != default_value: + serializable.add(arg) + return serializable + + @property + def _serializable_attributes(self) -> set[str]: + """Gets all serializable attributes that have not been set to their default value""" + def _flatten(items: Sequence[Any]) -> Sequence[Any]: + res: list[Any] = [] + for item in items: + if hasattr(item, "__len__"): + res.extend(_flatten(item)) + else: + res.append(item) + return res + + serializable = set() + for cls in _flatten(inspect.getclasstree([self.__class__])): + serializable |= self._init_spec(cls) + return serializable + + if not TYPE_CHECKING: + @property + def scratch(self) -> Path: + """The directory to run the graph in""" + return self.config.scratch + + @scratch.setter + def scratch(self, path: Path) -> None: + self.config.scratch = path + + @property + def ports(self) -> dict[str, Port[Any]]: + """Provides a convenience iterator for all inputs and outputs.""" + return cast(dict[str, Port[Any]], self.inputs | self.outputs) + + @property + def flat_inputs(self) -> Generator[Input[Any], None, None]: + """Provides a flattened view of all node inputs, recursing into MultiInput""" + for port in self.inputs.values(): + subports: list[Input[Any]] | MultiInput[Any] = ( + port if isinstance(port, MultiInput) else [port] + ) + yield from subports + + @property + def flat_outputs(self) -> Generator[Output[Any], None, None]: + """Provides a flattened view of all node inputs, recursing into MultiInput""" + for port in self.outputs.values(): + subports: list[Output[Any]] | MultiOutput[Any] = ( + port if isinstance(port, MultiOutput) else [port] + ) + yield from subports + + @property + def parents(self) -> tuple["Component", ...] | None: + """Provides all parent components.""" + if self.parent is None: + return None + if self.parent.parents is None: + return (self.parent,) + return *self.parent.parents, self.parent + + @property + def component_path(self) -> tuple[str, ...]: + """Provides the full path to the component as a tuple of names.""" + if self.parents is None: + return (self.name,) + _, *parent_names = tuple(p.name for p in self.parents) + return *parent_names, self.name + + @property + def root(self) -> "Graph": + """Provides the root workflow or graph instance.""" + if self.parents is None: + return cast("Graph", self) + root, *_ = self.parents + return cast("Graph", root) + + @property + def node_config(self) -> NodeConfig: + """Provides the configuration of the current node""" + return self.config.nodes.get(self.__class__.__name__.lower(), NodeConfig()) + + @property + def n_outbound(self) -> int: + """Returns the number of items waiting to be sent""" + return sum(out.size for out in self.outputs.values()) + + @property + def n_inbound(self) -> int: + """Returns the number of items waiting to be received""" + return sum(inp.size for inp in self.inputs.values()) + + @property + def all_parameters(self) -> dict[str, Input[Any] | MultiInput[Any] | Parameter[Any]]: + """Returns all settable parameters and unconnected inputs""" + inp_params = {name: inp for name, inp in self.inputs.items() if not inp.is_connected(inp)} + return self.parameters | inp_params + + @property + def mapped_parameters(self) -> list[Parameter[Any] | Input[Any]]: + """Returns all parameters that have been mapped""" + return list( + itertools.chain.from_iterable( + para._parameters + for para in self.all_parameters.values() + if isinstance(para, MultiParameter) + ) + ) + + def setup_directories(self, parent_path: Path | None = None) -> None: + """Sets up the required directories.""" + if parent_path is None: + parent_path = Path("./") + self.work_dir = Path(parent_path / f"comp-{self.name}") + self.work_dir.mkdir() + + def as_dict(self) -> dict[str, Any]: + """Provides a non-recursive dictionary view of the component.""" + data = with_fields(self, self._COMPONENT_FIELDS & self._serializable_attributes) + data["parameters"] = { + k: para.value + for k, para in self.parameters.items() + if para.is_set and not para.is_default + } + data["parameters"] |= { + k: inp.value + for k, inp in self.inputs.items() + if isinstance(inp, Input) and not inp.is_connected(inp) and inp.active + } + data["type"] = self.__class__.__name__ + data["status"] = self.status.name + + # Remove potentially unneccessary keys + if data["status"] == self.status.READY.name: + data.pop("status") + if data["parameters"] == {}: + data.pop("parameters") + return data + + def send_update(self, exception: TracebackException | None = None) -> None: + """Send a status update to the main process.""" + summary = StatusUpdate( + name=self.name, + parents=self.component_path, + status=self.status, + run_time=self.run_timer.elapsed_time, + full_time=self.full_timer.elapsed_time, + n_inbound=self.n_inbound, + n_outbound=self.n_outbound, + exception=exception, + ) + self._message_queue.put(summary) + + def update_parameters(self, **kwargs: dict[str, Any]) -> None: + """ + Update component parameters. + + Parameters + ---------- + **kwargs + Name - value pairs supplied as keyword arguments + + """ + for key, value in kwargs.items(): + if key not in self.all_parameters: + raise KeyError( + f"Parameter '{key}' not found in component parameters\n" + f" Available parameters: '{self.all_parameters.keys()}'" + ) + if value is not None: + self.all_parameters[key].set(value) + + def ports_active(self) -> bool: + """ + Check if all required ports are active. + + Can be overridden by the user to allow custom shutdown scenarios, + for example in the case of complex inter-port dependencies. By + default only checks if any mandatory ports are inactive. + + Returns + ------- + bool + ``True`` if all required ports are active, ``False`` otherwise. + + """ + for subport in itertools.chain(self.flat_inputs, self.flat_outputs): + if (not subport.active) and (not subport.optional): + return False + return True diff --git a/maize/core/graph.py b/maize/core/graph.py new file mode 100644 index 0000000..88001a5 --- /dev/null +++ b/maize/core/graph.py @@ -0,0 +1,950 @@ +""" +Graph +----- +`Graph` is the container for any kind of graph, and can also act as an +individual component, for example when used as a subgraph. It will contain +multiple nodes, connected together using channels. It can also directly +expose parameters. `Workflow` extends `Graph` by providing executing and +serialization features. Building a graph can be done programatically. + +""" + +from collections.abc import Callable +import itertools +from pathlib import Path +import sys +from typing import ( + Generic, + Literal, + Optional, + TypeVar, + Any, + overload, + get_origin, + cast, + TYPE_CHECKING, +) + +from maize.core.component import Component +from maize.core.channels import DataChannel, Channel, FileChannel +from maize.core.interface import ( + Input, + Interface, + MultiPort, + Output, + Port, + Parameter, + MultiParameter, + MultiInput, + MultiOutput, +) +from maize.core.runtime import Status, setup_build_logging +from maize.utilities.utilities import extract_type, graph_cycles, is_path_type, matching_types +from maize.utilities.visual import HAS_GRAPHVIZ, nested_graphviz + +if TYPE_CHECKING: + from maize.core.node import Node + +T_co = TypeVar("T_co", covariant=True) +S_co = TypeVar("S_co", covariant=True) +U = TypeVar("U", bound=Component) +ChannelKeyType = tuple[tuple[str, ...], tuple[str, ...]] + + +class GraphBuildException(Exception): + """Exception raised for graph build issues.""" + + +class Graph(Component, register=False): + """ + Represents a graph (or subgraph) consisting of individual components. + + As a user, one will typically instantiate a `Graph` and then add + individual nodes or subgraphs and connect them together. To construct + custom subgraphs, create a custom subclass and overwrite the `build` + method, and add nodes and connections there as normal. + + Parameters + ---------- + parent + Parent component, typically the graph in context + name + The name of the component + description + An optional additional description + fail_ok + If True, the failure in the component will + not trigger the whole network to shutdown + n_attempts + Number of attempts at executing the `run()` method + level + Logging level, if not given or ``None`` will use the parent logging level + cleanup_temp + Whether to remove any temporary directories after completion + resume + Whether to resume from a previous checkpoint + logfile + File to output all log messages to, defaults to STDOUT + max_cpus + Maximum number of CPUs to use, defaults to the number of available cores in the system + max_gpus + Maximum number of GPUs to use, defaults to the number of available GPUs in the system + loop + Whether to run the `run` method in a loop, as opposed to a single time + strict + If ``True`` (default), will not allow generic node parameterisation and + raise an exception instead. You may want to switch this to ``False`` if + you're automating subgraph construction. + default_channel_size + The maximum number of items to allow for each channel connecting nodes + + Attributes + ---------- + nodes + Dictionary of nodes or subgraphs part of the `Graph` + channels + Dictionary of channels part of the `Graph` + + Raises + ------ + GraphBuildException + If there was an error building the subgraph, e.g. an unconnected port + + Examples + -------- + Defining a new subgraph wrapping an output-only example node with a delay node: + + >>> class SubGraph(Graph): + ... out: Output[int] + ... delay: Parameter[int] + ... + ... def build(self) -> None: + ... node = self.add(Example) + ... delay = self.add(Delay, parameters=dict(delay=2)) + ... self.connect(node.out, delay.inp) + ... self.out = self.map_port(delay.out) + ... self.delay = self.map(delay.delay) + + It can then be used just like any other node: + + >>> subgraph = g.add(SubGraph, name="subgraph", parameters={"delay": 10}) + >>> g.connect(subgraph.out, other.inp) + + """ + + _GRAPH_FIELDS = {"name", "description", "level", "scratch"} + + def __init__( + self, + parent: Optional["Graph"] = None, + name: str | None = None, + description: str | None = None, + fail_ok: bool = False, + n_attempts: int = 1, + level: int | str | None = None, + scratch: Path | None = None, + cleanup_temp: bool = True, + resume: bool = False, + logfile: Path | None = None, + max_cpus: int | None = None, + max_gpus: int | None = None, + loop: bool | None = None, + strict: bool = True, + default_channel_size: int = 10, + ) -> None: + super().__init__( + parent=parent, + name=name, + description=description, + fail_ok=fail_ok, + n_attempts=n_attempts, + level=level, + cleanup_temp=cleanup_temp, + scratch=scratch, + resume=resume, + logfile=logfile, + max_cpus=max_cpus, + max_gpus=max_gpus, + loop=loop, + ) + + # While nodes can be either a 'Node' or 'Graph' + # (as a subgraph), we flatten this topology at execution to + # just provide one large flat graph. + self.nodes: dict[str, Component] = {} + self.channels: dict[ChannelKeyType, Channel[Any]] = {} + self.logger = setup_build_logging(name=self.name, level=self.level) + self.default_channel_size = default_channel_size + self.strict = strict + + self.build() + self.check() + + @property + def flat_components(self) -> list[Component]: + """Flattened view of all components in the graph.""" + flat: list[Component] = [] + for node in self.nodes.values(): + flat.append(node) + if isinstance(node, Graph): + flat.extend(node.flat_components) + return flat + + @property + def flat_nodes(self) -> list["Node"]: + """Flattened view of all nodes in the graph.""" + flat: list["Node"] = [] + for node in self.nodes.values(): + if isinstance(node, Graph): + flat.extend(node.flat_nodes) + else: + flat.append(cast("Node", node)) + return flat + + @property + def flat_channels(self) -> set[ChannelKeyType]: + """Flattened view of all connections in the graph.""" + channels = set(self.channels.keys()) + for node in self.nodes.values(): + if isinstance(node, Graph): + channels |= node.flat_channels + return channels + + @property + def active_nodes(self) -> list["Node"]: + """Flattened view of all active nodes in the graph.""" + return [node for node in self.flat_nodes if node.status == Status.READY] + + def setup_directories(self, parent_path: Path | None = None) -> None: + """Create all work directories for the graph / workflow.""" + if parent_path is None: + if self.scratch is not None: + self.config.scratch = self.scratch + parent_path = self.config.scratch.absolute() + if not parent_path.exists(): + parent_path.mkdir(parents=True) + self.work_dir = Path(parent_path / f"graph-{self.name}") + + # If our graph directory already exists, we increment until we get a fresh one + i = 0 + while self.work_dir.exists(): + self.work_dir = Path(parent_path / f"graph-{self.name}-{i}") + i += 1 + + self.work_dir.mkdir() + for comp in self.nodes.values(): + comp.setup_directories(self.work_dir) + + # We have to defer the file channel folder creation until now + for (_, (*inp_node_path, inp_name)), chan in self.channels.items(): + if isinstance(chan, FileChannel): + inp_node = self.root.get_node(*inp_node_path) + + # Channel directory structure: graph/channel + if inp_node.parent is None: + raise GraphBuildException(f"Node {inp_node.name} has no parent") + chan.setup(destination=inp_node.parent.work_dir / f"{inp_node.name}-{inp_name}") + + def get_node(self, *names: str) -> "Component": + """ + Recursively find a node in the graph. + + Parameters + ---------- + names + Names of nodes leading up to the potentially nested target node + + Returns + ------- + Component + The target component + + Raises + ------ + KeyError + When the target cannot be found + + Examples + -------- + >>> g.get_node("subgraph", "subsubgraph", "foo") + Foo(name='foo', parent=SubSubGraph(...)) + + """ + root, *children = names + nested_node = self.nodes[root] + if isinstance(nested_node, Graph) and children: + return nested_node.get_node(*children) + return nested_node + + def get_parameter(self, *names: str) -> Parameter[Any] | Input[Any] | MultiInput[Any]: + """ + Recursively find a parameter in the graph. + + Parameters + ---------- + names + Names of components leading up to the target parameter + + Returns + ------- + Parameter + The target parameter + + Raises + ------ + KeyError + When the parameter cannot be found + + """ + *path, name = names + node = self.get_node(*path) + if name not in node.all_parameters: + raise KeyError( + f"Can't find parameter '{name}' in node '{node.name}'. " + f"Available parameters: {list(node.parameters.keys())}" + ) + return node.all_parameters[name] + + def get_port(self, *names: str) -> Port[Any]: + """ + Recursively find a port in the graph. + + Parameters + ---------- + names + Names of components leading up to the target port + + Returns + ------- + Port + The target port + + Raises + ------ + KeyError + When the target cannot be found + + """ + *path, name = names + node = self.get_node(*path) + return node.ports[name] + + def check(self) -> None: + """ + Checks if the graph was built correctly and warns about possible deadlocks. + + A correctly built graph has no unconnected ports, + and all channel types are matched internally. + + Raises + ------ + GraphBuildException + If a port is unconnected + + Examples + -------- + >>> g = Workflow(name="foo") + >>> foo = g.add(Foo) + >>> bar = g.add(Bar) + >>> g.auto_connect(foo, bar) + >>> g.check() + + """ + # Check connectivity first + self.check_connectivity() + + # Check for deadlock potential by detecting cycles + self.check_cycles() + + def check_connectivity(self) -> None: + """Checks graph connectivity""" + for node in self.flat_nodes: + for name, port in node.ports.items(): + # Subgraphs can have unconnected ports at build time + if (not port.connected) and (self.parent is None): + # Inputs with default values do not need to be connected + if ( + isinstance(port, Input) + and (port.is_set or port.optional) + or port in self.mapped_parameters + ): + continue + raise GraphBuildException( + f"Subgraph '{self.name}' internal port '{name}' " + f"of node '{node.name}' was not connected" + ) + + def check_cycles(self) -> None: + """Check if the graph has cycles and provides information about them""" + cycles = graph_cycles(self) + if cycles: + msg = "Cycles found:\n" + for cycle in cycles: + msg += " " + " -> ".join("-".join(c for c in cyc) for cyc in cycle) + "\n" + self.logger.debug(msg) + + def check_dependencies(self) -> None: + """Check all contained node dependencies""" + for node in self.flat_nodes: + node.check_dependencies() + + def add( + self, + component: type[U], + name: str | None = None, + parameters: dict[str, Any] | None = None, + **kwargs: Any, + ) -> U: + """ + Add a component to the graph. + + Parameters + ---------- + name + Unique name of the component + component + Node class or subgraph class + kwargs + Additional arguments passed to the component constructor + + Returns + ------- + Component + The initialized component + + Raises + ------ + GraphBuildException + If a node with the same name already exists + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo = g.add(Foo, name="foo", parameters=dict(val=42)) + >>> bar = g.add(Bar) + + """ + if isinstance(component, Component): + raise GraphBuildException("Cannot add already instantiated node to the graph") + + name = component.__name__.lower() if name is None else str(name) + if name in self.nodes: + raise GraphBuildException( + f"Node with name {name} already exists in graph, use a different name" + ) + + # Check that generic nodes are correctly parameterized + component.datatype = extract_type(component) + if get_origin(component) is None and issubclass(component, Generic): # type: ignore + component.datatype = Any + msg = ( + f"Node of type '{component.__name__}' is a generic and should use explicit " + "parameterization. See the 'Generic Nodes' section in the maize user guide." + ) + if self.strict: + raise GraphBuildException(msg) + self.logger.warning(msg) + + comp = component(parent=self, name=name, **kwargs) + if parameters is not None: + comp.update_parameters(**parameters) + self.nodes[name] = comp + return comp + + # Come on Guido... + _T1 = TypeVar("_T1", bound=Component) + _T2 = TypeVar("_T2", bound=Component) + _T3 = TypeVar("_T3", bound=Component) + _T4 = TypeVar("_T4", bound=Component) + _T5 = TypeVar("_T5", bound=Component) + _T6 = TypeVar("_T6", bound=Component) + + @overload + def add_all( + self, c1: type[_T1], c2: type[_T2], c3: type[_T3], c4: type[_T4], / + ) -> tuple[_T1, _T2, _T3, _T4]: ... + + @overload + def add_all( + self, c1: type[_T1], c2: type[_T2], c3: type[_T3], c4: type[_T4], c5: type[_T5], / + ) -> tuple[_T1, _T2, _T3, _T4, _T5]: ... + + @overload + def add_all( + self, + c1: type[_T1], + c2: type[_T2], + c3: type[_T3], + c4: type[_T4], + c5: type[_T5], + c6: type[_T6], + /, + ) -> tuple[_T1, _T2, _T3, _T4, _T5, _T6]: ... + + # No way to type this at the moment :( + def add_all(self, *components: type[Component]) -> tuple[Component, ...]: + """ + Adds all specified components to the graph. + + Parameters + ---------- + components + All component classes to initialize + + Returns + ------- + tuple[U, ...] + The initialized component instances + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo, bar = g.add_all(Foo, Bar) + + """ + return tuple(self.add(comp) for comp in components) + + def auto_connect(self, sending: Component, receiving: Component, size: int = 10) -> None: + """ + Connects component nodes together automatically, based + on port availability and datatype. + + This should really only be used in unambiguous cases, otherwise + this will lead to an only partially-connected graph. + + Parameters + ---------- + sending + Sending node + receiving + Receiving node + size + Size (in items) of the queue used for communication + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo = g.add(Foo) + >>> bar = g.add(Bar) + >>> g.auto_connect(foo, bar) + + """ + for out in sending.outputs.values(): + for inp in receiving.inputs.values(): + # We don't overwrite existing connections + if not (out.connected or inp.connected) and matching_types( + out.datatype, inp.datatype + ): + # The check for mismatched types in 'connect()' is + # redundant now, but it's easier this way + self.connect(sending=out, receiving=inp, size=size) + return + + def chain(self, *nodes: Component, size: int = 10) -> None: + """ + Connects an arbitrary number of nodes in sequence using `auto_connect`. + + Parameters + ---------- + nodes + Nodes to be connected in sequence + size + Size of each channel connecting the nodes + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo = g.add(Foo) + >>> bar = g.add(Bar) + >>> baz = g.add(Baz) + >>> g.chain(foo, bar, baz) + + """ + for sending, receiving in itertools.pairwise(nodes): + self.auto_connect(sending=sending, receiving=receiving, size=size) + + P = TypeVar("P") + P1 = TypeVar("P1") + P2 = TypeVar("P2") + P3 = TypeVar("P3") + P4 = TypeVar("P4") + P5 = TypeVar("P5") + + @overload + def connect_all( + self, + p1: tuple[Output[P] | MultiOutput[P], Input[P] | MultiInput[P]], + p2: tuple[Output[P1] | MultiOutput[P1], Input[P1] | MultiInput[P1]], + /, + ) -> None: ... + + @overload + def connect_all( + self, + p1: tuple[Output[P] | MultiOutput[P], Input[P] | MultiInput[P]], + p2: tuple[Output[P1] | MultiOutput[P1], Input[P1] | MultiInput[P1]], + p3: tuple[Output[P2] | MultiOutput[P2], Input[P2] | MultiInput[P2]], + /, + ) -> None: ... + + @overload + def connect_all( + self, + p1: tuple[Output[P] | MultiOutput[P], Input[P] | MultiInput[P]], + p2: tuple[Output[P1] | MultiOutput[P1], Input[P1] | MultiInput[P1]], + p3: tuple[Output[P2] | MultiOutput[P2], Input[P2] | MultiInput[P2]], + p4: tuple[Output[P3] | MultiOutput[P3], Input[P3] | MultiInput[P3]], + /, + ) -> None: ... + + @overload + def connect_all( + self, + p1: tuple[Output[P] | MultiOutput[P], Input[P] | MultiInput[P]], + p2: tuple[Output[P1] | MultiOutput[P1], Input[P1] | MultiInput[P1]], + p3: tuple[Output[P2] | MultiOutput[P2], Input[P2] | MultiInput[P2]], + p4: tuple[Output[P3] | MultiOutput[P3], Input[P3] | MultiInput[P3]], + p5: tuple[Output[P4] | MultiOutput[P4], Input[P4] | MultiInput[P4]], + /, + ) -> None: ... + + @overload + def connect_all( + self, + p1: tuple[Output[P] | MultiOutput[P], Input[P] | MultiInput[P]], + p2: tuple[Output[P1] | MultiOutput[P1], Input[P1] | MultiInput[P1]], + p3: tuple[Output[P2] | MultiOutput[P2], Input[P2] | MultiInput[P2]], + p4: tuple[Output[P3] | MultiOutput[P3], Input[P3] | MultiInput[P3]], + p5: tuple[Output[P4] | MultiOutput[P4], Input[P4] | MultiInput[P4]], + p6: tuple[Output[P5] | MultiOutput[P5], Input[P5] | MultiInput[P5]], + /, + ) -> None: ... + + # Same as for `add_all`: no way to type this + def connect_all( + self, *ports: tuple[Output[Any] | MultiOutput[Any], Input[Any] | MultiInput[Any]] + ) -> None: + """ + Connect multiple pairs of ports together. + + Parameters + ---------- + ports + Output - Input pairs to connect + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo = g.add(Foo) + >>> bar = g.add(Bar) + >>> baz = g.add(Baz) + >>> g.connect_all((foo.out, bar.inp), (bar.out, baz.inp)) + + """ + for out, inp in ports: + self.connect(sending=out, receiving=inp) + + def check_port_compatibility( + self, sending: Output[T_co] | MultiOutput[T_co], receiving: Input[T_co] | MultiInput[T_co] + ) -> None: + """ + Checks if two ports can be connected. + + Parameters + ---------- + sending + Output port for sending items + receiving + Input port for receiving items + + Raises + ------ + GraphBuildException + If the port types don't match, or the maximum number + of channels supported by your OS has been reached + + """ + + if not matching_types(sending.datatype, receiving.datatype): + msg = ( + f"Incompatible ports: " + f"'{sending.parent.name}.{sending.name}' expected '{sending.datatype}', " + f"'{receiving.parent.name}.{receiving.name}' got '{receiving.datatype}'" + ) + raise GraphBuildException(msg) + + if sending.parent.root is not receiving.parent.root: + msg = ( + "Attempting to connect nodes from separate workflows, " + f"'{sending.parent.root.name}' sending, '{receiving.parent.root.name}' receiving" + ) + raise GraphBuildException(msg) + + # Check for accidental duplicate assignments + for port in (sending, receiving): + if not isinstance(port, MultiPort) and port.connected: + raise GraphBuildException( + f"Port '{port.name}' of node '{port.parent.name}' is already connected" + ) + + def connect( + self, + sending: Output[T_co] | MultiOutput[T_co], + receiving: Input[T_co] | MultiInput[T_co], + size: int | None = None, + mode: Literal["copy", "link", "move"] | None = None, + ) -> None: + """ + Connects component inputs and outputs together. + + Parameters + ---------- + sending + Output port for sending items + receiving + Input port for receiving items + size + Size (in items) of the queue used for communication, only for serializable data + mode + Whether to link, copy or move files, overrides value specified for the port + + Raises + ------ + GraphBuildException + If the port types don't match, or the maximum number + of channels supported by your OS has been reached + + Examples + -------- + >>> g = Graph(name="foo") + >>> foo = g.add(Foo) + >>> bar = g.add(Bar) + >>> g.connect(foo.out, bar.inp) + + """ + self.check_port_compatibility(sending, receiving) + + # FIXME This heuristic fails when chaining multiple `Any`-parameterised generic + # nodes after a `Path`-based node, as the information to use a `FileChannel` will + # be lost. This originally cropped in the `parallel` macro. + if is_path_type(sending.datatype) or is_path_type(receiving.datatype): + # Precedence should be copy > link > move, if one + # port wants copies we should respect that + if mode is None: + if "copy" in (receiving.mode, sending.mode): + mode = "copy" + elif "link" in (receiving.mode, sending.mode): + mode = "link" + else: + mode = "move" + channel: Channel[Any] = FileChannel(mode=mode) + else: + size = size if size is not None else self.default_channel_size + channel = DataChannel(size=size) + + sending.set_channel(channel) + receiving.set_channel(channel) + + # Add info on which port of a MultiPort is connected + send_path = sending.path + if isinstance(sending, MultiOutput): + send_path = *send_path, str(len(sending) - 1) + recv_path = receiving.path + if isinstance(receiving, MultiInput): + recv_path = *recv_path, str(len(receiving) - 1) + + self.logger.debug( + "Connected '%s' -> '%s' using %s(%s)", + "-".join(send_path), + "-".join(recv_path), + channel.__class__.__name__, + size, + ) + self.channels[(sending.path, receiving.path)] = channel + + _P = TypeVar("_P", bound=Port[Any]) + + def map_port(self, port: _P, name: str | None = None) -> _P: + """ + Maps a port of a component to the graph. + + This will be required when creating custom subgraphs, + ports of individual component nodes will need to be + mapped to the subgraph. This method also handles setting + a graph attribute with the given name. + + Parameters + ---------- + port + The component port + name + Name for the port to be registered as + + Returns + ------- + _P + Mapped port + + Examples + -------- + >>> def build(self): + ... node = self.add(Example) + ... self.map_port(node.output, name="output") + + """ + if name is None: + name = port.name + if name in self.ports: + raise KeyError(f"Port with name '{name}' already exists in graph '{self.name}'") + + if isinstance(port, Input | MultiInput): + self.inputs[name] = port + elif isinstance(port, Output | MultiOutput): + self.outputs[name] = port + setattr(self, name, port) + return port + + def combine_parameters( + self, + *parameters: Parameter[T_co] | Input[T_co], + name: str | None = None, + default: S_co | None = None, + optional: bool | None = None, + hook: Callable[[S_co], T_co] | None = None, + ) -> MultiParameter[S_co, T_co]: + """ + Maps multiple low-level parameters to one high-level one. + + This can be useful when a single parameter needs to be + supplied to multiple nodes within a subgraph. This method + also handles setting a graph attribute with the given name. + + Parameters + ---------- + parameters + Low-level parameters of component nodes + name + Name of the high-level parameter + default + The default parameter value + optional + Whether the mapped parameters should be considered optional + hook + Optional hook to be called on the corresponding value before setting the + contained parameter(s). This allows the use of different datatypes for the + parent and child parameters, making more complex parameter behaviour possible. + + Returns + ------- + MultiParameter + The combined parameter object + + Examples + -------- + >>> def build(self): + ... foo = self.add(Foo) + ... bar = self.add(Bar) + ... self.map_parameters( + ... foo.param, bar.param, name="param", default=42) + + """ + if name is None: + name = parameters[0].name + + if name in self.parameters: + raise GraphBuildException(f"Parameter with name '{name}' already exists in graph") + + multi_param: MultiParameter[S_co, T_co] = MultiParameter( + parameters=parameters, default=default, optional=optional, hook=hook + ).build(name=name, parent=self) + self.parameters[name] = multi_param + setattr(self, name, multi_param) + return multi_param + + def map(self, *interfaces: Interface[Any]) -> None: + """ + Map multiple child interfaces (ports or parameters) onto the current graph. + Will also set the graph attributes to the names of the mapped interfaces. + + Parameters + ---------- + interfaces + Any number of ports and parameters to map + + See also + -------- + Graph.map_parameters + If you want to map multiple parameters to a single high-level one + Graph.map_port + If you want more fine-grained control over naming + + Examples + -------- + >>> def build(self): + ... foo = self.add(Foo) + ... bar = self.add(Bar) + ... self.map(foo.inp, bar.out, foo.param) + + """ + for inter in interfaces: + if isinstance(inter, Parameter): + self.combine_parameters(inter) + elif isinstance(inter, Port): + self.map_port(inter) + else: + raise ValueError(f"'{inter}' is not a valid interface") + + def visualize( + self, + max_level: int = sys.maxsize, + coloring: Literal["nesting", "status"] = "nesting", + labels: bool = True, + ) -> Any: + """ + Visualize the graph using graphviz, if installed. + + Parameters + ---------- + max_level + Maximum nesting level to show, shows all levels by default + coloring + Whether to color nodes by nesting level or status + labels + Whether to show datatype labels + + Returns + ------- + dot + Graphviz `Dot` instance, in a Jupyter notebook + this will be displayed visually automatically + + """ + if HAS_GRAPHVIZ: + dot = nested_graphviz(self, max_level=max_level, coloring=coloring, labels=labels) + return dot + return None + + def build(self) -> None: + """ + Builds a subgraph. + + Override this method to construct a subgraph encapsulating + multiple lower-level nodes, by using the `add` and `connect` + methods. Additionally use the `map`, `map_port`, and `map_parameters` + methods to create a subgraph that can be used just like a node. + + Examples + -------- + >>> def build(self): + ... foo = self.add(Foo) + ... bar = self.add(Bar) + ... self.map(foo.inp, bar.out, foo.param) + + """ + + # This is for jupyter / ipython, see: + # https://ipython.readthedocs.io/en/stable/config/integrating.html#MyObject._repr_mimebundle_ + def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> Any: + if (dot := self.visualize()) is not None: + return dot._repr_mimebundle_(*args, **kwargs) # pylint: disable=protected-access + return None diff --git a/maize/core/interface.py b/maize/core/interface.py new file mode 100644 index 0000000..27c9e34 --- /dev/null +++ b/maize/core/interface.py @@ -0,0 +1,1449 @@ +""" +Interface +--------- +This module encompasses all possible node interfaces. This includes parameters, +such as configuration files (using `FileParameter`) and simple values (using `Parameter`), +but also `Input` (`MultiInput`) and `Output` (`MultiOutput`) ports allowing the attachment of +(multiple) channels to communicate with other nodes. All interfaces expose a ``datatype`` +attribute that is used to ensure the usage of correct types when constructing a workflow graph. + +""" + +from abc import abstractmethod +from collections.abc import Iterable, Callable, Sequence, Generator +from datetime import datetime +import inspect +import logging +import os +from pathlib import Path +from typing import ( + Annotated, + Literal, + TypeGuard, + TypeVar, + ClassVar, + Generic, + Union, + Any, + cast, + get_args, + get_origin, + TYPE_CHECKING, +) + +from maize.core.runtime import Status +from maize.core.channels import Channel, ChannelFull +from maize.utilities.utilities import ( + extract_superclass_type, + extract_type, + format_datatype, + tuple_to_nested_dict, + NestedDict, + typecheck, +) + +if TYPE_CHECKING: + # In commit #6ec9884 we introduced a setter into Component, this seems + # to have the effect of mypy not being able to import the full class + # definition at type-checking time. The results are various "Cannot + # determine type of ... [has-type]", and "Incompatible return value + # type (got "T", expected "T | None") [return-value]" errors. See + # https://github.com/python/mypy/issues/16337 for a potentially + # related error. + from maize.core.component import Component + +# I have temporarily removed dynamic typing for now, but if we +# reintroduce it intersection types will come in handy: +# https://github.com/python/typing/issues/213 + +T = TypeVar("T") + + +log = logging.getLogger(f"run-{os.getpid()}") + + +UPDATE_INTERVAL = 60 + + +class Suffix: + """ + Utility class to annotate paths with restrictions on possible suffixes. + + Parameters + ---------- + suffixes + Any number of possible file suffixes without the leading dot + + Examples + -------- + >>> path: Annotated[Path, Suffix("pdb", "gro", "tpr")] = Path("conf.xyz") + >>> pred = get_args(path)[1] + >>> pred(path) + False + + In practice it might look like this: + + >>> class Example(Node): + ... para = FileParameter[Annotated[Path, Suffix("pdb", "gro")]]() + ... out = Output[int]() + ... def run(self) -> None: ... + + This will then raise a `ValueError` when trying to set the + parameter with a file that doesn't have the correct suffix. + + See Also + -------- + FileParameter + Parameter subclass using these annotations + + """ + + def __init__(self, *suffixes: str) -> None: + self._valid = frozenset(suf.lstrip(".") for suf in suffixes) + + def __call__(self, path: Path) -> bool: + return path.suffix.lstrip(".") in self._valid + + def __repr__(self) -> str: + return f"Suffix({', '.join(self._valid)})" + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, self.__class__): + return False + return len(self._valid.intersection(__value._valid)) > 0 + + def __hash__(self) -> int: + return hash(self._valid) + + +# Default port polling timeout, i.e. the frequency with which I/O for nodes is polled +DEFAULT_TIMEOUT = 0.5 + + +# The idea behind the interface base class is that we can declare it +# as a class attribute in the node, with type information and some +# settings. Then when the node constructor is executed, we create a +# copy of each interface to make sure each node instance has an +# individual instance of an interface. This is done by saving the +# constructor arguments (using `__new__`) and reinitializing the +# class by calling the constructor again in `build`. +class Interface(Generic[T]): + """ + Interface parent class, handles behaviour common to ports and parameters. + + Attributes + ---------- + name + Unique interface name + parent + Parent node instance this interface is part of + datatype + The datatype of the associated value. This may be a type from the ``typing`` + library and thus not always usable for checks with ``isinstance()`` + + """ + + _TInter = TypeVar("_TInter", bound="Interface[Any]") + + name: str + parent: "Component" + datatype: Any + doc: str | None + + # These allow `build` to create a shallow copy of the interface + _args: tuple[Any, ...] + _kwargs: dict[str, Any] + + # Where to place the interface in the parent component + _target: ClassVar[str] + + # See comment in `build`, this is for the typechecker + __orig_class__: ClassVar[Any] + + def __new__(cls: type[_TInter], *args: Any, **kwargs: Any) -> _TInter: + inst = super().__new__(cls) + inst._args = args + inst._kwargs = kwargs + return inst + + def __repr__(self) -> str: + if not hasattr(self, "name"): + return f"{self.__class__.__name__}()" + return ( + f"{self.__class__.__name__}[{self.datatype}]" + f"(name='{self.name}', parent='{self.parent.name}')" + ) + + @property + def path(self) -> tuple[str, ...]: + """Provides a unique path to the interface.""" + return *self.parent.component_path, self.name + + @staticmethod + def _update_generic_type(obj: _TInter) -> None: + """Update the contained generic datatype with parent information.""" + if isinstance(obj.datatype, TypeVar) and obj.parent.datatype is not None: + obj.datatype = obj.parent.datatype + + # Check for the case where the datatype is GenericAlias but the argument is + # a TypeVar e.g. List[~T] + if get_origin(obj.datatype) is not None: + args = get_args(obj.datatype) + if isinstance(args[0], TypeVar) and obj.parent.datatype is not None: + obj.datatype = get_origin(obj.datatype)[obj.parent.datatype] + + @property + def serialized(self) -> dict[str, Any]: + """Provides a serialized summary of the parameter""" + dtype = Any if not hasattr(self, "datatype") else self.datatype + return {"type": str(dtype), "kind": self.__class__.__name__} + + # What we would really need to type this correctly is a `TypeVar` + # that is bound to a generic, i.e. `TypeVar("U", bound=Interface[T])`, + # (or higher-kinded types) but this doesn't seem possible at the moment, see: + # https://github.com/python/mypy/issues/2756 + # https://github.com/python/typing/issues/548 + def build(self: _TInter, name: str, parent: "Component") -> _TInter: + """ + Instantiate an interface from the description. + + Parameters + ---------- + name + Name of the interface, will typically be the attribute name of the parent object + parent + Parent component instance + + Returns + ------- + _TInter + Copy of the current instance, with references to the name and the parent + + """ + inst = self.__class__(*self._args, **self._kwargs) + inst.name = name + inst.parent = parent + inst.datatype = None + inst.doc = None + + inst.datatype = extract_type(self) + if inst.datatype is None: + inst.datatype = extract_superclass_type(parent, name) + + # If the parent component is generic and was instantiated using + # type information, we can use that information to update the + # interface's datatype (in case it was also generic). This is + # unfortunately not a comprehensive solution, as we might have + # multiple typevars or containerized types, but those are quite + # tricky to handle properly. + self._update_generic_type(inst) + + # Register in parent inputs/outputs/parameters dictionary + target_dict = getattr(parent, self._target) + + # We don't want to overwrite an already existing multiport instance + if name not in target_dict: + target_dict[name] = inst + return inst + + def is_file(self) -> bool: + """Returns if the interface is wrapping a file""" + dtype = self.datatype + if get_origin(dtype) == Annotated: + dtype = get_args(dtype)[0] + return isinstance(dtype, type) and issubclass(dtype, Path) + + def is_file_iterable(self) -> bool: + """Returns if the interface is wrapping an iterable containing files""" + dtype = self.datatype + if hasattr(dtype, "__iter__"): + if not get_args(dtype): + return False + for sub_dtype in get_args(dtype): + if get_origin(sub_dtype) == Annotated: + sub_dtype = get_args(sub_dtype)[0] + if not inspect.isclass(sub_dtype): + return False + if not issubclass(sub_dtype, Path): + return False + + # All sub-datatypes are paths + return True + + # Not iterable + return False + + def check(self, value: T) -> bool: + """ + Checks if a value is valid using type annotations. + + Parameters + ---------- + value + Value to typecheck + + Returns + ------- + bool + ``True`` if the value is valid, ``False`` otherwise + + """ + if self.is_file() and typecheck(value, str): + # At this point we know we can safely cast to Path and check the predicates + value = Path(value) # type: ignore + if self.is_file_iterable(): + value = type(value)(Path(v) for v in value) # type: ignore + return typecheck(value, self.datatype) + + +class ParameterException(Exception): + """Exception raised for parameter issues.""" + + +class Parameter(Interface[T]): + """ + Task parameter container. + + Parameters + ---------- + default + Default value for the parameter + default_factory + The default factory in case of mutable values + optional + If ``True``, the parent node will not check if this parameter has been set + + Attributes + ---------- + name + Name of the parameter + parent + Parent component + datatype + The datatype of the associated parameter. This may be a type from the ``typing`` + library and thus not always usable for checks with ``isinstance()`` + + See Also + -------- + FileParameter + Allows specifying a read-only file as a parameter + Parameter + Allows specifying any data value as a parameter + Flag + Alias for boolean parameters + + """ + + _target = "parameters" + + def __init__( + self, + default: T | None = None, + default_factory: Callable[[], T] | None = None, + optional: bool = False, + ) -> None: + self.default = default + if self.default is None and default_factory is not None: + self.default = default_factory() + self.optional = optional or self.default is not None + self._value: T | None = self.default + self._changed = False + + @property + def changed(self) -> bool: + """Returns whether the parameter was explicitly set""" + return self._changed + + @property + def is_default(self) -> bool: + """Returns whether the default value is set""" + comp = self.default == self._value + if isinstance(comp, Iterable): + return all(comp) + return comp + + @property + def is_set(self) -> bool: + """Indicates whether the parameter has been set to a non-``None`` value.""" + return self._value is not None + + @property + def value(self) -> T: + """Provides the value of the parameter.""" + if self._value is None: + raise ParameterException( + f"Parameter '{self.name}' of node '{self.parent.name}' must be set" + ) + return self._value + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"default": self.default, "optional": self.optional} + + @property + def skippable(self) -> bool: + """Indicates whether this parameter can be skipped""" + return self.optional and not self.changed + + def set(self, value: T) -> None: + """ + Set the parameter to a specified value. + + Parameters + ---------- + value + Value to set the parameter to + + Raises + ------ + ValueError + If the datatype doesn't match with the parameter type + + Examples + -------- + >>> foo.val.set(42) + + See Also + -------- + graph.Graph.add + Allows setting parameters at the point of node addition + component.Component.update_parameters + Sets multiple parameters over a node, graph, or workflow at once + + """ + if not self.check(value): + raise ValueError( + f"Error validating value '{value}' of type '{type(value)}'" + f" against parameter type '{self.datatype}'" + ) + self._changed = True + self._value = value + + +Flag = Parameter[bool] + + +P = TypeVar("P", bound=Path | list[Path]) + + +class FileParameter(Parameter[P]): + """ + Allows provision of files as parameters to nodes. + + Parameters + ---------- + exist_required + If ``True`` will raise an exception if the specified file can't be found + + Raises + ------ + FileNotFoundError + If the input file doesn't exist + + See Also + -------- + Parameter + Allows specifying values as parameters + + """ + + def __init__( + self, default: P | None = None, exist_required: bool = True, optional: bool = False + ) -> None: + super().__init__(default, default_factory=None, optional=optional) + self.exist_required = exist_required + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"exist_required": self.exist_required} + + @property + def filepath(self) -> P: + """Provides the path to the file.""" + paths = self.value if isinstance(self.value, list) else [self.value] + if self.exist_required: + for path in paths: + if path is None or not path.exists(): + path_str = path.as_posix() if path is not None else "None" + raise FileNotFoundError(f"Parameter file at '{path_str}' not found") + return self.value + + def set(self, value: P) -> None: + if isinstance(value, list): + path: Path | list[Path] = [Path(val).absolute() for val in value] + elif isinstance(value, Path | str): + path = Path(value).absolute() + super().set(cast(P, path)) + + +ParameterMappingType = dict[str, str | T | list[NestedDict[str, str]] | None] + + +S = TypeVar("S") + + +class MultiParameter(Parameter[T], Generic[T, S]): + """ + Container for multiple parameters. Allows setting multiple + low-level parameters with a single high-level one. + + When constructing subgraphs, one will often want to map a + single parameter to multiple component nodes. This is where + `MultiParameter` is useful, as it will automatically set those + component parameters. If you wish to perform some more elaborate + processing instead, subclass `MultiParameter` and overwrite the + `MultiParameter.set` method. + + Do not use this class directly, instead make use of the + :meth:`maize.core.graph.Graph.combine_parameters` method. + + Parameters + ---------- + parameters + Sequence of `Parameter` instances that + will be updated with a call to `set` + default + Default value for the parameter + optional + Whether the parameter will be considered optional + hook + An optional hook function mapping from ``T`` to ``S`` + + Attributes + ---------- + name + Name of the parameter + parent + Parent component + + See Also + -------- + graph.Graph.combine_parameters + Uses `MultiParameter` in the background to combine multiple parameters + from separate nodes into a single parameter for a subgraph. + + """ + + _TInter = TypeVar("_TInter", bound="MultiParameter[Any, Any]") + + def __init__( + self, + parameters: Sequence[Parameter[S] | "Input[S]"], + default: T | None = None, + optional: bool | None = None, + hook: Callable[[T], S] | None = None, + ) -> None: + def _id(x: T) -> S: + return cast(S, x) + + super().__init__(default=default) + + if len(dtypes := {param.datatype for param in parameters}) > 1: + raise ParameterException(f"Inconsistent datatypes in supplied parameters: {dtypes}") + + self._parameters = list(parameters) + self.doc = parameters[0].doc + self.hook = _id if hook is None else hook + + if hook is None: + self.datatype = parameters[0].datatype + + if parameters[0].is_set and default is None and hook is None: + self._value = cast(T, parameters[0].value) + elif default is not None: + self.set(default) + + # This allows us to set parameters as optional on the workflow level + if optional is not None: + self.optional = optional + for para in self._parameters: + para.optional = optional + + @property + def parents(self) -> list["Component"]: + """Provides the original parent nodes of all contained parameters.""" + return [param.parent for param in self._parameters] + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"default": self.default} + + @property + def original_names(self) -> list[str]: + """Provides the original names of the contained parameters""" + return [para.name for para in self._parameters] + + def build(self: _TInter, name: str, parent: "Component") -> _TInter: + self.name = name + self.parent = parent + return self + + def set(self, value: T) -> None: + """ + Sets all contained parameters to the value. + + Parameters + ---------- + value + The value to be set for all contained parameters + check + Whether to check the value against the parameter datatype + + """ + self._value = value + for param in self._parameters: + param.set(self.hook(value)) + + def as_dict(self) -> ParameterMappingType[T]: + """ + Provide a dictionary representation of the parameter mapping. + + Returns + ------- + ParameterMappingType[T] + Dictionary representation of the `MultiParameter` + + """ + val = None if not self.is_set else self.value + data: ParameterMappingType[T] = dict( + name=self.name, value=val, type=format_datatype(self.datatype) + ) + mapping: list[NestedDict[str, str]] = [] + for para in self._parameters: + mapping.append(tuple_to_nested_dict(para.parent.name, para.name)) + data["map"] = mapping + return data + + +class PortException(Exception): + """Exception raised for channel issues.""" + + +_PortType = TypeVar("_PortType", bound="Port[Any]") + + +class Port(Interface[T]): + """ + Port parent class, use the `Input` or `Output` classes to specify user connections. + + Parameters + ---------- + timeout + Timeout used for continuously polling the connection + for data on a blocking `receive` call. + optional + Whether this port is required for the process to stay alive. + If the connection to an optional port is closed by a neighbouring + process, the current node will not shutdown. + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + hook + An optional function to be called on the data being sent or received. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + """ + + def __init__( + self, + timeout: float = DEFAULT_TIMEOUT, + optional: bool = False, + mode: Literal["copy", "link", "move"] = "copy", + hook: Callable[[T], T] = lambda x: x, + ) -> None: + self.timeout = timeout + self.optional = optional + self.mode = mode + self.hook = hook + self.channel: Channel[T] | None = None + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"optional": self.optional, "mode": self.mode} + + @property + def active(self) -> bool: + """Specifies whether the port is active or not.""" + return Port.is_connected(self) and self.channel.active + + @property + def connected(self) -> bool: + """Specifies whether the port is connected.""" + return self.is_connected(self) + + @property + def size(self) -> int: + """Returns the approximate number of items waiting in the channel""" + return 0 if self.channel is None else self.channel.size + + @staticmethod + def is_connected(port: _PortType) -> TypeGuard["_PortChannel[T]"]: + """Specifies whether the port is connected.""" + return port.channel is not None + + def set_channel(self, channel: Channel[T]) -> None: + """ + Set the channel associated with the port. This needs + to be called when connecting two ports together. + + Parameters + ---------- + channel + An instantiated subclass of `Channel` + + """ + self.channel = channel + + def close(self) -> None: + """ + Closes the port. + + This can be detected by neighbouring nodes waiting on + the port, and subsequently cause them to shut down. + + Raises + ------ + PortException + If the port is not connected + + """ + if self.channel is not None: + self.channel.close() + + +# Mypy type guard, see: https://stackoverflow.com/questions/71805426/how-to-tell-a-python-type-checker-that-an-optional-definitely-exists +class _PortChannel(Port[T]): + channel: Channel[T] + _value: T | None + _preloaded: bool + cached: bool + + +class Output(Port[T]): + """ + Output port to allow sending arbitrary data. + + Parameters + ---------- + timeout + Timeout used for continuously polling the connection + for sending data into a potentially full channel. + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + hook + An optional function to be called on the data being sent or received. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + """ + + _target = "outputs" + + # Convenience function allowing connections by + # using the shift operator on ports like so: + # a.out >> b.in + def __rshift__(self, other: "Input[T]") -> None: + if self.parent.parent is not None: + self.parent.parent.connect(receiving=other, sending=self) + + def send(self, item: T) -> None: + """ + Sends data through the channel. + + Parameters + ---------- + item + Item to send through the channel + + Raises + ------ + PortException + If trying to send through an unconnected port + + Examples + -------- + In the `run` method of the sending node: + + >>> self.output.send(42) + + """ + if not Port.is_connected(self): + raise PortException("Attempting send through an inactive port") + + # Get the current time to allow regular status update communication + current_time = datetime.now() + + self.parent.status = Status.WAITING_FOR_OUTPUT + while not self.parent.signal.is_set(): + # Send an update if it's taking a while + delta = datetime.now() - current_time + if delta.seconds > UPDATE_INTERVAL: + self.parent.send_update() + current_time = datetime.now() + + # If our connection partner shuts down while we are trying to + # send something, that means we should shutdown too + if not self.active: + self.parent.status = Status.STOPPED + raise PortInterrupt("Port is dead, stopping node", name=self.name) + + try: + self.channel.send(self.hook(cast(Any, item)), timeout=self.timeout) + + # This essentially allows the FBP concept of 'back-pressure' + except ChannelFull: + continue + else: + self.parent.status = Status.RUNNING + return + + +class PortInterrupt(KeyboardInterrupt): + """ + Interrupt raised to quit a process immediately to avoid + propagating ``None`` values to downstream nodes. + + """ + + def __init__(self, *args: object, name: str | None = None) -> None: + self.name = name + super().__init__(*args) + + +class Input(Port[T]): + """ + Input port to allow receiving or parameterising data. + + An input port can either be connected to another node's output + port to dynamically receive data, or set to a value (with optional + default) before workflow execution to obtain a static value, + making it behave analogously to `Parameter`. + + Parameters + ---------- + default + Default value for the parameter + default_factory + The default factory in case of mutable values + timeout + Timeout used for continuously polling the connection + for data on a blocking `receive` call. + optional + If ``True``, this port does not have to be connected. Nodes should + then check if data is available first (by a call to ``ready()``) + before accessing it. Also determines whether this port is required + for the process to stay alive. If the connection to an optional port + is closed by a neighbouring process, the current node will not shutdown. + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + cached + If ``True``, will cache the latest received value and immediately return + this value when calling `receive` while the channel is empty. This is useful + in cases where a node will run in a loop, but some inputs stay constant, as + those constant inputs will only need to receive a value a single time. + hook + An optional function to be called on the data being sent or received. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + """ + + _target = "inputs" + _value: T | None + _preloaded: bool + + def __init__( + self, + default: T | None = None, + default_factory: Callable[[], T] | None = None, + timeout: float = DEFAULT_TIMEOUT, + optional: bool = False, + mode: Literal["copy", "link", "move"] = "copy", + cached: bool = False, + hook: Callable[[T], T] = lambda x: x, + ) -> None: + super().__init__(timeout, optional, mode) + self.default = default + self.cached = cached + self.hook = hook + if self.default is None and default_factory is not None: + self.default = default_factory() + self._value = self.default + self._changed = False + self._preloaded = False + + # Convenience function allowing connections by + # using the shift operator on ports like so: + # a.in << b.out + def __lshift__(self, other: "Output[T]") -> None: + if self.parent.parent is not None: + self.parent.parent.connect(receiving=self, sending=other) + + @property + def changed(self) -> bool: + """Returns whether the input was explicitly set""" + return self._changed + + @property + def is_default(self) -> bool: + """Returns whether the default value is set""" + return self.default == self._value + + @property + def is_set(self) -> bool: + """Indicates whether the input has been set to a non-``None`` value or is an active port.""" + return self._value is not None or self.active + + @property + def active(self) -> bool: + return super().active or self.ready() or (self.cached and self._value is not None) + + @property + def value(self) -> T: + """Receives a value. Alias for `receive()`""" + return self.receive() + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"default": self.default, "cached": self.cached} + + @property + def skippable(self) -> bool: + """Indicates whether this input can be skipped""" + return self.optional and not self.is_set + + def set(self, value: T) -> None: + """ + Set the input to a static value. + + Parameters + ---------- + value + Value to set the input to + + Raises + ------ + ValueError + If the datatype doesn't match with the parameter type + ParameterException + When attempting to set a value for a connected port + + """ + if not self.check(value): + raise ValueError( + f"Error validating value '{value}' of type '{type(value)}'" + f" against parameter type '{self.datatype}'" + ) + if Port.is_connected(self): + raise ParameterException("Can't set an Input that is already connected") + + if self.is_file(): + value = Path(value).absolute() # type: ignore + if self.is_file_iterable(): + value = type(value)(p.absolute() for p in value) # type: ignore + + self._changed = True + self._value = value + + def preload(self, value: T) -> None: + """ + Preload a value on an input. Internal use only. + + Parameters + ---------- + value + Value to set the parameter to + + Raises + ------ + ValueError + If the datatype doesn't match with the parameter type + + """ + if not self.check(value): + raise ValueError( + f"Error validating value '{value}' of type '{type(value)}'" + f" against parameter type '{self.datatype}'" + ) + self._preloaded = True + self._value = value + + def dump(self) -> list[T]: + """ + Dump any data contained in the channel. + + Returns + ------- + list[T] + List or all items contained in the channel + + """ + if self.channel is None: + if self._value is None: + raise PortException("Cannot dump from unconnected ports") + return [self.hook(self._value)] + return [self.hook(val) for val in self.channel.flush()] + + def ready(self) -> bool: + """ + Specifies whether the input has data available to read. + + This allows checking for data without a blocking + receive, thus allowing nodes to use optional inputs. + + Returns + ------- + bool + ``True`` if there is data in the channel ready to be read + + Examples + -------- + >>> if self.input.ready(): + ... val = self.input.receive() + ... else: + ... val = 42 + + """ + if self._preloaded: + return True + if Port.is_connected(self): + return self.channel.ready + return self._value is not None + + def receive(self) -> T: + """ + Receives data from the port and blocks. + + Returns + ------- + T + Item received from the channel. + + Raises + ------ + PortInterrupt + Special signal to immediately quit the node, + without any further processing + PortException + If trying to receive from an unconnected port, + or if the received value is ``None`` + + Examples + -------- + >>> self.input.receive() + 42 + + See Also + -------- + Input.receive_optional + Can potentially return 'None', use this + method when utilising optional inputs + + """ + val = self.receive_optional() + if val is None: + raise PortException("Received 'None', use 'receive_optional' with optional ports") + return val + + def receive_optional(self) -> T | None: + """ + Receives data from the port and blocks, potentially returning ``None``. + + In nearly all cases you will want to use `Input.receive` instead. This method + is intended to be used with optional inputs with upstream branches that may or + may not run. In those cases, use this method and handle a potential ``None`` + value indicating that the optional data is not available. If you expect the + data to always be available instantly, you can use `Input.ready` to check if + there's data in the channel to be read. + + Returns + ------- + T | None + Item received from the channel. + + Raises + ------ + PortInterrupt + Special signal to immediately quit the node, + without any further processing + PortException + If trying to receive from an unconnected port + + Examples + -------- + >>> self.input.receive() + 42 + + See Also + -------- + Input.receive + Raises a `PortException` instead of returning ``None`` + + """ + if not Port.is_connected(self): + if self._value is None: + if not self.optional: + raise PortException( + f"Attempting receive from unconnected port '{self.parent.name}-{self.name}'" + ) + else: + return None + return self.hook(self._value) + + if self._preloaded: + self._preloaded = False + if self._value is None: + return None + # See comment under the TYPE_CHECKING condition above, this seems to be a mypy quirk + return cast(T, self.hook(self._value)) + + # Warn the user if we receive multiple times from an unlooped node, as this is + # unusual and may cause unexpected premature graph shutdowns + if self._value is not None and not self.parent.looped and not self.optional: + log.warning( + "Receiving multiple times from the same port ('%s'), " + "but the node is not looping. This may cause unexpected " + "behaviour and earlier graph shutdowns if not accounted for.", + self.name, + ) + + # Get the current time to allow regular status update communication + current_time = datetime.now() + + # And now for the tricky bit + self.parent.status = Status.WAITING_FOR_INPUT + while not self.parent.signal.is_set(): + # First try to get data, even if the upstream process is dead + item = self.channel.receive(timeout=self.timeout) # type: ignore + if item is not None: + self.parent.status = Status.RUNNING + self._value = item + return cast(T, self.hook(item)) + + # If we attempt to receive a value, but we don't have a new one + # available and `cached` is True, we return the last cached value + if self.cached and self._value is not None: + return cast(T, self.hook(self._value)) + + # Check for channel termination signal, this should cause + # immediate port closure and component shutdown (in most cases) + if not self.active: + break + + # Send an update if it's taking a while + delta = datetime.now() - current_time + if delta.seconds > UPDATE_INTERVAL: + self.parent.send_update() + current_time = datetime.now() + + # Optional ports should really be probed for data with 'ready()' + if not self.optional: + self.parent.status = Status.STOPPED + raise PortInterrupt("Port is dead, stopping node", name=self.name) + + # This means the channel and by extension upstream port shut down. Returning `None` + # is okay provided that we have an optional port and the user handles this situation. + return None + + +class MultiPort(Port[T]): + """ + Aggregate Port parent class, allowing multiple + ports to be integrated into one instance. + + Parameters + ---------- + timeout + Timeout used for continuously polling the connection + for data on a blocking `receive` call. + optional + Whether this port is required for the process to stay alive. + If the connection to an optional port is closed by a neighbouring + process, the current node will not shutdown. + n_ports + The number of ports to instantiate, if not given will allow + dynamic creation of new ports when `set_channel` is called + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + Examples + -------- + Accessing individual ports through indexing: + + >>> out[0].send(item) + + """ + + _ports: list[Port[T]] + _type: type[Port[T]] + _TInter = TypeVar("_TInter", bound="MultiPort[Any]") + + def __init__( + self, + timeout: float = DEFAULT_TIMEOUT, + optional: bool = False, + mode: Literal["copy", "link", "move"] = "copy", + ) -> None: + self._ports = [] + super().__init__(timeout=timeout, optional=optional, mode=mode) + + @abstractmethod + def __getitem__(self, key: int) -> Port[T]: # pragma: no cover + pass + + def __setitem__(self, key: int, value: Port[T]) -> None: + self._ports[key] = value + + @abstractmethod + def __iter__(self) -> Generator[Port[T], None, None]: # pragma: no cover + pass + + @property + def serialized(self) -> dict[str, Any]: + return super().serialized | {"n_ports": len(self._ports)} + + def __len__(self) -> int: + return len(self._ports) + + def set_channel(self, channel: Channel[T]) -> None: + # This occurs *after* `build`, we thus need to provide + # a full port with name and parent parameters + port = self._type(optional=self.optional, timeout=self.timeout).build( + name=self.name, parent=self.parent + ) + port.set_channel(channel=channel) + self._ports.append(port) + + @staticmethod + def is_connected(port: _PortType) -> TypeGuard["_PortChannel[T]"]: + if not isinstance(port, MultiPort): # pragma: no cover + return port.is_connected(port) + if len(port) == 0: + return False + return all(subport.is_connected(subport) for subport in port) + + def close(self) -> None: + for _, port in enumerate(self._ports): + port.close() + + +class MultiOutput(MultiPort[T]): + """ + Aggregation of multiple output ports into a single port. + + Index into the port to access a normal `Output` instance with a `send` method. + + Parameters + ---------- + timeout + Timeout used for continuously polling the connection + for free space on a potentially blocking `send` call. + optional + Whether this port is required for the process to stay alive. + If the connection to an optional port is closed by a neighbouring + process, the current node will not shutdown. + n_ports + The number of ports to instantiate, if not given will allow + dynamic creation of new ports when `set_channel` is called + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + Examples + -------- + >>> class Example(Node): + ... out = MultiOutput[int](n_ports=2) + ... + ... def run(self): + ... self.out[0].send(42) + ... self.out[1].send(69) + + """ + + _target = "outputs" + _type = Output + + # Typing this correctly will require higher-kinded types I think + # and won't be possible for the foreseeable future :( + # But this is enough to get correct static type checking at the + # graph assembly stage, and that's the most important thing + _ports: list[Output[T]] # type: ignore + + def __getitem__(self, key: int) -> Output[T]: + return self._ports[key] + + def __iter__(self) -> Generator[Output[T], None, None]: + yield from self._ports + + # Convenience function allowing connections by + # using the shift operator on ports like so: + # a.out >> b.in + def __rshift__(self, other: Union[Input[T], "MultiInput[T]"]) -> None: + if self.parent.parent is not None: + self.parent.parent.connect(receiving=other, sending=self) + + +class MultiInput(MultiPort[T]): + """ + Aggregation of multiple input ports into a single port. + + Index into the port to access a normal `Input` instance with a `receive` method. + + Parameters + ---------- + timeout + Timeout used for continuously polling the connection + for data on a blocking `receive` call. + optional + Whether this port is required for the process to stay alive. + If the connection to an optional port is closed by a neighbouring + process, the current node will not shutdown. + n_ports + The number of ports to instantiate, if not given will allow + dynamic creation of new ports when `set_channel` is called + mode + Whether to ``'link'``, ``'move'`` or ``'copy'`` (default) files. + + Attributes + ---------- + name + Unique port name + parent + Parent node instance this port is part of + + Raises + ------ + PortException + If the port is used without a connected channel + + Examples + -------- + >>> class Example(Node): + ... inp = MultiInput[int](n_ports=2) + ... + ... def run(self): + ... a = self.inp[0].receive() + ... b = self.inp[1].receive() + + """ + + _target = "inputs" + _type = Input + _ports: list[Input[T]] # type: ignore + + def __getitem__(self, key: int) -> Input[T]: + return self._ports[key] + + def __iter__(self) -> Generator[Input[T], None, None]: + yield from self._ports + + # Convenience function allowing connections by + # using the shift operator on ports like so: + # a.in << b.out + def __lshift__(self, other: Union[Output[T], "MultiOutput[T]"]) -> None: + if self.parent.parent is not None: + self.parent.parent.connect(receiving=self, sending=other) + + def __init__( + self, + timeout: float = DEFAULT_TIMEOUT, + optional: bool = False, + mode: Literal["copy", "link", "move"] = "copy", + cached: bool = False, + ) -> None: + self.cached = cached + super().__init__(timeout, optional, mode) + + @property + def is_set(self) -> bool: + """Indicates whether all inputs have been set""" + return all(inp.is_set for inp in self._ports) + + @property + def default(self) -> T | None: + """Provides the default value, if available""" + return self._ports[0].default + + @property + def skippable(self) -> bool: + """Indicates whether this input can be skipped""" + return self.optional and not self.is_set and not self.is_connected(self) + + def set(self, value: T) -> None: + """Set unconnected ports to a specified value""" + port = self._type(optional=self.optional, timeout=self.timeout).build( + name=self.name, parent=self.parent + ) + port.set(value) + self._ports.append(port) + + def set_channel(self, channel: Channel[T]) -> None: + # This occurs *after* `build`, we thus need to provide + # a full port with name and parent parameters + port = self._type(optional=self.optional, timeout=self.timeout, cached=self.cached).build( + name=self.name, parent=self.parent + ) + port.set_channel(channel=channel) + self._ports.append(port) + + def dump(self) -> list[list[T]]: + """Dump any data contained in any of the inputs.""" + return [port.dump() for port in self._ports] + + def preload(self, data: list[T]) -> None: + """Preload the input with data, to allow resuming from a checkpoint.""" + for port, datum in zip(self._ports, data): + port.preload(datum) diff --git a/maize/core/node.py b/maize/core/node.py new file mode 100644 index 0000000..650fb3c --- /dev/null +++ b/maize/core/node.py @@ -0,0 +1,828 @@ +""" +Node +---- +Nodes are the individual atomic components of workflow graphs and encapsulate arbitrary +computational behaviour. They communicate with other nodes and the environment +only through ports, and expose parameters to the user. Custom behaviour is +implemented by subclassing and defining the `Node.run` method. + +""" + +from abc import abstractmethod +from collections.abc import Generator, Sequence +import importlib +import logging +import os +import random +from pathlib import Path +import shutil +import subprocess +import sys +import time +import traceback +from typing import Any, Optional, TYPE_CHECKING + +from maize.core.component import Component +from maize.core.interface import ( + Flag, + Input, + MultiInput, + Interface, + Parameter, + FileParameter, + PortInterrupt, +) +from maize.core.runtime import ( + Runnable, + Status, + StatusHandler, + init_signal, + setup_node_logging, +) +from maize.utilities.execution import CommandRunner, JobResourceConfig, run_single_process +from maize.utilities.resources import cpu_count +from maize.utilities.io import ScriptSpecType, expand_shell_vars, remove_dir_contents +from maize.utilities.utilities import ( + Timer, + change_environment, + extract_attribute_docs, + has_module_system, + load_modules, + set_environment, +) +from maize.utilities.validation import Validator + +if TYPE_CHECKING: + from maize.core.graph import Graph + + +log = logging.getLogger("build") + + +class NodeBuildException(Exception): + """Exception raised for faulty `build` methods.""" + + +class Node(Component, Runnable, register=False): + """ + Base class for all atomic (non-subgraph) nodes of a graph. + Create a subclass to implement your own custom tasks. + + Parameters + ---------- + parent + Parent component, typically the graph in context + name + The name of the component + description + An optional additional description + fail_ok + If True, the failure in the component will + not trigger the whole network to shutdown + n_attempts + Number of attempts at executing the `run()` method + level + Logging level, if not given or ``None`` will use the parent logging level + cleanup_temp + Whether to remove any temporary directories after completion + resume + Whether to resume from a previous checkpoint + logfile + File to output all log messages to, defaults to STDOUT + max_cpus + Maximum number of CPUs to use, defaults to the number of available cores in the system + max_gpus + Maximum number of GPUs to use, defaults to the number of available GPUs in the system + loop + Whether to run the `run` method in a loop, as opposed to a single time + initial_status + The initial status of the node, will be ``NOT_READY`` by default, but + can be set otherwise to indicate that the node should not be run. + This would be useful when starting from a partially completed graph. + max_loops + Run the internal `loop` method a maximum number of `max_loops` times + + Attributes + ---------- + cpus + Resource semaphore allowing the reservation of multiple CPUs + gpus + Resource semaphore allowing the reservation of multiple GPUs + + Examples + -------- + Subclassing can be done the following way: + + >>> class Foo(Node): + ... out: Output[int] = Output() + ... + ... def run(self): + ... self.out.send(42) + + """ + + active: Flag = Flag(default=True) + """Whether the node is active or can be shutdown""" + + python: FileParameter[Path] = FileParameter(default=Path(sys.executable)) + """The path to the python executable to use for this node, allows custom environments""" + + modules: Parameter[list[str]] = Parameter(default_factory=list) + """Modules to load in addition to ones defined in the configuration""" + + scripts: Parameter[ScriptSpecType] = Parameter(default_factory=dict) + """ + Additional script specifications require to run. + + Examples + -------- + >>> node.scripts.set({"interpreter": /path/to/python, "script": /path/to/script}) + + """ + + commands: Parameter[dict[str, Path]] = Parameter(default_factory=dict) + """Custom paths to any commands""" + + batch_options: Parameter[JobResourceConfig | None] = Parameter(default=None, optional=True) + """If given, will run commands on the batch system instead of locally""" + + # Making status a descriptor allows us to log status updates and + # keep track of the timers when waiting on other nodes or resources + status = StatusHandler() + + def __init__( + self, + parent: Optional["Graph"] = None, + name: str | None = None, + description: str | None = None, + fail_ok: bool = False, + n_attempts: int = 1, + level: int | str | None = None, + cleanup_temp: bool = True, + resume: bool = False, + logfile: Path | None = None, + max_cpus: int | None = None, + max_gpus: int | None = None, + loop: bool | None = None, + max_loops: int = -1, + initial_status: Status = Status.NOT_READY, + ) -> None: + super().__init__( + parent=parent, + name=name, + description=description, + fail_ok=fail_ok, + n_attempts=n_attempts, + level=level, + cleanup_temp=cleanup_temp, + resume=resume, + logfile=logfile, + max_cpus=max_cpus, + max_gpus=max_gpus, + loop=loop, + ) + self.status = initial_status + + # Run loops a maximum number of times, mostly to simplify testing + self.max_loops = max_loops + + # For the signal handler + self.n_signals = 0 + + # Construct the node I/O and check it makes sense + self.build() + self.check() + + # The full timer should measure the full execution time + # no matter if there's a block or not + self.run_timer = Timer() + self.full_timer = Timer() + + @property + def user_parameters(self) -> dict[str, Input[Any] | MultiInput[Any] | Parameter[Any]]: + """Returns all settable parameters and unconnected inputs defined by the user""" + return { + name: para for name, para in self.parameters.items() if name not in Node.__dict__ + } | self.inputs + + def setup_directories(self, parent_path: Path | None = None) -> None: + """Sets up the required directories.""" + if parent_path is None: + parent_path = Path("./") + self.work_dir = Path(parent_path / f"node-{self.name}") + self.work_dir.mkdir() + + def build(self) -> None: + """ + Builds the node by instantiating all interfaces from descriptions. + + Examples + -------- + >>> class Foo(Node): + ... def build(self): + ... self.inp = self.add_input( + ... "inp", datatype="pdb", description="Example input") + ... self.param = self.add_parameter("param", default=42) + + """ + docs = extract_attribute_docs(self.__class__) + for name in dir(self): + attr = getattr(self, name) + if isinstance(attr, Interface): + interface = attr.build(name=name, parent=self) + interface.doc = docs.get(name, None) + setattr(self, name, interface) + + def check(self) -> None: + """ + Checks if the node was built correctly. + + Raises + ------ + NodeBuildException + If the node didn't declare at least one port + + """ + if len(self.inputs) == 0 and len(self.outputs) == 0: + raise NodeBuildException(f"Node {self.name} requires at least one port") + + if self.status == Status.NOT_READY: + self.status = Status.READY + + def check_dependencies(self) -> None: + """ + Check if all node dependencies are met by running the `prepare` method + + Raises + ------ + NodeBuildException + If required callables were not found + ImportError + If required python packages were not found + + """ + if self.__class__.is_checked(): + log.debug("Already checked '%s', skipping...", self.name) + return + + log.debug("Checking if required dependencies are available for '%s'...", self.name) + try: + run_single_process(self._prepare, name=self.name, executable=self.python.filepath) + finally: + self.__class__.set_checked() + + def run_command( + self, + command: str | list[str], + working_dir: Path | None = None, + validators: Sequence[Validator] | None = None, + verbose: bool = False, + raise_on_failure: bool = True, + command_input: str | None = None, + pre_execution: str | list[str] | None = None, + batch_options: JobResourceConfig | None = None, + prefer_batch: bool = False, + timeout: float | None = None, + cuda_mps: bool = False, + ) -> subprocess.CompletedProcess[bytes]: + """ + Runs an external command. + + Parameters + ---------- + command + Command to run as a single string, or a list of strings + working_dir + Working directory for the command + validators + One or more `Validator` instances that will + be called on the result of the command. + verbose + If ``True`` will also log any STDOUT or STDERR output + raise_on_failure + Whether to raise an exception when encountering a failure + command_input + Text string used as input for command + pre_execution + Command to run directly before the main one + batch_options + Job options for the batch system, if given, + will attempt run on the batch system + prefer_batch + Whether to prefer submitting a batch job rather than running locally. Note that + supplying batch options directly will automatically set this to ``True``. + timeout + Maximum runtime for the command in seconds, or unlimited if ``None`` + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + + Returns + ------- + subprocess.CompletedProcess[bytes] + Result of the execution, including STDOUT and STDERR + + Raises + ------ + ProcessError + If any of the validators failed or the returncode was not zero + + Examples + -------- + To run a single command: + + >>> self.run_command("echo foo", validators=[SuccessValidator("foo")]) + + To run on a batch system, if configured: + + >>> self.run_command("echo foo", batch_options=JobResourceConfig(nodes=1)) + + """ + self.status = Status.WAITING_FOR_COMMAND + cmd = CommandRunner( + working_dir=working_dir or self.work_dir, + validators=validators, + raise_on_failure=raise_on_failure, + prefer_batch=prefer_batch and (batch_options is not None or self.batch_options.is_set), + rm_config=self.config.batch_config, + ) + if batch_options is None and self.batch_options.is_set: + batch_options = self.batch_options.value + + if batch_options is not None: + self.logger.debug("Using batch options: %s", batch_options) + res = cmd.run_validate( + command=command, + verbose=verbose, + command_input=command_input, + config=batch_options, + pre_execution=pre_execution, + timeout=timeout, + cuda_mps=cuda_mps, + ) + self.status = Status.RUNNING + return res + + def run_multi( + self, + commands: Sequence[str | list[str]], + working_dirs: Sequence[Path] | None = None, + command_inputs: Sequence[str | None] | None = None, + validators: Sequence[Validator] | None = None, + verbose: bool = False, + raise_on_failure: bool = True, + n_jobs: int = 1, + pre_execution: str | list[str] | None = None, + batch_options: JobResourceConfig | None = None, + timeout: float | None = None, + cuda_mps: bool = False, + n_batch: int | None = None, + batchsize: int | None = None, + ) -> list[subprocess.CompletedProcess[bytes]]: + """ + Runs multiple commands in parallel. + + Parameters + ---------- + commands + Commands to run as a list of strings, or a nested list of strings + working_dirs + Working directories for each command + command_inputs + Text string used as input for each command + validators + One or more `Validator` instances that will + be called on the result of the command. + verbose + If ``True`` will also log any STDOUT or STDERR output + raise_on_failure + Whether to raise an exception when encountering a failure + n_jobs + Max number of processes to spawn at once, should generally be + compatible with the number of available CPUs + pre_execution + Command to run directly before the main one + batch_options + Job options for the batch system, if given, + will attempt run on the batch system + timeout + Maximum runtime for the command in seconds, or unlimited if ``None`` + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + n_batch + Number of batches to divide all the commands between. Incompatible with ``batchsize``. + batchsize + Number of commands to put into 1 batch. Incompatible with ``n_batch``. + + Returns + ------- + list[subprocess.CompletedProcess[bytes]] + Result of the execution, including STDOUT and STDERR + + Raises + ------ + ProcessError + If any of the validators failed or a returncode was not zero + + Examples + -------- + To run multiple commands, but only two at a time: + + >>> self.run_multi(["echo foo", "echo bar", "echo baz"], n_jobs=2) + + To run on a batch system, if configured (note that batch settings are per-command): + + >>> self.run_command(["echo foo", "echo bar"], batch_options=JobResourceConfig(nodes=1)) + + """ + self.status = Status.WAITING_FOR_COMMAND + batch = batch_options is not None or self.batch_options.is_set + if n_jobs > cpu_count() and not batch: + self.logger.warning( + "Requested number of jobs (%s) is higher than available cores (%s)", + n_jobs, + cpu_count(), + ) + + cmd = CommandRunner( + validators=validators, + raise_on_failure=raise_on_failure, + prefer_batch=batch, + rm_config=self.config.batch_config, + ) + reserved = n_jobs if not batch else 0 + with self.cpus(reserved): + res = cmd.run_parallel( + commands=commands, + working_dirs=working_dirs, + command_inputs=command_inputs, + verbose=verbose, + n_jobs=n_jobs, + validate=True, + config=(batch_options or self.batch_options.value) if batch else None, + pre_execution=pre_execution, + timeout=timeout, + cuda_mps=cuda_mps, + n_batch=n_batch, + batchsize=batchsize, + ) + self.status = Status.RUNNING + return res + + # No cover because we change the environment, which breaks pytest-cov + def _prepare(self) -> None: # pragma: no cover + """ + Prepares the execution environment for `run`. + + Performs the following: + + * Changing the python environment, if required + * Setting of environment variables + * Setting of parameters from the config + * Loading LMOD modules + * Importing python packages listed in `required_packages` + * Checking if software in `required_callables` is available + + """ + # Change environment based on python executable set by `RunPool` + python = self.node_config.python + if not self.python.is_default: + python = self.python.value + change_environment(expand_shell_vars(python)) + + # Custom preset parameters + config_params = self.node_config.parameters + for key, val in config_params.items(): + if key in self.parameters and not (param := self.parameters[key]).changed: + param.set(val) + + # Load any required modules if possible from the global config, + # they don't neccessarily have to contain the executable, but + # might be required for running it + if has_module_system(): + load_modules(*self.node_config.modules) + + # And then locally defined ones + for mod in self.modules.value: + load_modules(mod) + + # Environment variables + set_environment(self.config.environment) + + # Check we can import any required modules, now + # that we might be in a different environment + for package in self.required_packages: + importlib.import_module(package) + + for exe in self.required_callables: + # Prepare any interpreter - script pairs, prioritize local + if exe in (script_dic := self.node_config.scripts | self.scripts.value): + interpreter = os.path.expandvars(script_dic[exe].get("interpreter", "")) + loc_path = expand_shell_vars(script_dic[exe]["location"]) + location = loc_path.absolute().as_posix() if loc_path.exists() else loc_path.name + self.runnable[exe] = f"{interpreter} {location}" + + # Prepare custom command locations + elif exe in (com_dic := self.node_config.commands | self.commands.value): + self.runnable[exe] = expand_shell_vars(Path(com_dic[exe])).absolute().as_posix() + + # It's already in our $PATH + elif shutil.which(exe) is not None: + self.runnable[exe] = exe + + else: + raise NodeBuildException( + f"Could not find a valid executable for '{exe}'. Add an appropriate entry " + f"in your global configuration under '[{self.__class__.__name__.lower()}]', " + f"e.g. 'commands.{exe} = \"path/to/executable\"', " + f"'scripts.{exe}.interpreter = \"path/to/interpreter\"' and " + f"'scripts.{exe}.location = \"path/to/script\"' or " + f"load an appropriate module with 'modules = [\"module_with_{exe}\"]'" + ) + + # Run any required user setup + self.prepare() + + # No cover because we change the environment, which breaks pytest-cov + def execute(self) -> None: # pragma: no cover + """ + This is the main entrypoint for node execution. + + Raises + ------ + KeyboardInterrupt + If the underlying process gets interrupted or receives ``SIGINT`` + + """ + # Prepare environment + self._prepare() + + # This will hold a traceback-exception for sending to the main process + tbe = None + + # Signal handler for interrupts will make sure the process has a chance + # to shutdown gracefully, by setting the shutdown signal + init_signal(self) + + # This replaces the build-logger with the process-safe message based logger + self.logger = setup_node_logging( + name=self.name, + logging_queue=self._logging_queue, + level=self.level, + color=self.logfile is None, + ) + self.logger.debug("Using executable at") + self.logger.debug("'%s'", sys.executable) + + os.chdir(self.work_dir) + self.logger.debug("Running in '%s'", self.work_dir.as_posix()) + + # Wait a short random time to make testing more reliable, + # this shouldn't matter in production too much + time.sleep(random.random()) + + # The `run_timer` is controlled by the `StatusHandler` + # descriptor, so no need to start it here + self.full_timer.start() + self.logger.debug("Starting up") + try: + # Main execution + tbe = self._attempt_loop() + + finally: + # We exhausted all our attempts, we now set the shutdown signal + # (if the task is not allowed to fail, otherwise we don't care) + if self.status == Status.FAILED and not self.fail_ok: + self.signal.set() + if tbe is not None: + self.send_update(exception=tbe) + + run_time, full_time = self.run_timer.stop(), self.full_timer.stop() + self.logger.debug("Shutting down, runtime: %ss", run_time) + self.logger.debug("Shutting down, total time: %ss", full_time) + + # It's very important we shutdown all ports, + # so other processes can follow suit + self._shutdown() + + # The final update will have a completion status, indicating to + # the master process that this node has finished processing + self.send_update() + + def cleanup(self) -> None: + if self.cleanup_temp and self.work_dir.exists(): + shutil.rmtree(self.work_dir) + for inp in self.inputs.values(): + if inp.channel is not None: + inp.channel.kill() + + def prepare(self) -> None: + """ + Prepare the node for execution. + + This method is called just before :meth:`~maize.core.node.Node.run` + and can be used for setup code if desired. This can be useful with + looping nodes, as initial variables can be set before the actual execution. + + Examples + -------- + >>> class Foo(Node): + ... def prepare(self): + ... self.val = 0 + ... + ... def run(self): + ... val = self.inp.receive() + ... self.val += val + ... self.out.send(self.val) + + """ + + @abstractmethod + def run(self) -> None: + """ + This is the main high-level node execution point. + + It should be overridden by the user to provide custom node functionality, + and should return normally at completion. Exception handling, log message passing, + and channel management are handled by the wrapping `execute` method. + + Examples + -------- + >>> class Foo(Node): + ... def run(self): + ... val = self.inp.receive() + ... new = val * self.param.value + ... self.out.send(new) + + """ + + def _shutdown(self) -> None: + """ + Shuts down the component gracefully. + + This should not be called by the user directly, + as it is called at node shutdown by `execute()`. + + """ + if self.status not in (Status.STOPPED, Status.FAILED): + self.status = Status.COMPLETED + + # Shutdown all ports, it's important that we do this for + # every port, not just the ones that appear active, as + # some port closures can only be performed on one side + # (e.g. file channels can only be closed after the receiver + # has moved the file out of the channel directory). + for name, port in self.ports.items(): + port.close() + self.logger.debug("Closed port %s", name) + + def _loop(self, step: float = 0.5) -> Generator[int, None, None]: + """ + Allows continuous looping of the main routine, it handles graceful + shutdown of the node and checks for changes in the run conditions. + Do not use this function directly, instead pass ``loop=True`` to + the component constructor. + + Parameters + ---------- + step + Timestep in seconds to take between iterations + + Returns + ------- + Generator[None, None, None] + Generator allowing infinite looping + + """ + i = 0 + while not self.signal.is_set(): + # Inactive but required ports should stop the process + if not self.ports_active(): + self.logger.debug("Shutting down due to inactive port") + self.status = Status.STOPPED + return + + # In a testing setup we will only execute a limited number + # of times, as we are testing the node in isolation + if self.max_loops > 0 and i >= self.max_loops: + self.logger.debug("Maximum loops reached (%s/%s)", i, self.max_loops) + return + + time.sleep(step) + yield i + i += 1 + + def _iter_run(self, cleanup: bool = False) -> None: + """ + Runs the node (in a loop if `self.looped` is set). + + Parameters + ---------- + cleanup + Whether to remove working directory contents between iterations + + """ + # In some cases we might have branches in our workflow dedicated to providing + # a potentially optional input for a downstream node. This means we don't want + # to force the user to set this value, so we want to be able to override the + # originating parameter to be optional. This should then cause the loading node + # to not run at all, which will cause this branch to shutdown. These conditions + # here check that this is the case, i.e. if all parameters are unset and optional, + # and none of the inputs are connected, then we return immediately. + if all(para.skippable for para in self.user_parameters.values()): + self.logger.warning( + "Inputs / parameters are unset, optional, or unconnected, not running node" + ) + return + + if self.looped: + for it in self._loop(): + if cleanup and it != 0: + self.logger.debug("Removing all items in '%s'", self.work_dir.absolute()) + remove_dir_contents(self.work_dir) + for inp in self.flat_inputs: + if self.parent is not None and not inp.cached: + remove_dir_contents(self.parent.work_dir / f"{self.name}-{inp.name}") + self.run() + else: + self.run() + + def _attempt_loop(self) -> traceback.TracebackException | None: + """ + Attempt to execute the `run` method multiple times. Internal use only. + + Returns + ------- + TracebackException | None + Object containing a traceback in case of an error encountered in `run` + + Raises + ------ + KeyboardInterrupt + If the underlying process gets interrupted or receives ``SIGINT`` + + """ + # Skip execution if inactive + if not self.active.value: + self.logger.info("Node inactive, stopping...") + self.status = Status.STOPPED + return None + + tbe = None + for attempt in range(self.n_attempts): + # Reset the status in case of failure + self.status = Status.RUNNING + try: + self._iter_run(cleanup=True) + + # Raised to immediately quit a node in the case of a dead input, + # as we want to avoid propagating ``None`` while the graph is + # shutting down + except PortInterrupt as inter: + self.logger.debug("Port '%s' shutdown, exiting now...", inter.name) + self.status = Status.STOPPED + break + + # This could come from the system or ctrl-C etc. and + # should always abort any attempts + except KeyboardInterrupt: + self.logger.info("Received interrupt") + self.status = Status.STOPPED + + # This should have been set if we're here, just making sure + self.signal.set() + raise + + # Error in run() + except Exception as err: # pylint: disable=broad-except + self.status = Status.FAILED + msg = "Attempt %s of %s failed due to exception" + self.logger.error(msg, attempt + 1, self.n_attempts, exc_info=err) + + # Save the traceback to send to the main process in the 'finally' + # block, as we don't yet know whether to raise or fail silently + # and (maybe) try again + tbe = traceback.TracebackException.from_exception(err) + + # Can we even start up again? Check if ports are still open + if not self.ports_active(): + self.logger.info("Cannot restart due to closed ports") + self.signal.set() + break + + # Success + else: + if self.n_attempts > 1: + self.logger.info("Attempt %s of %s succeeded", attempt + 1, self.n_attempts) + break + + return tbe + + +class LoopedNode(Node): + """Node variant that loops its `run` method by default""" + + def __init__( + self, max_loops: int = -1, initial_status: Status = Status.NOT_READY, **kwargs: Any + ): + kwargs["loop"] = True + super().__init__(max_loops=max_loops, initial_status=initial_status, **kwargs) diff --git a/maize/core/py.typed b/maize/core/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/maize/core/runtime.py b/maize/core/runtime.py new file mode 100644 index 0000000..69a2964 --- /dev/null +++ b/maize/core/runtime.py @@ -0,0 +1,636 @@ +""" +Run +--- +Graph running infrastructure enabling parallel execution of nodes. + +""" + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Iterable +from dataclasses import dataclass +import datetime +from enum import auto +import functools +import itertools +import logging +from logging.handlers import QueueHandler +from multiprocessing import get_context +import os +from pathlib import Path +import signal +import sys +import time +from traceback import TracebackException +from typing import TYPE_CHECKING, Any + +from maize.utilities.utilities import StrEnum +from maize.utilities.execution import DEFAULT_CONTEXT, JobHandler + +if TYPE_CHECKING: + from logging import _FormatStyle + from multiprocessing import Process, Queue + from multiprocessing.context import ForkProcess, ForkServerProcess, SpawnProcess + from multiprocessing.synchronize import Event as EventClass + from maize.core.component import Component + from typing import TextIO + + +_RESET = "\x1b[0m" +_GREEN = "\x1b[1;32m" +_YELLOW = "\x1b[1;33m" +_RED = "\x1b[1;31m" +_PURPLE = "\x1b[1;35m" +_BLUE = "\x1b[1;34m" +_LIGHT_BLUE = "\x1b[1;36m" + + +MAIZE_ISO = r""" + ___ ___ ___ ___ + /\__\ /\ \ ___ /\ \ /\ \ + /::| | /::\ \ /\ \ \:\ \ /::\ \ + /:|:| | /:/\:\ \ \:\ \ \:\ \ /:/\:\ \ + /:/|:|__|__ /::\~\:\ \ /::\__\ \:\ \ /::\~\:\ \ + /:/ |::::\__\ /:/\:\ \:\__\ __/:/\/__/ _______\:\__\ /:/\:\ \:\__\ + \/__/~~/:/ / \/__\:\/:/ / /\/:/ / \::::::::/__/ \:\~\:\ \/__/ + /:/ / \::/ / \::/__/ \:\~~\~~ \:\ \:\__\ + /:/ / /:/ / \:\__\ \:\ \ \:\ \/__/ + /:/ / /:/ / \/__/ \:\__\ \:\__\ + \/__/ \/__/ \/__/ \/__/ + +""" + + +log = logging.getLogger("build") + + +# ================================ +# Node execution utilities +# ================================ + + +class NodeException(Exception): + """ + Exception representing execution failure in a node. + Should not be instantiated directly, but by providing + a traceback and using the `from_traceback_exception` method. + + """ + + @classmethod + def from_traceback_exception(cls, tbe: TracebackException) -> "NodeException": + """Create an exception from a `TracebackException` instance.""" + return cls("\n\nOriginal traceback:\n" + "".join(tbe.format())) + + +class Runnable(ABC): + """Represents an object that can be run as a separate process.""" + + name: str + signal: "EventClass" + n_signals: int + + @abstractmethod + def execute(self) -> None: + """Main execution method, run in a separate process.""" + + @abstractmethod + def cleanup(self) -> None: + """Method run on component shutdown in the main process.""" + + +class RunPool: + """ + Context manager for running a pool of processes. + + Upon entering, will start one process for each item + with the `execute` method as the target. Exiting + will cause all processes to be joined within + `wait_time`, and any remaining ones will be terminated. + + Parameters + ---------- + items + Items of type `Runnable` + wait_time + Time to wait when joining processes before issuing a termination signal + + """ + + def __init__(self, *items: Runnable, wait_time: float = 1) -> None: + self.items = list(items) + self.wait_time = wait_time + self._procs: list["Process" | "SpawnProcess" | "ForkProcess" | "ForkServerProcess"] = [] + + def __enter__(self) -> "RunPool": + ctx = get_context(DEFAULT_CONTEXT) + current_exec = sys.executable + for item in self.items: + if hasattr(item, "python"): + exec_path = item.python.value.as_posix() + log.debug("Setting executable for '%s' to '%s'", item.name, exec_path) + ctx.set_executable(exec_path) + else: + ctx.set_executable(current_exec) + proc = ctx.Process(target=item.execute, name=item.name) # type: ignore + log.debug("Launching '%s'", item.name) + proc.start() + self._procs.append(proc) + ctx.set_executable(current_exec) + return self + + def __exit__(self, *_: Any) -> None: + # Flush all loggers no matter what happened in the workflow + for hand in log.handlers: + hand.flush() + + end_time = time.time() + self.wait_time + for item, proc in zip(self.items, self._procs, strict=True): + item.cleanup() + join_time = max(0, min(end_time - time.time(), self.wait_time)) + proc.join(join_time) + log.debug( + "Joined runnable '%s' (PID %s) with exitcode %s", item.name, proc.pid, proc.exitcode + ) + + while self._procs: + proc = self._procs.pop() + if proc.is_alive(): # pragma: no cover + # FIXME this is drastic, not sure what side effects this might have. + # Calling `proc.terminate()` seems to, when running a larger number + # of processes in serial (such as in pytest), cause deadlocks. This + # might be due to queues being accessed, although all calls to queues + # are now non-blocking. + proc.kill() + log.debug( + "Killed '%s' (PID %s) with exitcode %s", proc.name, proc.pid, proc.exitcode + ) + + @property + def n_processes(self) -> int: + """Returns the total number of processes.""" + return len(self._procs) + + +# ================================ +# Signal handling +# ================================ + + +MAX_TERM_SIGNALS = 1 + + +def _default_signal_handler( + signal_object: Runnable, exception_type: type[Exception], *_: Any +) -> None: # pragma: no cover + signal_object.n_signals += 1 + signal_object.signal.set() + + JobHandler().cancel_all() + + if signal_object.n_signals == MAX_TERM_SIGNALS: + raise exception_type() + + +def init_signal(signal_object: Runnable) -> None: + """ + Initializes the signal handler for interrupts. + + Parameters + ---------- + signal_object + Object with an `n_signals` attribute + + """ + int_handler = functools.partial(_default_signal_handler, signal_object, KeyboardInterrupt) + signal.signal(signal.SIGINT, int_handler) + signal.signal(signal.SIGTERM, int_handler) + signal.siginterrupt(signal.SIGINT, False) + signal.siginterrupt(signal.SIGTERM, False) + + +# ================================ +# Status handling +# ================================ + + +class Status(StrEnum): + """Component run status.""" + + NOT_READY = auto() + """Not ready / not initialized""" + + READY = auto() + """Ready to run""" + + RUNNING = auto() + """Currently running with no IO interaction""" + + COMPLETED = auto() + """Successfully completed everything""" + + FAILED = auto() + """Failed via exception""" + + STOPPED = auto() + """Stopped because channels closed, or other components completed""" + + WAITING_FOR_INPUT = auto() + """Waiting for task input""" + + WAITING_FOR_OUTPUT = auto() + """Waiting for task output (backpressure due to full queue)""" + + WAITING_FOR_RESOURCES = auto() + """Waiting for compute resources (blocked by semaphore)""" + + WAITING_FOR_COMMAND = auto() + """Waiting for an external command to complete""" + + +_STATUS_COLORS = { + Status.NOT_READY: _RED, + Status.READY: _YELLOW, + Status.RUNNING: _BLUE, + Status.COMPLETED: _GREEN, + Status.FAILED: _RED, + Status.STOPPED: _GREEN, + Status.WAITING_FOR_INPUT: _YELLOW, + Status.WAITING_FOR_OUTPUT: _YELLOW, + Status.WAITING_FOR_RESOURCES: _PURPLE, + Status.WAITING_FOR_COMMAND: _LIGHT_BLUE, +} + + +class StatusHandler: + """Descriptor logging any component status updates.""" + + public_name: str + private_name: str + + def __set_name__(self, owner: type["Component"], name: str) -> None: + self.public_name = name + self.private_name = "_" + name + + def __get__(self, obj: "Component", objtype: Any = None) -> Any: + return getattr(obj, self.private_name) + + def __set__(self, obj: "Component", value: Status) -> None: + setattr(obj, self.private_name, value) + + if hasattr(obj, "logger"): + obj.logger.debug("Status changed to %s", value.name) + + # We start and pause the timer here depending on whether + # the node is actually doing computation (RUNNING) or + # just blocked (WAITING) + if hasattr(obj, "run_timer"): + # Completed nodes get a separate explicit update during shutdown + if value not in (Status.COMPLETED, Status.FAILED, Status.STOPPED): + obj.send_update() + if value in (Status.RUNNING, Status.WAITING_FOR_COMMAND): + obj.run_timer.start() + elif obj.run_timer.running: + obj.run_timer.pause() + + +@dataclass +class StatusUpdate: + """ + Summarizes the node status at completion of execution. + + Attributes + ---------- + name + Name of the node + parents + Names of all parents + status + Node status at completion, will be one of ('FAILED', 'STOPPED', 'COMPLETED') + run_time + Time spent in status 'RUNNING' + full_time + Time spent for the full node execution, including 'WAITING' for others + n_inbound + Number of items waiting to be received + n_outbound + Number of items waiting to be sent + note + Additional message to be printed at completion + exception + Exception if status is 'FAILED' + + """ + + name: str + parents: tuple[str, ...] + status: Status + run_time: datetime.timedelta = datetime.timedelta(seconds=0) + full_time: datetime.timedelta = datetime.timedelta(seconds=0) + n_inbound: int = 0 + n_outbound: int = 0 + note: str | None = None + exception: TracebackException | None = None + + def __eq__(self, other: object) -> bool: + if isinstance(other, StatusUpdate): + return ( + self.name, + self.parents, + self.status, + self.n_inbound, + self.n_outbound, + self.note, + self.exception, + ) == ( + other.name, + other.parents, + other.status, + other.n_inbound, + other.n_outbound, + other.note, + other.exception, + ) + return NotImplemented + + +def format_summaries(summaries: list[StatusUpdate], runtime: datetime.timedelta) -> str: + """ + Create a string containing interesting information from all execution summaries. + + Parameters + ---------- + summaries + List of summaries to be aggregated and formatted + runtime + The total runtime to format + + Returns + ------- + str + Formatted summaries + + """ + n_success = sum(s.status == Status.COMPLETED for s in summaries) + n_stopped = sum(s.status == Status.STOPPED for s in summaries) + n_failed = sum(s.status == Status.FAILED for s in summaries) + wall_time: datetime.timedelta = sum( + (s.full_time for s in summaries), start=datetime.timedelta(seconds=0) + ) + blocked_time: datetime.timedelta = sum( + (s.full_time - s.run_time for s in summaries), start=datetime.timedelta(seconds=0) + ) + smilie = ":(" if n_failed > 0 else ":)" + msg = f"Execution completed {smilie} total runtime: {runtime}" + msg += f"\n\t{n_success} nodes completed successfully" + msg += f"\n\t{n_stopped} nodes stopped due to closing ports" + msg += f"\n\t{n_failed} nodes failed" + msg += f"\n\t{wall_time} total walltime" + msg += f"\n\t{blocked_time} spent waiting for resources or other nodes" + + return msg + + +def _item_count(items: Iterable[Any], target: Any) -> int: + count = 0 + for item in items: + count += target == item + return count + + +def format_update(summaries: dict[tuple[str, ...], StatusUpdate], color: bool = True) -> str: + """ + Create a string containing a summary of waiting nodes. + + Parameters + ---------- + summaries + Dictionary of `StatusUpdate` + color + Whether to add color to the update + + Returns + ------- + str + A formatted string with information + + """ + msg = "Workflow status" + all_names = [path[-1] for path in summaries] + for (*_, name), node_sum in summaries.items(): + if _item_count(all_names, name) > 1: + name = "-".join(node_sum.parents) + + if color: + stat_color = _STATUS_COLORS[node_sum.status] + msg += f"\n{'':>34} | {name:>16} | {stat_color}{node_sum.status.name}{_RESET}" + else: + msg += f"\n{'':>34} | {name:>16} | {node_sum.status.name}" + + # Show queued item information, but only if the node is running + # (otherwise this value will never be updated) + if (node_sum.n_inbound > 0 or node_sum.n_outbound > 0) and node_sum.status not in ( + Status.COMPLETED, + Status.FAILED, + Status.STOPPED, + ): + msg += f" ({node_sum.n_inbound} | {node_sum.n_outbound})" + return msg + + +class Spinner: + """Provides an indication that the workflow is currently running""" + def __init__(self, interval: int) -> None: + self.icon = list("|/-\\") + self.cycler = itertools.cycle(self.icon) + self.update_time = datetime.datetime.now() + self.interval = interval + + def __call__(self) -> None: + if (datetime.datetime.now() - self.update_time).seconds > self.interval: + print(next(self.cycler), end="\r") + self.update_time = datetime.datetime.now() + + +# ================================ +# Logging +# ================================ + + +LOG_FORMAT = "%(asctime)s | %(levelname)8s | %(comp_name)16s | %(message)s" +_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} +VALID_LEVELS = {getattr(logging, name) for name in _LEVELS} | _LEVELS + + +class CustomFormat(logging.Formatter): + """Custom formatter to show colors in log messages""" + + LOG_FORMAT = "%(asctime)s | {color}%(levelname)8s{reset} | %(comp_name)16s | %(message)s" + + formats = { + logging.DEBUG: LOG_FORMAT.format(color=_LIGHT_BLUE, reset=_RESET), + logging.INFO: LOG_FORMAT.format(color=_BLUE, reset=_RESET), + logging.WARNING: LOG_FORMAT.format(color=_YELLOW, reset=_RESET), + logging.ERROR: LOG_FORMAT.format(color=_RED, reset=_RESET), + logging.CRITICAL: LOG_FORMAT.format(color=_RED, reset=_RESET), + } + + formats_no_color = {k: LOG_FORMAT.format(color="", reset="") for k in formats} + + def __init__( + self, + fmt: str | None = None, + datefmt: str | None = None, + style: "_FormatStyle" = "%", + validate: bool = True, + *, + defaults: Mapping[str, Any] | None = None, + color: bool = True, + ) -> None: + self.defaults = defaults + self.color = color + super().__init__(fmt, datefmt, style, validate, defaults=defaults) + + def format(self, record: logging.LogRecord) -> str: + if self.color: + format_string = self.formats.get(record.levelno) + else: + format_string = self.formats_no_color.get(record.levelno) + formatter = logging.Formatter(format_string, defaults=self.defaults) + return formatter.format(record) + + +def setup_build_logging( + name: str, level: str | int = logging.INFO, file: Path | None = None +) -> logging.Logger: + """ + Sets up the build-time logging functionality, running in the main process. + + Parameters + ---------- + name + Name of the component being logged + level + Logging level + file + File to log to + + Returns + ------- + logging.Logger + Logger customized to a specific component + + """ + logger = logging.getLogger("build") + + # If we don't do this check we might get duplicate build log messages + if logger.hasHandlers(): + logger.handlers.clear() + + # When logging to a file we don't want to add the ASCII color codes + handler: logging.FileHandler | "logging.StreamHandler[TextIO]" + handler = logging.FileHandler(file) if file is not None else logging.StreamHandler() + stream_formatter = CustomFormat(defaults=dict(comp_name=name), color=file is None) + handler.setFormatter(stream_formatter) + logger.addHandler(handler) + if level not in VALID_LEVELS: + raise ValueError(f"Logging level '{level}' is not valid. Valid levels: {_LEVELS}") + logger.setLevel(level) + return logger + + +def setup_node_logging( + name: str, + logging_queue: "Queue[logging.LogRecord | None]", + level: str | int = logging.INFO, + color: bool = True, +) -> logging.Logger: + """ + Sets up the node logging functionality, running in a child process. + + Parameters + ---------- + name + Name of the component being logged + logging_queue + Global messaging queue + level + Logging level + color + Whether to log in color + + Returns + ------- + logging.Logger + Logger customized to a specific component + + """ + stream_formatter = CustomFormat(defaults=dict(comp_name=name), color=color) + logger = logging.getLogger(f"run-{os.getpid()}") + if logger.hasHandlers(): + logger.handlers.clear() + handler = QueueHandler(logging_queue) + handler.setFormatter(stream_formatter) + logger.addHandler(handler) + + if level not in VALID_LEVELS: + raise ValueError(f"Logging level '{level}' is not valid. Valid levels: {_LEVELS}") + logger.setLevel(level) + return logger + + +class Logger(Runnable): + """ + Logging node for all nodes in the graph. + + Parameters + ---------- + message_queue + Queue to send logging messages to + name + Name of the logger + level + Logging level + file + File to log to + + """ + + name: str + + def __init__( + self, + message_queue: "Queue[logging.LogRecord | None]", + name: str | None = None, + level: int = logging.INFO, + file: Path | None = None, + ) -> None: + self.name = name if name is not None else self.__class__.__name__ + self.queue = message_queue + self.level = level + self.file = file + + def execute(self) -> None: + """Main logging process, receiving messages from a global queue.""" + logger = logging.getLogger("run") + handler: logging.FileHandler | "logging.StreamHandler[TextIO]" + handler = logging.StreamHandler() if self.file is None else logging.FileHandler(self.file) + logger.addHandler(handler) + logger.setLevel(self.level) + + while True: + message = self.queue.get() + + # Sentinel to quit the logging process, means we are stopping the graph + if message is None: + break + + logger.handle(message) + + for hand in logger.handlers: + hand.flush() + + # Clear all items so the process can shut down normally + self.queue.cancel_join_thread() + + def cleanup(self) -> None: + self.queue.put(None) diff --git a/maize/core/workflow.py b/maize/core/workflow.py new file mode 100644 index 0000000..a7963ae --- /dev/null +++ b/maize/core/workflow.py @@ -0,0 +1,1107 @@ +""" +Workflow +-------- +The top-level graph class, i.e. the root node. Subclasses from `Graph` +to add checkpointing, file IO, and execution orchestration. + +""" + +import argparse +import builtins as _b +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime +import functools +import inspect +import itertools +import logging +from pathlib import Path +import queue +import shutil +import time +from traceback import TracebackException +import typing +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Literal, + TypeVar, + cast, + get_args, + get_origin, +) + +import dill +import toml + +from maize.core.component import Component +from maize.core.graph import Graph, GraphBuildException +from maize.core.interface import ( + Input, + Output, + MultiParameter, + MultiInput, + MultiOutput, + ParameterException, +) +from maize.core.runtime import ( + Logger, + RunPool, + NodeException, + Spinner, + StatusUpdate, + Status, + format_summaries, + format_update, + MAIZE_ISO, + setup_build_logging, +) +from maize.utilities.utilities import ( + Timer, + nested_dict_to_tuple, + tuple_to_nested_dict, + NestedDict, +) +from maize.utilities.io import ( + DictAction, + NodeConfig, + args_from_function, + create_default_parser, + parse_groups, + with_fields, + with_keys, + read_input, + write_input, +) + +from maize.utilities.execution import ( + JobResourceConfig, + WorkflowStatus, + ProcessError, + CommandRunner, + BatchSystemType, + job_from_id, +) + +if TYPE_CHECKING: + from multiprocessing import Queue + from maize.core.component import MessageType + + +TIME_STEP_SEC = 0.1 +CHECKPOINT_INTERVAL_MIN = 15 +STATUS_INTERVAL_MIN = 0.5 + +T = TypeVar("T") + + +def _get_message(message_queue: "Queue[T]", timeout: float = 0.1) -> T | None: + try: + return message_queue.get(timeout=timeout) + except queue.Empty: + return None + + +class ParsingException(Exception): + """Exception raised for graph definition parsing errors.""" + + +class CheckpointException(Exception): + """Exception raised for checkpointing issues.""" + + +def expose(factory: Callable[..., "Workflow"]) -> Callable[[], None]: + """ + Exposes the workflow definition and sets it up for execution. + + Parameters + ---------- + factory + A function that returns a fully defined workflow + + Returns + ------- + Callable[[], None] + Runnable workflow + + """ + name = factory.__name__.lower() + Workflow.register(name=name, factory=factory) + + def wrapped() -> None: + # Argument parsing - we create separate groups for + # global settings and workflow specific options + parser = create_default_parser(help=False) + + # We first construct the parser for the workflow factory - this allows + # us to parse args before constructing the workflow object, thus + # allowing us to influence the construction process + parser = args_from_function(parser, factory) + + # Parse only the factory args, everything else will be for the workflow. + # General maize args will show up in `known`, so they need to be filtered + known, rem = parser.parse_known_args() + func_args = inspect.getfullargspec(factory).annotations + func_args.pop("return") + + # Create the workflow, using only the relevant factory args + workflow = factory( + **{key: vars(known)[key] for key in func_args if vars(known)[key] is not None} + ) + workflow.description = Workflow.get_workflow_summary(name=name) + parser.description = workflow.description + + # Now create workflow-specific settings + flow = parser.add_argument_group(workflow.name) + flow = workflow.add_arguments(flow) + + # Help is the last arg, this allows us to create a help message dynamically + parser.add_argument("-h", "--help", action="help") + groups = parse_groups(parser, extra_args=rem) + + # Finally parse global settings + args = groups["maize"] + workflow.update_settings_with_args(args) + workflow.update_settings_with_args(known) + workflow.update_parameters(**vars(groups[workflow.name])) + + # Execution + workflow.check() + if args.check: + workflow.logger.info("Workflow compiled successfully") + return + + workflow.execute() + + return wrapped + + +@dataclass +class FutureWorkflowResult: + """ + Represents the result of the workflow over the duration of its execution and afterwards. + + Parameters + ---------- + id + The native ID of the job given by the batch system upon submission + folder + The folder the workflow is running in + workflow + The workflow definition in serialized form + backend + The type of batch system used for running the job + + """ + + id: str + folder: Path + workflow: dict[str, Any] + backend: BatchSystemType + stdout_path: Path | None + stderr_path: Path | None + + def __post_init__(self) -> None: + # To be able to interact with the associated job after serializing + # this object, we need to re-create the psij `Job` object + self._job = job_from_id(self.id, backend=self.backend) + + def to_dict(self) -> dict[str, Any]: + """ + Serializes the result to a dictionary. + + Returns + ------- + dict[str, Any] + A dictionary representation of the result object + + """ + return { + "id": self.id, + "folder": self.folder, + "workflow": self.workflow, + "backend": self.backend, + } + + @classmethod + def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> "FutureWorkflowResult": + """ + Create a result object from a dictionary representation. + + Parameters + ---------- + data + A dictionary representation of the result object + + Returns + ------- + FutureWorkflowResult + The result object + + """ + return cls(**data, **kwargs) + + def wait(self) -> WorkflowStatus: + """ + Wait for workflow completion. + + Returns + ------- + WorkflowStatus + The status of the workflow at the final state + + """ + status = self._job.wait() + if status is None: + return WorkflowStatus.UNKNOWN + return WorkflowStatus.from_psij(status.state) + + def query(self) -> WorkflowStatus: + """ + Query the workflow status from the batch system. + + Returns + ------- + WorkflowStatus + The status of the workflow + + """ + return WorkflowStatus.from_psij(self._job.status.state) + + def done(self) -> bool: + """ + Checks whether the workflow has finished. + + Returns + ------- + bool + ``True`` if the workflow is in a completed or failed state + + """ + return self.query() in (WorkflowStatus.COMPLETED, WorkflowStatus.FAILED) + + def cancel(self) -> None: + """Cancel the workflow execution""" + self._job.cancel() + + +def wait_for_all(results: list[FutureWorkflowResult], timeout: float | None = None) -> None: + """ + Waits for all listed workflows to complete. + + Parameters + ---------- + results + A list of ``FutureWorkflowResult`` objects + timeout + Maximum time to wait for completion + + """ + start = time.time() + while any(not res.done() for res in results): + time.sleep(5) + if timeout is not None and (time.time() - start) >= timeout: + for res in results: + res.cancel() + + +class Workflow(Graph, register=False): + """ + Represents a workflow graph consisting of individual components. + + As a user, one will typically instantiate a `Workflow` and then add + individual nodes or subgraphs and connect them together. + + """ + + __registry: ClassVar[dict[str, Callable[[], "Workflow"]]] = {} + + @classmethod + def register(cls, name: str, factory: Callable[[], "Workflow"]) -> None: + """ + Register a workflow for global access. + + Parameters + ---------- + name + Name to register the workflow under + factory + Function returning an initialized and built workflow + + """ + Workflow.__registry[name.lower()] = factory + + @staticmethod + def get_workflow_summary(name: str) -> str: + """Provides a one-line summary of the workflow.""" + factory = Workflow.__registry[name.lower()] + if factory.__doc__ is not None: + for line in factory.__doc__.splitlines(): + if line: + return line.lstrip() + return "" + + @staticmethod + def get_available_workflows() -> set[Callable[[], "Workflow"]]: + """ + Returns all available and registered / exposed workflows. + + Returns + ------- + set[Callable[[], "Workflow"]] + All available workflow factories + + """ + return set(Workflow.__registry.values()) + + @classmethod + def from_name(cls, name: str) -> "Workflow": + """ + Create a predefined workflow from a previously registered name. + + Parameters + ---------- + name + Name the workflow is registered under + + Returns + ------- + Workflow + The constructed workflow, with all nodes and channels + + Raises + ------ + KeyError + If a workflow under that name cannot be found + + """ + try: + flow = Workflow.__registry[name.lower()]() + except KeyError as err: + raise KeyError( + f"Workflow with name '{name.lower()}' not found in the registry. " + "Have you imported the workflow definitions?" + ) from err + return flow + + @classmethod + def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> "Workflow": + """ + Read a graph definition from a dictionary parsed from a suitable serialization format. + + Parameters + ---------- + data + Tree structure as a dictionary containing graph and node parameters, + as well as connectivity. For format details see the `read_input` method. + kwargs + Additional arguments passed to the `Component` constructor + + Returns + ------- + Graph + The constructed graph, with all nodes and channels + + Raises + ------ + ParsingException + If the input dictionary doesn't conform to the expected format + + """ + # Create the graph and fill it with nodes + graph = cls(**with_keys(data, cls._GRAPH_FIELDS), strict=False, **kwargs) + for node_data in data["nodes"]: + if "type" not in node_data: + raise ParsingException("Node specification requires at least 'type'") + + node_type: type[Component] = Component.get_node_class(node_data["type"]) + init_data = with_keys( + node_data, + cls._COMPONENT_FIELDS + | { + "parameters", + }, + ) + if "name" not in init_data: + init_data["name"] = node_type.__name__.lower() + graph.add(**init_data, component=node_type) + + # Check the parameter field for potential mappings from node-specific + # parameters to global ones, these can then be set on the commandline + for param_data in data.get("parameters", []): + if any(field not in param_data for field in ("name", "map")): + raise ParsingException("Parameter specification requires at least 'name' and 'map'") + + # Each defined parameter can map to several nodes, + # so we first collect those parameters + parameters_to_map = [] + for entry in param_data["map"]: + path = nested_dict_to_tuple(entry) + param = graph.get_parameter(*path) + if isinstance(param, MultiInput): + raise ParsingException("Mapping MultiInputs is not currently supported") + parameters_to_map.append(param) + + # And then map them using `MultiParameter` + graph.combine_parameters( + *parameters_to_map, + name=param_data["name"], + optional=param_data.get("optional", None), + ) + if "value" in param_data: + graph.update_parameters(**{param_data["name"]: param_data["value"]}) + + # Now connect all ports together based on channel specification + for channel_data in data["channels"]: + if any(field not in channel_data for field in ("sending", "receiving")): + raise ParsingException( + "Channel specification requires at least 'sending' and 'receiving'" + ) + + input_path = nested_dict_to_tuple(channel_data["receiving"]) + output_path = nested_dict_to_tuple(channel_data["sending"]) + input_port = graph.get_port(*input_path) + output_port = graph.get_port(*output_path) + + # Mostly for mypy :) + if not isinstance(input_port, Input | MultiInput) or not isinstance( + output_port, Output | MultiOutput + ): + raise ParsingException( + "Port doesn't correspond to the correct type, are you sure" + " 'sending' and 'receiving' are correct?" + ) + graph.connect(sending=output_port, receiving=input_port) + + # At this point a lot of parameters will not have been set, so we only check connectivity + # (also applies to software dependencies, as those can be set via parameters as well) + super(cls, graph).check() + return graph + + @classmethod + def from_file(cls, file: Path | str) -> "Workflow": + """ + Reads in a graph definition in JSON, YAML, or TOML format + and creates a runnable workflow graph. + + This is an example input: + + .. code:: yaml + + name: graph + description: An optional description for the workflow + level: INFO + nodes: + - name: foo + type: ExampleNode + + # Below options are optional + description: An optional description + fail_ok: false + n_attempts: 1 + parameters: + val: 40 + + channels: + - sending: foo: out + receiving: bar: input + + # Optional + parameters: + - name: val + map: + foo: val + + Parameters + ---------- + file + File in JSON, YAML, or TOML format + + Returns + ------- + Workflow + The complete graph with all connections + + Raises + ------ + ParsingException + If the input file doesn't conform to the expected format + + """ + file = Path(file) + data = read_input(file) + + return cls.from_dict(data) + + @classmethod + def from_checkpoint(cls, file: Path | str) -> "Workflow": + """ + Initialize a graph from a checkpoint file. + + Checkpoints include two additional sections, `_data` for any data + stored in a channel at time of shutdown, and `_status` for node + status information. We need the data for the full graph and thus + use a nested implementation. + + .. code:: yaml + + _data: + - bar: input: + _status: + - foo: STOPPED + - subgraph: baz: FAILED + + Parameters + ---------- + file + Path to the checkpoint file + + Returns + ------- + Graph + The initialized graph, with statuses set and channels loaded + + Raises + ------ + ParsingException + If the input file doesn't conform to the expected format + + """ + file = Path(file) + data = read_input(file) + graph = cls.from_dict(data, resume=True) + graph.load_checkpoint(data) + return graph + + def _set_global_attribute(self, __name: str, __value: Any, /) -> None: + """Set an attribute for all contained components.""" + setattr(self, __name, __value) + for comp in self.flat_components: + setattr(comp, __name, __value) + + def load_checkpoint(self, data: dict[str, Any]) -> None: + """ + Load checkpoint data from a dictionary. + + Uses data as generated by `read_input` to access + the special `_data` and `_status` fields. + + Parameters + ---------- + data + Dictionary data to read in as a checkpoint. + Both `_data` and `_status` are optional. + + See also + -------- + from_checkpoint : Load a `Graph` from a checkpoint file + to_checkpoint : Save a `Graph` to a checkpoint file + + """ + for nested_data in data.get("_data", []): + path: list[str] + *path, input_name, raw_data = nested_dict_to_tuple(nested_data) + node = self.get_node(*path) + if len(preload := dill.loads(raw_data)) > 0: + node.inputs[input_name].preload(preload) + + for status_data in data.get("_status", []): + *path, status = nested_dict_to_tuple(status_data) + node = self.get_node(*path) + node.status = Status(status) + + def as_dict(self) -> dict[str, Any]: + return with_fields(self, self._GRAPH_FIELDS & self._serializable_attributes) + + def to_dict(self) -> dict[str, Any]: + """ + Create a dictionary from a graph, ready to be saved in a suitable format. + + Returns + ------- + dict[str, Any] + Nested dictionary equivalent to the input format + + Examples + -------- + >>> g = Workflow(name="foo") + ... foo = g.add(Foo) + ... bar = g.add(Bar) + ... g.auto_connect(foo, bar) + ... data = g.to_dict() + + """ + data = self.as_dict() + # For the purposes of writing checkpoints we treat subgraphs just like nodes + data["nodes"] = [comp.as_dict() for comp in self.nodes.values()] + data["channels"] = [] + for sender_path, receiver_path in self.channels: + channel_data: dict[str, NestedDict[str, str]] = dict( + sending=tuple_to_nested_dict(*sender_path), + receiving=tuple_to_nested_dict(*receiver_path), + ) + data["channels"].append(channel_data) + + data["parameters"] = [ + param.as_dict() + for param in self.parameters.values() + if isinstance(param, MultiParameter) + ] + + return data + + def to_file(self, file: Path | str) -> None: + """ + Save the graph to a file. The type is inferred from the suffix + and can be one of JSON, YAML, or TOML. + + Parameters + ---------- + file + Path to the file to save to + + Examples + -------- + >>> g = Workflow(name="foo") + ... foo = g.add(Foo) + ... bar = g.add(Bar) + ... g.auto_connect(foo, bar) + ... g.to_file("graph.yml") + + """ + file = Path(file) + + data = self.to_dict() + write_input(file, data) + + def to_checkpoint(self, path: Path | str | None = None, fail_ok: bool = True) -> None: + """ + Saves the current graph state, including channel data and node liveness to a file. + + Parameters + ---------- + path + Optional filename for the checkpoint + fail_ok + If ``True``, will only log a warning instead of raising + an exception when encountering a writing problem. + + Raises + ------ + CheckpointException + Raised for checkpoint writing errors when `fail_ok` is ``False`` + + """ + timestamp = time.strftime("%Y%m%d-%H%M%S") + if path is None: + path = self.work_dir / f"ckp-{self.name}-{timestamp}.yaml" + path = Path(path) + + try: + data = self.to_dict() + data |= self._checkpoint_data() + write_input(path, data) + self.logger.info("Wrote checkpoint to %s", path.as_posix()) + + # Problems with writing a checkpoint shouldn't + # cause the whole graph to necessarily crash + except Exception as err: # pylint: disable=broad-except + if not fail_ok: + raise CheckpointException("Unable to complete checkpoint") from err + self.logger.warning("Unable to save checkpoint", exc_info=err) + + def generate_config_template(self) -> str: + """ + Generates a global configuration template in TOML format. + + Returns + ------- + str + The config template + + """ + conf = {} + for node in set(self.flat_nodes): + if node.required_callables: + conf[node.__class__.__name__] = NodeConfig().generate_template( + node.required_callables + ) + return toml.dumps(conf) + + AddableArgType = ( + argparse.ArgumentParser | argparse._ArgumentGroup # pylint: disable=protected-access + ) + _A = TypeVar("_A", bound=AddableArgType) + + def add_arguments(self, parser: _A) -> _A: + """ + Adds custom arguments to an existing parser for workflow parameters + + Parameters + ---------- + parser + Pre-initialized parser or group + + Returns + ------- + argparse.ArgumentParser | argparse._ArgumentGroup + An parser instance that can be used to read additional + commandline arguments specific to the workflow + + See Also + -------- + Workflow.update_with_args + Sets up a parser for the workflow, uses `add_arguments` to update it, + and parses all arguments with updated parameters for the workflow + Workflow.update_settings_with_args + Updates the workflow with non-parameter settings + + """ + for name, param in self.all_parameters.items(): + # If the `datatype` is annotated it might fail a subclass + # check, so we unpack it using `get_args` first + dtype = param.datatype + if get_origin(dtype) == Annotated: + dtype = get_args(dtype)[0] + if get_origin(dtype) is not None: + dargs = get_args(dtype) + dtype = get_origin(dtype) + + self.logger.debug("Matching parameter '%s' with datatype '%s'", name, dtype) + + match dtype: + # We want people to parameterise their generic nodes + case TypeVar(): + parent = param.parents[0] if isinstance(param, MultiParameter) else param.parent + raise GraphBuildException( + f"Parameter '{name}' is a generic. Did you specify the datatype " + f"for node '{parent.name}' ('{parent.__class__.__name__}' type)?" + ) + + # Just a simple flag + case _b.bool: + parser.add_argument( + f"--{name}", action=argparse.BooleanOptionalAction, help=param.doc + ) + + # File path needs special treatment, not quite sure why + case path_type if isinstance(dtype, type) and issubclass(dtype, Path): + parser.add_argument(f"--{name}", type=path_type, help=param.doc) + + # Several options + case typing.Literal: + parser.add_argument( + f"--{name}", type=str, help=param.doc, choices=get_args(dtype) + ) + + # Anything else should be a callable type, e.g. int, float... + case _b.int | _b.float | _b.complex | _b.str | _b.bytes: + doc = f"{param.doc} [{dtype.__name__}]" + parser.add_argument(f"--{name}", type=dtype, help=doc) # type: ignore + + # Multiple items + case _b.tuple | _b.list: + if get_origin(dargs[0]) == Literal: # pylint: disable=comparison-with-callable + parser.add_argument( + f"--{name}", nargs="+", help=param.doc, choices=get_args(dargs[0]) + ) + else: + doc = f"{param.doc} [{dargs[0].__name__}]" + parser.add_argument(f"--{name}", nargs=len(dargs), type=dargs[0], help=doc) + + case _b.dict: + parser.add_argument(f"--{name}", nargs="+", action=DictAction, help=param.doc) + + case _: + self.logger.warning( + "Parameter '%s' with datatype '%s' could " + "not be exposed as a commandline argument", + name, + dtype, + ) + + return parser + + def update_with_args( + self, extra_options: list[str], parser: argparse.ArgumentParser | None = None + ) -> None: + """ + Update the graph with additional options passed from the commandline. + + Parameters + ---------- + extra_options + List of option strings, i.e. the output of ``parse_args`` + parser + Optional parser to reuse + + Raises + ------ + ParsingException + Raised when encountering unexpected commandline options + + """ + if parser is None: + parser = argparse.ArgumentParser() + parser = self.add_arguments(parser) + + try: + param_args = parser.parse_args(extra_options) + except SystemExit as err: + raise ParsingException( + "Unexpected argument type in commandline args, " f"see:\n{parser.format_help()}" + ) from err + + self.update_parameters(**vars(param_args)) + + def update_settings_with_args(self, args: argparse.Namespace) -> None: + """ + Updates the workflow with global settings from the commandline. + + Parameters + ---------- + args + Namespace including the args to use. See `maize -h` for possible options. + + """ + # Update the config first, other explicitly set options should override it + if args.config: + if not args.config.exists(): + raise FileNotFoundError(f"Config under '{args.config.as_posix()}' not found") + self.config.update(args.config) + if args.quiet: + self._set_global_attribute("level", logging.WARNING) + self.logger.setLevel(logging.WARNING) + if args.debug: + self._set_global_attribute("level", logging.DEBUG) + self.logger.setLevel(logging.DEBUG) + if args.log: + self._set_global_attribute("logfile", args.log) + if args.keep or args.debug: + self._set_global_attribute("cleanup_temp", False) + if args.parameters: + data = read_input(args.parameters) + self.update_parameters(**data) + if args.scratch: + self._set_global_attribute("scratch", Path(args.scratch)) + + def check(self) -> None: + super().check() + super().check_dependencies() + + # These are the original names of all mapped parameters + multi_para_names = [ + para.original_names + for para in self.all_parameters.values() + if isinstance(para, MultiParameter) + ] + orig_names = set(itertools.chain(*multi_para_names)) + + for node in self.flat_nodes: + for name, param in node.parameters.items(): + # Parameters that have been mapped to the workflow level + # (self.parameters) should not raise an exception if not set, + # as we might set these after a check (e.g. on the commandline) + if not param.is_set and not param.optional and name not in orig_names: + raise ParameterException( + f"Parameter '{name}' of node '{node.name}' needs to be set explicitly" + ) + + # Remind the user of parameters that still have to be set + for name, para in self.all_parameters.items(): + if not para.is_set and not para.optional: + self.logger.warning("Parameter '%s' must be set to run the workflow", name) + + def execute(self) -> None: + """ + Run a given graph. + + This is the top-level entry for maize execution. It creates a separate + logging process and general message queue and then starts the `execute` + methods of all nodes. Any node may at some point signal for the full graph + to be shut down, for example after a failure. Normal termination of a node + is however signalled by an `runtime.StatusUpdate` instance with finished + status. Any exceptions raised in a node are passed through the message queue + and re-raised as a `runtime.NodeException`. + + Raises + ------ + NodeException + If there was an exception in any node child process + + """ + # Import version here to avoid circular import + from maize.maize import __version__ # pylint: disable=import-outside-toplevel + + timer = Timer() + # We only know about the logfile now, so we can set all node + # logging and the main graph logging to use a file, if given + logger = Logger(message_queue=self._logging_queue, file=self.logfile) + self.logger = setup_build_logging(name=self.name, level=self.level, file=self.logfile) + self.logger.info(MAIZE_ISO) + self.logger.info( + "Starting Maize version %s (c) AstraZeneca %s", __version__, time.localtime().tm_year + ) + self.logger.info("Running workflow '%s' with parameters:", self.name) + for node in self.flat_nodes: + for name, param in node.parameters.items(): + if not param.is_default: + self.logger.info("%s = %s (from '%s')", name, param.value, node.name) + + # Setup directories recursively + self.setup_directories() + + # This is the control queue + receive = cast( + "Callable[[], MessageType]", + functools.partial( + _get_message, message_queue=self._message_queue, timeout=TIME_STEP_SEC + ), + ) + + # Visual run indication + spin = Spinner(interval=2) + + with RunPool(*self.active_nodes, logger) as pool: + timer.start() + update_time = datetime.now() + + # 'StatusUpdate' objects with a finished status represent our stop tokens + summaries: list[StatusUpdate] = [] + latest_status: dict[tuple[str, ...], StatusUpdate] = {} + while not self.signal.is_set() and len(summaries) < (pool.n_processes - 1): + delta = datetime.now() - update_time + spin() + + event: "MessageType" = receive() + match event: + case StatusUpdate(status=Status.COMPLETED | Status.STOPPED): + summaries.append(event) + latest_status[(*event.parents, event.name)] = event + + # Because the status is set in another process, we don't know the + # value, and thus need to set it again in the main process to allow + # saving the complete state of the graph as a checkpoint + comp = self.get_node(*event.parents, event.name) + comp.status = event.status + self.logger.info( + "Node '%s' finished (%s/%s)", + comp.name, + len(summaries), + pool.n_processes - 1, # Account for logger + ) + + # Pickling tracebacks is impossible, so we unfortunately have + # to go the ugly route via 'TracebackException' + case StatusUpdate(exception=TracebackException()): + summaries.append(event) + + # This might be a bug in mypy, the pattern match + # already implies that exception is not ``None`` + if event.exception is not None: + raise NodeException.from_traceback_exception(event.exception) + + case StatusUpdate(): + key = (*event.parents, event.name) + changed = ( + key not in latest_status or latest_status[key].status != event.status + ) + latest_status[key] = event + + if (delta.seconds > 1 and len(latest_status) > 0 and changed) or ( + delta.seconds > 600 and len(latest_status) > 0 + ): + self.logger.info( + format_update(latest_status, color=self.logfile is None) + ) + update_time = datetime.now() + + self.logger.debug("All nodes finished, stopping...") + elapsed = timer.stop() + self._cleanup() + + # Make sure we actually print the summary at the end, + # not while the logger is still going + time.sleep(0.5) + self.logger.info(format_summaries(summaries, elapsed)) + + def submit( + self, folder: Path, config: JobResourceConfig, maize_config: Path | None = None + ) -> FutureWorkflowResult: + """ + Submit this workflow to a batch system and exit. + + Parameters + ---------- + folder + The directory to execute the workflow in + config + The batch submission configuration + maize_config + Path to an optional different maize configuration + + Returns + ------- + FutureWorkflowResult + An object representing a future workflow result. It can be serialized, + queried, and the underlying job can be cancelled or waited for. + + """ + folder.mkdir(exist_ok=True) + self.to_file(folder / "flow.yml") + runner = CommandRunner( + name="workflow-runner", prefer_batch=True, rm_config=self.config.batch_config + ) + + command = f"maize --scratch {folder} " + if not self.cleanup_temp: + command += "--keep " + if self.logfile is not None: + command += f"--log {self.logfile} " + if maize_config is not None: + command += f"--config {maize_config} " + command += f"{folder / 'flow.yml'}" + + results = runner.run_async(command=command, working_dir=folder, verbose=True, config=config) + + # do some debug statements here, none of these values should be 0 + if results.id is None: + raise ProcessError("Workflow submission failed, no ID received") + + workflow_result = FutureWorkflowResult( + id=results.id, + folder=folder, + workflow=self.to_dict(), + backend=self.config.batch_config.system, + stdout_path=results.stdout_path, + stderr_path=results.stderr_path, + ) + + return workflow_result + + def _cleanup(self) -> None: + """Cleans up the graph directory if required""" + self.logger.debug("Attempting to remove working directory %s", self.work_dir) + if self.cleanup_temp: + self.logger.debug("Removing %s", self.work_dir) + shutil.rmtree(self.work_dir) + + def _checkpoint_data(self) -> dict[str, list[NestedDict[str, T]]]: + # We save all channel data in a flat dump for each input... + data: dict[str, list[NestedDict[str, T]]] = {"_data": [], "_status": []} + for node in self.flat_nodes: + for name, inp in node.inputs.items(): + if (channel_dump := inp.dump()) is None: + continue + dump: NestedDict[str, T] = tuple_to_nested_dict( + *node.component_path, str(name), dill.dumps(channel_dump) + ) + data["_data"].append(dump) + + # ... and all the status data (including nodes in subgraphs) + # in a flat list, by referring to each node by its full path + for node in self.flat_nodes: + data["_status"].append(tuple_to_nested_dict(*node.component_path, node.status.name)) + return data diff --git a/maize/maize.py b/maize/maize.py new file mode 100644 index 0000000..3768f41 --- /dev/null +++ b/maize/maize.py @@ -0,0 +1,98 @@ +""" +maize +========== +maize is a graph-based workflow manager for computational chemistry pipelines. + +""" + +# pylint: disable=import-outside-toplevel, unused-import + +__version__ = "0.8.3" + +import argparse +from contextlib import suppress +import multiprocessing +from pathlib import Path +import sys + +# TODO This doesn't work for editable installs for some reason +# (ImportError when attempting to import from maize directly). +from maize.core.workflow import Workflow +from maize.core.component import Component +from maize.utilities.execution import DEFAULT_CONTEXT +from maize.utilities.io import ( + Config, + create_default_parser, + get_plugins, +) + +# Importing these will trigger all contained nodes to be +# registered and discovered by `Component.get_available_nodes()` +import maize.steps.io +import maize.steps.plumbing + + +def main() -> None: + """Main maize execution entrypoint.""" + + # Import builtin steps to register them + + parser = create_default_parser() + + # We only add the file arg if we're not just showing available nodes + if "--list" not in sys.argv: + parser.add_argument( + "file", + type=Path, + help="Path to a JSON, YAML or TOML input file", + ) + + # extra_options are parameters passed to the graph itself + args, extra_options = parser.parse_known_args() + config = Config.from_default() + if args.config: + config.update(args.config) + + # Import namespace packages defined in config + for package in config.packages: + with suppress(ImportError): + get_plugins(package) + + if args.list: + # TODO Would be cool to partition this list into modules etc and get a tree-like overview + # of nodes, this would require keeping track of the __module__ at component registration + names = {comp.__name__: comp for comp in Component.get_available_nodes()} + min_length = max(len(name) for name in names) + 1 + print("The following node types are available:\n") + print( + "\n".join( + f" {name:<{min_length}} {names[name].get_summary_line():>{min_length}}" + for name in sorted(names) + ), + "\n", + ) + return + + graph = Workflow.from_file(args.file) + + # We create a separate parser for graph parameters + extra_parser = argparse.ArgumentParser() + extra_parser = graph.add_arguments(extra_parser) + if args.options: + print("The following workflow parameters are available:") + extra_parser.print_help() + return + + graph.update_settings_with_args(args) + graph.update_with_args(extra_options, extra_parser) + + graph.check() + if args.check: + print("Workflow compiled successfully") + return + graph.execute() + + +if __name__ == "__main__": + multiprocessing.set_start_method(DEFAULT_CONTEXT) + main() diff --git a/maize/steps/io.py b/maize/steps/io.py new file mode 100644 index 0000000..c153d47 --- /dev/null +++ b/maize/steps/io.py @@ -0,0 +1,267 @@ +"""General purpose tasks for data input / output.""" + +from pathlib import Path +from multiprocessing import get_context +import shutil +from typing import TYPE_CHECKING, Generic, TypeVar, Any, cast + +import dill + +from maize.core.node import Node, LoopedNode +from maize.core.interface import Parameter, FileParameter, Input, Output, Flag, MultiInput +from maize.utilities.execution import DEFAULT_CONTEXT + +if TYPE_CHECKING: + from multiprocessing import Queue + +T = TypeVar("T") + + +class Dummy(Node): + """A dummy node to connect to unused but optional input ports.""" + + out: Output[Any] = Output() + """Dummy output with type ``Any``, nothing will be sent""" + + def run(self) -> None: + # Do nothing + pass + + +class Void(LoopedNode): + """A node that swallows whatever input it receives.""" + + inp: MultiInput[Any] = MultiInput() + """Void input, a bit like ``/dev/null``""" + + def run(self) -> None: + for inp in self.inp: + if inp.ready(): + _ = inp.receive() + + +P = TypeVar("P", bound=Path) + + +class LoadFile(Node, Generic[P]): + """Provides a file specified as a parameter on an output.""" + + file: FileParameter[P] = FileParameter() + """Path to the input file""" + + out: Output[P] = Output(mode="copy") + """File output""" + + def run(self) -> None: + path = self.file.filepath + if not path.exists(): + raise FileNotFoundError(f"File at path {path.as_posix()} not found") + + self.out.send(path) + + +class LoadFiles(Node, Generic[P]): + """Provides multiple files specified as a parameter on an output channel.""" + + files: Parameter[list[P]] = Parameter() + """Paths to the input files""" + + out: Output[list[P]] = Output(mode="copy") + """File output channel""" + + def run(self) -> None: + paths = self.files.value + for path in paths: + if not path.exists(): + raise FileNotFoundError(f"File at path {path.as_posix()} not found") + + self.out.send(paths) + + +class LoadData(Node, Generic[T]): + """Provides data passed as a parameter to an output channel.""" + + data: Parameter[T] = Parameter() + """Data to be sent to the output verbatim""" + + out: Output[T] = Output() + """Data output""" + + def run(self) -> None: + data = self.data.value + self.out.send(data) + + +class LogResult(Node): + """Receives data and logs it.""" + + inp: Input[Any] = Input() + """Data input""" + + def run(self) -> None: + data = self.inp.receive() + self.logger.info("Received data: '%s'", data) + + +class Log(Node, Generic[T]): + """Logs any received data and sends it on""" + + inp: Input[T] = Input() + """Data input""" + + out: Output[T] = Output() + """Data output""" + + def run(self) -> None: + data = self.inp.receive() + msg = f"Handling data '{data}'" + if isinstance(data, Path): + msg += f", {'exists' if data.exists() else 'not found'}" + self.logger.info(msg) + self.out.send(data) + + +class SaveFile(Node, Generic[P]): + """ + Receives a file and saves it to a specified location. + + You must parameterise the node with the appropriate filetype, or just ``Path``. + If the destination doesn't exist, a folder will be created. + + """ + + inp: Input[P] = Input(mode="copy") + """File input""" + + destination: FileParameter[P] = FileParameter(exist_required=False) + """The destination file or folder""" + + overwrite: Flag = Flag(default=False) + """If ``True`` will overwrite any previously existing file in the destination""" + + def run(self) -> None: + file = self.inp.receive() + dest = self.destination.filepath + + # Inherit the received file name if not given + if dest.is_dir() or not dest.suffix: + dest = dest / file.name + self.logger.debug("Using path '%s'", dest.as_posix()) + + # Create parent directory if required + if not dest.parent.exists(): + self.logger.debug("Creating parent directory '%s'", dest.parent.as_posix()) + dest.parent.mkdir() + + if dest.exists() and not self.overwrite.value: + self.logger.warning( + "File already exists at destination '%s', set 'overwrite' to proceed anyway", + dest.as_posix(), + ) + else: + self.logger.info("Saving file to '%s'", dest.as_posix()) + shutil.copyfile(file, dest) + + +class SaveFiles(Node, Generic[P]): + """ + Receives multiple files and saves them to a specified location. + + You must parameterise the node with the appropriate filetype, or just ``Path``. + + """ + + inp: Input[list[P]] = Input(mode="copy") + """Files input""" + + destination: FileParameter[P] = FileParameter(exist_required=False) + """The destination folder""" + + overwrite: Flag = Flag(default=False) + """If ``True`` will overwrite any previously existing file in the destination""" + + def run(self) -> None: + files = self.inp.receive() + folder = self.destination.filepath + + if not folder.is_dir(): + raise ValueError(f"Destination '{folder}' must be a directory") + + for file in files: + dest = folder / file.name + if dest.exists() and not self.overwrite.value: + raise FileExistsError( + f"File already exists at destination '{dest.as_posix()}'" + ", set 'overwrite' to proceed anyway" + ) + self.logger.info("Saving file to '%s'", dest.as_posix()) + shutil.copyfile(file, dest) + + +class FileBuffer(Node, Generic[P]): + """ + Dynamic file storage. + + If the file exists, sends it immediately. If it doesn't, + waits to receive it and saves it in the specified location. + + """ + + inp: Input[P] = Input() + """File input""" + + out: Output[P] = Output() + """File output""" + + file: FileParameter[P] = FileParameter(exist_required=False) + """Buffered file location""" + + def run(self) -> None: + if self.file.filepath.exists(): + self.out.send(self.file.filepath) + file = self.inp.receive() + self.logger.info("Saving file to '%s'", self.file.filepath.as_posix()) + shutil.copyfile(file, self.file.filepath) + + +class Return(Node, Generic[T]): + """ + Return a value from the input to a specialized queue to be captured by the main process. + + Examples + -------- + >>> save = graph.add(Return[float]) + >>> # Define your workflow as normal... + >>> graph.execute() + >>> save.get() + 3.14159 + + Note that ``get()`` will pop the item from the internal queue, + this means the item will be lost if not assigned to a variable. + + """ + + ret_queue: "Queue[bytes]" + ret_item: T | None = None + + inp: Input[T] = Input() + + def get(self) -> T | None: + """Returns the passed value""" + return self.ret_item + + def build(self) -> None: + super().build() + + # This is our return value that will be checked in the test + self.ret_queue = get_context(DEFAULT_CONTEXT).Queue() + + def cleanup(self) -> None: + if not self.ret_queue.empty(): + self.ret_item = cast(T, dill.loads(self.ret_queue.get())) + self.ret_queue.close() + return super().cleanup() + + def run(self) -> None: + if (val := self.inp.receive_optional()) is not None: + self.ret_queue.put(dill.dumps(val)) diff --git a/maize/steps/plumbing.py b/maize/steps/plumbing.py new file mode 100644 index 0000000..7cec383 --- /dev/null +++ b/maize/steps/plumbing.py @@ -0,0 +1,765 @@ +"""General purpose tasks for data-flow control.""" + +from copy import deepcopy +from collections.abc import Iterator +import itertools +import time +from typing import Generic, TypeVar, Any + +import numpy as np +from numpy.typing import NDArray + +from maize.core.node import Node, LoopedNode +from maize.core.interface import ( + Parameter, + Flag, + Input, + Output, + MultiInput, + MultiOutput, + PortInterrupt, +) +from maize.utilities.utilities import chunks + + +T = TypeVar("T") + +INACTIVE = "Inactive ports" + + +class Multiply(Node, Generic[T]): + """ + Creates a list of multiple of the same item + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Multiply", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "1"] + n -> 1 [label = "[1, 1, 1]"] + } + + """ + + inp: Input[T] = Input() + """Data input""" + + out: Output[list[T]] = Output() + """Data output""" + + n_packages: Parameter[int] = Parameter() + """Number of times to multiply the data""" + + def run(self) -> None: + data = self.inp.receive() + out = [data for _ in range(self.n_packages.value)] + self.out.send(out) + + +class Yes(LoopedNode, Generic[T]): + """ + Sends a single received value multiple times + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Yes", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "1"] + n -> 1 [label = "1, 1, 1, ..."] + } + + """ + + inp: Input[T] = Input() + """Data input""" + + out: Output[T] = Output() + """Data output""" + + data: T | None = None + + def ports_active(self) -> bool: + return self.out.active + + def run(self) -> None: + if self.data is None: + self.data = self.inp.receive() + self.out.send(self.data) + + +class Barrier(LoopedNode, Generic[T]): + """ + Only sends data onwards if a signal is received + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Barrier", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [style = dashed] + 2 -> n + e -> n + {rank=same e n} + } + + """ + + inp: Input[T] = Input() + """Data input""" + + out: Output[T] = Output() + """Data output""" + + inp_signal: Input[bool] = Input() + """Signal data, upon receiving ``True`` will send the held data onwards""" + + def run(self) -> None: + data = self.inp.receive() + self.logger.info("Waiting for signal...") + if self.inp_signal.receive(): + self.out.send(data) + + +class Batch(Node, Generic[T]): + """ + Create batches of data from a single large input + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Batch", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "[1, 2, 3, 4]"] + n -> 1 [label = "[1, 2], [3, 4]"] + } + + """ + + inp: Input[list[T]] = Input() + """Input data""" + + out: Output[list[T]] = Output() + """Stream of output data chunks""" + + n_batches: Parameter[int] = Parameter() + """Number of chunks""" + + def run(self) -> None: + full_data = self.inp.receive() + for i, batch in enumerate(chunks(full_data, self.n_batches.value)): + self.logger.info("Sending batch %s/%s", i, self.n_batches.value) + self.out.send(list(batch)) + + +class Combine(Node, Generic[T]): + """ + Combine multiple batches of data into a single dataset + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Combine", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "[1, 2], [3, 4]"] + n -> 1 [label = "[1, 2, 3, 4]"] + } + + """ + + inp: Input[list[T] | NDArray[Any]] = Input() + """Input data chunks to combine""" + + out: Output[list[T]] = Output() + """Single combined output data""" + + n_batches: Parameter[int] = Parameter() + """Number of chunks""" + + def run(self) -> None: + data: list[T] = [] + for i in range(self.n_batches.value): + data.extend(self.inp.receive()) + self.logger.info("Received batch %s/%s", i + 1, self.n_batches.value) + self.logger.debug("Sending full dataset of size %s", len(data)) + self.out.send(data) + + +class MergeLists(LoopedNode, Generic[T]): + """ + Collect lists of data from multiple inputs and merges them into a single list. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, style = filled, + fontname = "Helvetica,Arial,sans-serif"] + MergeLists [label = "MergeLists", fillcolor = "#FFFFCC", color = "#B48000"] + 1 -> MergeLists [label = "[1, 2]"] + 2 -> MergeLists [label = "[3]"] + 3 -> MergeLists [label = "[4, 5]"] + MergeLists -> 4 [label = "[1, 2, 3, 4, 5]"] + } + + """ + + inp: MultiInput[list[T]] = MultiInput() + """Flexible number of input lists to be merged""" + + out: Output[list[T]] = Output() + """Single output for all merged data""" + + def run(self) -> None: + if not self.out.active or all(not port.active for port in self.inp): + raise PortInterrupt(INACTIVE) + + has_items = False + items: list[T] = [] + n_interrupt = 0 + for inp in self.inp: + try: + if not inp.optional: + items.extend(inp.receive()) + has_items = True + elif inp.ready() and (item := inp.receive_optional()) is not None: + items.extend(item) + has_items = True + + # Normally, interrupts are caught in the main maize execution loop. + # But when one of the inputs dies while we're trying to receive from + # it, we would still like to receive from any subsequent ports as + # normal. This means that this node will only shutdown once all + # inputs have signalled a shutdown. + except PortInterrupt: + n_interrupt += 1 + if n_interrupt >= len(self.inp): + raise + + # No check for an empty list, as nodes could conceivably explicitly send an empty + # list, in that case we want to honour this and send an empty list on too. + if has_items: + self.out.send(items) + + +class Merge(LoopedNode, Generic[T]): + """ + Collect inputs from multiple channels and send them to a + single output port on a first-in-first-out (FIFO) basis. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + node [shape = rectangle, label = "", color = white, style = filled, + fontname = "Helvetica,Arial,sans-serif"] + Merge [label = "Merge", fillcolor = "#FFFFCC", color = "#B48000"] + 1 -> Merge + 2 -> Merge + 3 -> Merge + Merge -> 4 + } + + """ + + inp: MultiInput[T] = MultiInput(optional=True) + """Flexible number of input channels to be merged""" + + out: Output[T] = Output() + """Single output for all merged data""" + + def run(self) -> None: + if not self.out.active or all(not port.active for port in self.inp): + raise PortInterrupt(INACTIVE) + + for port in self.inp: + if port.ready() and (item := port.receive_optional()) is not None: + self.out.send(item) + + +class Multiplex(LoopedNode, Generic[T]): + """ + Receives items on multiple inputs, sends them to a single output, + receives those items, and sends them to the same output index. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Multiplex", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [penwidth = 2] + n -> 2 [style = dashed] + n -> 3 [style = dashed] + 4 -> n [penwidth = 2] + 5 -> n [style = dashed] + 6 -> n [style = dashed] + e [fillcolor = "#FFFFCC", color = "#B48000", style = dashed] + n -> e + e -> n + {rank=same e n} + } + + """ + + inp: MultiInput[T] = MultiInput(optional=True) + """Multiple inputs, should be the same as the number of outputs""" + + out: MultiOutput[T] = MultiOutput(optional=True) + """Multiple outputs, should be the same as the number of inputs""" + + out_single: Output[T] = Output() + """Single output""" + + inp_single: Input[T] = Input() + """Single input""" + + def ports_active(self) -> bool: + return ( + any(inp.active for inp in self.inp) + and any(out.active for out in self.out) + and self.out_single.active + and self.inp_single.active + ) + + def run(self) -> None: + for inp, out in zip(self.inp, self.out, strict=True): + if inp.ready(): + item = inp.receive() + self.out_single.send(item) + new_item = self.inp_single.receive() + out.send(new_item) + + +class Copy(LoopedNode, Generic[T]): + """ + Copy a single input packet to multiple output channels. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + Copy [label = "Copy", fillcolor = "#FFFFCC", color = "#B48000"] + Copy -> 1 + Copy -> 2 + Copy -> 3 + 4 -> Copy + } + + """ + + inp: Input[T] = Input(mode="copy") + """Single input to broadcast""" + + out: MultiOutput[T] = MultiOutput(optional=True, mode="copy") + """Multiple outputs to broadcast over""" + + def run(self) -> None: + if (not self.inp.active and not self.inp.ready()) or all( + not out.active for out in self.out + ): + raise PortInterrupt(INACTIVE) + + val = self.inp.receive() + self.logger.debug("Received %s", val) + for out in self.out: + self.logger.debug("Sending %s", val) + out.send(deepcopy(val)) + + +class RoundRobin(LoopedNode, Generic[T]): + """ + Outputs a single input packet to a single output port at a time, + cycling through output ports. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "RoundRobin", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [label = 1] + n -> 2 [style = dashed, label = 2] + n -> 3 [style = dashed, label = 3] + 4 -> n + } + + """ + + inp: Input[T] = Input() + """Single input to alternatingly send on""" + + out: MultiOutput[T] = MultiOutput(optional=True) + """Multiple outputs to distribute over""" + + _output_cycle: Iterator[Output[T]] + _current_output: Output[T] + + def prepare(self) -> None: + self._output_cycle = itertools.cycle(self.out) + self._current_output = next(self._output_cycle) + + def run(self) -> None: + if not self.inp.active or all(not out.active for out in self.out): + raise PortInterrupt(INACTIVE) + + self._current_output.send(self.inp.receive()) + self._current_output = next(self._output_cycle) + + +class IntegerMap(Node): + """ + Maps an integer to another range. + + Takes a pattern describing which integer to output based on a constantly incrementing + input integer. For example, with the pattern ``[0, 2, -1]`` an input of ``0`` will output + ``1`` (the index of the first element), an input of ``1`` will also output ``1`` (as we + indicated 2 iterations for this position), and an input of ``2`` and up will output ``2`` + (for the final index). This can be especially useful in conjunction with `IndexDistribute` + to perform different actions based on a global iteration counter. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "IntegerMap [0, 2, -1]", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [label = "2"] + 2 -> n [label = "1"] + } + + See Also + -------- + IndexDistribute + Allows distributing data to one of several outputs based on a separately supplied index. + + """ + + inp: Input[int] = Input() + """Input integer""" + + out: Output[int] = Output() + """Output integer""" + + pattern: Parameter[list[int]] = Parameter() + """Output pattern""" + + def run(self) -> None: + x = self.inp.receive() + pat = np.array(self.pattern.value) + outs = np.arange(len(pat)).repeat(abs(pat)) + x = min(x, len(outs) - 1) + self.out.send(outs[x]) + + +class Choice(Node, Generic[T]): + """ + Sends an item from a dynamically specified input to an output. + + Receives an index separately and indexes into the + inputs to decide from where to receive the item. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "IndexDistribute", fillcolor = "#FFFFCC", color = "#B48000"] + 1 -> n [style = dashed] + 2 -> n [label = "foo"] + n -> 3 [label = "foo"] + 4 -> n [label = "1"] + } + + """ + + inp: MultiInput[T] = MultiInput() + """Item inputs""" + + inp_index: Input[int] = Input() + """Which index of the output to send the item to""" + + out: Output[T] = Output() + """Item output""" + + clip: Flag = Flag(default=False) + """Whether to clip the index to the maximum number of outputs""" + + def run(self) -> None: + idx = self.inp_index.receive() + if self.clip.value: + idx = min(idx, len(self.inp) - 1) + item = self.inp[idx].receive() + self.out.send(item) + + +class IndexDistribute(Node, Generic[T]): + """ + Sends an item to an output specified by a dynamic index. + + Receives an index separately and indexes into the outputs to decide where to send the item. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "IndexDistribute", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [style = dashed] + n -> 2 [label = "foo"] + 3 -> n [label = "foo"] + 4 -> n [label = "1"] + } + + """ + + inp: Input[T] = Input() + """Item input""" + + inp_index: Input[int] = Input() + """Which index of the output to send the item to""" + + out: MultiOutput[T] = MultiOutput() + """Item outputs""" + + clip: Flag = Flag(default=False) + """Whether to clip the index to the maximum number of outputs""" + + def run(self) -> None: + idx = self.inp_index.receive() + if self.clip.value: + idx = min(idx, len(self.out) - 1) + item = self.inp.receive() + self.out[idx].send(item) + + +class TimeDistribute(LoopedNode, Generic[T]): + """ + Distributes items over multiple outputs depending on the current iteration. + + Can be seen as a generalized form of `RoundRobin`, with an additional specification + of which output to send data to how many times. For example, the pattern ``[2, 1, 10]`` + will send 2 items to the first output, 1 item to the second, and 10 items to the third. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "TimeDistribute [1, 3]", fillcolor = "#FFFFCC", color = "#B48000"] + n -> 1 [label = 1] + n -> 2 [style = dashed, label = "2, 3, 4"] + 4 -> n [label = "1, 2, 3, 4"] + } + + """ + + inp: Input[T] = Input() + """Item input""" + + out: MultiOutput[T] = MultiOutput(optional=True) + """Multiple outputs to distribute over""" + + pattern: Parameter[list[int]] = Parameter() + """How often to send items to each output, ``-1`` is infinite""" + + cycle: Flag = Flag(default=False) + """Whether to loop around the pattern""" + + _outputs: Iterator[Output[T]] + + def prepare(self) -> None: + pattern = self.pattern.value + if len(pattern) != len(self.out): + raise ValueError( + "The number of entries in the pattern must " + "be equal to the number of connected outputs" + ) + + if any(n < 0 for n in pattern[:-1]): + raise ValueError("Cannot infinitely send to a non-final output") + + self._outputs = itertools.chain.from_iterable( + itertools.repeat(out, n) if n >= 0 else itertools.repeat(out) + for out, n in zip(self.out, pattern) + ) + + if self.cycle.value: + self._outputs = itertools.cycle(self._outputs) + + def run(self) -> None: + if not self.inp.active or all(not out.active for out in self.out): + raise PortInterrupt(INACTIVE) + + next(self._outputs).send(self.inp.receive()) + + +class CopyEveryNIter(LoopedNode, Generic[T]): + """ + Copy a single input packet to multiple output channels, but only every ``n`` iterations. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + edge [fontname = "Helvetica,Arial,sans-serif"] + Copy [label = "Copy", fillcolor = "#FFFFCC", color = "#B48000"] + Copy -> 1 [label = "1, 2, 3"] + Copy -> 2 [label = "2"] + 3 -> Copy [label = "1, 2, 3"] + } + + """ + + inp: Input[T] = Input(mode="copy") + """Single input to broadcast""" + + out: MultiOutput[T] = MultiOutput(optional=True, mode="copy") + """Multiple outputs to broadcast over""" + + freq: Parameter[int] = Parameter(default=-1) + """How often to send, ``-1`` is only on first iteration""" + + def prepare(self) -> None: + self._it = 1 + + def run(self) -> None: + if (not self.inp.active and not self.inp.ready()) or all( + not out.active for out in self.out + ): + raise PortInterrupt(INACTIVE) + + val = self.inp.receive() + self.logger.debug("Received %s", val) + outputs = iter(self.out) + + self.logger.debug("Sending %s", val) + next(outputs).send(deepcopy(val)) + + if (self.freq.value == -1 and self._it == 1) or ( + self.freq.value != -1 and self._it % self.freq.value == 0 + ): + for out in outputs: + self.logger.debug("Sending %s", val) + out.send(deepcopy(val)) + + self._it += 1 + + +class Accumulate(LoopedNode, Generic[T]): + """ + Accumulate multiple independent packets into one large packet. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Accumulate", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "1, 2, 3"] + n -> 1 [label = "[1, 2, 3]"] + } + + """ + + inp: Input[T] = Input() + """Packets to accumulate""" + + out: Output[list[T]] = Output() + """Output for accumulated packets""" + + n_packets: Parameter[int] = Parameter() + """Number of packets to receive before sending one large packet""" + + def run(self) -> None: + packet = [] + for _ in range(self.n_packets.value): + val = self.inp.receive() + packet.append(val) + self.out.send(packet) + + +class Scatter(LoopedNode, Generic[T]): + """ + Decompose one large packet into it's constituent items and send them separately. + + .. graphviz:: + + digraph { + graph [rankdir = "LR"] + edge [fontname = "Helvetica,Arial,sans-serif"] + node [shape = rectangle, label = "", color = white, + style = filled, fontname = "Helvetica,Arial,sans-serif"] + n [label = "Scatter", fillcolor = "#FFFFCC", color = "#B48000"] + 2 -> n [label = "[1, 2, 3]"] + n -> 1 [label = "1, 2, 3"] + } + + """ + + inp: Input[list[T]] = Input() + """Packets of sequences that allow unpacking""" + + out: Output[T] = Output() + """Unpacked data""" + + def run(self) -> None: + packet = self.inp.receive() + if not hasattr(packet, "__len__"): + raise ValueError(f"Packet of type {type(packet)} can not be unpacked") + + for item in packet: + self.out.send(item) + + +class Delay(LoopedNode, Generic[T]): + """Pass on a packet with a custom delay.""" + + inp: Input[T] = Input() + """Data input""" + + out: Output[T] = Output() + """Data output""" + + delay: Parameter[float | int] = Parameter(default=1) + """Delay in seconds""" + + def run(self) -> None: + item = self.inp.receive() + time.sleep(self.delay.value) + self.out.send(item) diff --git a/maize/steps/py.typed b/maize/steps/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/maize/utilities/execution.py b/maize/utilities/execution.py new file mode 100644 index 0000000..c18c3ba --- /dev/null +++ b/maize/utilities/execution.py @@ -0,0 +1,1305 @@ +"""Utilities to execute external software.""" + +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from datetime import timedelta +from enum import auto +import io +import logging +from multiprocessing import get_context +import os +from pathlib import Path +from queue import Queue +import shlex +import stat +import subprocess +import sys +from tempfile import mkdtemp +import time +from typing import TYPE_CHECKING, Any, Concatenate, Literal, ParamSpec, TypeVar, cast +from typing_extensions import Self +import xml.etree.ElementTree as ET + +from psij import ( + InvalidJobException, + Job, + JobAttributes, + JobExecutor, + JobExecutorConfig, + JobSpec, + JobStatus, + ResourceSpecV1, + SubmitException, + JobState, +) +from psij.executors.batch.slurm import SlurmExecutorConfig + +from maize.utilities.utilities import ( + make_list, + split_list, + unique_id, + split_multi, + set_environment, + StrEnum, +) + +if TYPE_CHECKING: + from maize.utilities.validation import Validator + + +class ProcessError(Exception): + """Error called for failed commands.""" + + +# Maize can in theory use either 'fork' or 'spawn' for process handling. However some external +# dependencies (notably PSI/J) can cause subtle problems when changing this setting due to their +# own use of threading / multiprocessing. See this issue for more details: +# https://github.com/ExaWorks/psij-python/issues/387 +DEFAULT_CONTEXT = "fork" + + +def _format_command(result: subprocess.CompletedProcess[bytes] | subprocess.Popen[bytes]) -> str: + """Format a command from a completed process or ``Popen`` instance.""" + return " ".join(cast(list[str], result.args)) + + +def _log_command_output(stdout: bytes | None, stderr: bytes | None) -> str: + """Write any command output to the log.""" + msg = "Command output:\n" + if stdout is not None and len(stdout) > 0: + msg += "---------------- STDOUT ----------------\n" + msg += stdout.decode(errors="ignore") + "\n" + msg += "---------------- STDOUT ----------------\n" + if stderr is not None and len(stderr) > 0: + msg += "---------------- STDERR ----------------\n" + msg += stderr.decode(errors="ignore") + "\n" + msg += "---------------- STDERR ----------------\n" + return msg + + +def _simple_run(command: list[str] | str) -> subprocess.CompletedProcess[bytes]: + """Run a command and return a `subprocess.CompletedProcess` instance.""" + if isinstance(command, str): + command = shlex.split(command) + + return subprocess.run(command, check=False, capture_output=True) + + +class GPUMode(StrEnum): + DEFAULT = auto() + EXCLUSIVE_PROCESS = auto() + + +@dataclass +class GPU: + """ + Indicates the status of the system's GPUs, if available. + + Attributes + ---------- + available + Whether a GPU is available in the system + free + Whether the GPU can run a process right now + free_with_mps + Whether the GPU can run a process if the Nvidia multi-process service (MPS) is used + mode + The compute mode of the GPU, currently just one of (``DEFAULT``, + ``EXCLUSIVE_PROCESS``). In ``EXCLUSIVE_PROCESS`` only one process + can be run at a time, unless MPS is used. ``DEFAULT`` allows any + process to be run normally. + + """ + + available: bool + free: bool + free_with_mps: bool = False + mode: GPUMode | None = None + + @classmethod + def from_system(cls) -> list[Self]: + """Provides information on GPU capabilities""" + cmd = CommandRunner(raise_on_failure=False) + ret = cmd.run_only("nvidia-smi -q -x") + if ret.returncode != 0: + return [cls(available=False, free=False)] + + tree = ET.parse(io.StringIO(ret.stdout.decode())) + + devices = [] + for gpu in tree.findall("gpu"): + # No need for MPS in default compute mode :) + if (res := gpu.find("compute_mode")) is not None and res.text == "Default": + devices.append(cls(available=True, free=True, mode=GPUMode.DEFAULT)) + continue + + names = (name.text or "" for name in gpu.findall("processes/*/process_name")) + types = (typ.text or "" for typ in gpu.findall("processes/*/type")) + + # No processes implies an empty GPU + if not len(list(names)): + devices.append(cls(available=True, free=True, mode=GPUMode.EXCLUSIVE_PROCESS)) + continue + + mps_only = all( + "nvidia-cuda-mps-server" in id or typ == "M+C" for id, typ in zip(names, types) + ) + gpu_ok = all(typ == "G" for typ in types) + devices.append(cls( + available=True, free=gpu_ok, free_with_mps=mps_only, mode=GPUMode.EXCLUSIVE_PROCESS + )) + return devices + + +def gpu_info() -> tuple[bool, bool]: + """ + Provides information on GPU capabilities + + Returns + ------- + tuple[bool, bool] + The first boolean indicates whether the GPU is free to use, the second + indicates whether the GPU can be used, but only using the nvidia MPS daemon. + If no GPU is available returns ``(False, False)``. + + """ + cmd = CommandRunner(raise_on_failure=False) + ret = cmd.run_only("nvidia-smi -q -x") + if ret.returncode != 0: + return False, False + + tree = ET.parse(io.StringIO(ret.stdout.decode())) + + # No need for MPS in default compute mode :) + if (res := tree.find("gpu/compute_mode")) is not None and res.text == "Default": + return True, False + + names = (name.text or "" for name in tree.findall("gpu/processes/*/process_name")) + types = (typ.text or "" for typ in tree.findall("gpu/processes/*/type")) + + # No processes implies an empty GPU + if not len(list(names)): + return True, False + + mps_only = all("nvidia-cuda-mps-server" in id or typ == "M+C" for id, typ in zip(names, types)) + gpu_ok = all(typ == "G" for typ in types) + return gpu_ok, mps_only + + +def job_from_id(job_id: str, backend: str) -> Job: + """Creates a job object from an ID""" + job = Job() + job._native_id = job_id + executor = JobExecutor.get_instance(backend) + executor.attach(job, job_id) + return job + + +class WorkflowStatus(StrEnum): + RUNNING = auto() + """Currently running with no IO interaction""" + + COMPLETED = auto() + """Successfully completed everything""" + + FAILED = auto() + """Failed via exception""" + + QUEUED = auto() + """Queued for execution in the resource manager""" + + CANCELLED = auto() + """Cancelled by user""" + + UNKNOWN = auto() + """Unknown job status""" + + @classmethod + def from_psij(cls, status: JobState) -> "WorkflowStatus": + """Convert a PSIJ job status to a Maize workflow status""" + return PSIJSTATUSTRANSLATION[str(status)] + + +PSIJSTATUSTRANSLATION: dict[str, "WorkflowStatus"] = { + str(JobState.NEW): WorkflowStatus.QUEUED, + str(JobState.QUEUED): WorkflowStatus.QUEUED, + str(JobState.ACTIVE): WorkflowStatus.RUNNING, + str(JobState.FAILED): WorkflowStatus.FAILED, + str(JobState.CANCELED): WorkflowStatus.CANCELLED, + str(JobState.COMPLETED): WorkflowStatus.COMPLETED, +} + + +def check_executable(command: list[str] | str) -> bool: + """ + Checks if a command can be run. + + Parameters + ---------- + command + Command to execute + + Returns + ------- + bool + ``True`` if running the command was successfull, ``False`` otherwise + + """ + exe = make_list(command) + try: + # We cannot use `CommandRunner` here, as initializing the Exaworks PSI/J job executor in + # the main process (where this code will be run, as we're checking if the nodes have all + # the required tools to start) may cause the single process reaper to lock up all child + # jobs. See this related bug: https://github.com/ExaWorks/psij-python/issues/387 + res = _simple_run(exe) + res.check_returncode() + except (FileNotFoundError, subprocess.CalledProcessError): + return False + return True + + +def _parse_slurm_walltime(time_str: str) -> timedelta: + """ + Parses a SLURM walltime string + + Parameters + ---------- + time_str + String with the `SLURM time format `_. + + Returns + ------- + timedelta + Timedelta object with the parsed time interval + + """ + match split_multi(time_str, "-:"): + case [days, hours, minutes, seconds]: + delta = timedelta( + days=int(days), hours=int(hours), minutes=int(minutes), seconds=int(seconds) + ) + case [days, hours, minutes] if "-" in time_str: + delta = timedelta(days=int(days), hours=int(hours), minutes=int(minutes)) + case [days, hours] if "-" in time_str: + delta = timedelta(days=int(days), hours=int(hours)) + case [hours, minutes, seconds]: + delta = timedelta(hours=int(hours), minutes=int(minutes), seconds=int(seconds)) + case [minutes, seconds]: + delta = timedelta(minutes=int(minutes), seconds=int(seconds)) + case [minutes]: + delta = timedelta(minutes=int(minutes)) + return delta + + +# This would normally be part of `run_single_process`, but +# pickling restrictions force us to place it at the top level +def _wrapper( + func: Callable[[], Any], error_queue: "Queue[Exception]", *args: Any, **kwargs: Any +) -> None: # pragma: no cover + try: + func(*args, **kwargs) + except Exception as err: # pylint: disable=broad-except + error_queue.put(err) + + +def run_single_process( + func: Callable[[], Any], name: str | None = None, executable: Path | None = None +) -> None: + """ + Runs a function in a separate process. + + Parameters + ---------- + func + Function to call in a separate process + name + Optional name of the function + executable + Optional python executable to use + + """ + ctx = get_context(DEFAULT_CONTEXT) + + # In some cases we might need to change python environments to get the dependencies + exec_path = sys.executable if executable is None else executable.as_posix() + ctx.set_executable(exec_path) + + # The only way to reliably get raised exceptions in the main process + # is by passing them through a shared queue and re-raising. So we just + # wrap the function of interest to catch any exceptions and pass them on + queue = ctx.Queue() + + proc = ctx.Process( # type: ignore + target=_wrapper, + name=name, + args=( + func, + queue, + ), + ) + proc.start() + + proc.join(timeout=2.0) + if proc.is_alive(): + proc.terminate() + if not queue.empty(): + raise queue.get_nowait() + + +def check_returncode( + result: subprocess.CompletedProcess[bytes], + raise_on_failure: bool = True, + logger: logging.Logger | None = None, +) -> None: + """ + Check the returncode of the process and raise or log a warning. + + Parameters + ---------- + result + Completed process to check + raise_on_failure + Whether to raise an exception on failure + logger + Logger instance to use for command output + + """ + if logger is None: + logger = logging.getLogger() + + # Raise the expected FileNotFoundError if the command + # couldn't be found (to mimic subprocess) + if result.returncode == 127: + msg = f"Command {result.args[0]} not found (returncode {result.returncode})" + if raise_on_failure: + raise FileNotFoundError(msg) + logger.warning(msg) + elif result.returncode != 0: + msg = f"Command {_format_command(result)} failed with returncode {result.returncode}" + logger.warning(_log_command_output(result.stdout, result.stderr)) + if raise_on_failure: + raise ProcessError(msg) + logger.warning(msg) + + +class ProcessBase: + def __init__(self) -> None: + self.logger = logging.getLogger(f"run-{os.getpid()}") + + def check_returncode( + self, + result: subprocess.CompletedProcess[bytes], + raise_on_failure: bool = True, + ) -> None: + """ + Check the returncode of the process and raise or log a warning. + + Parameters + ---------- + result + Completed process to check + raise_on_failure + Whether to raise an exception on failure + + """ + return check_returncode(result, raise_on_failure=raise_on_failure, logger=self.logger) + + def _job_to_completed_process(self, job: Job) -> subprocess.CompletedProcess[bytes]: + """Converts a finished PSI/J job to a `CompletedProcess` instance""" + + # Basically just for mypy + if ( + job.spec is None + or job.spec.stderr_path is None + or job.spec.stdout_path is None + or job.spec.executable is None + or job.spec.arguments is None + ): # pragma: no cover + raise ProcessError("Job was not initialized correctly") + + command = [job.spec.executable, *job.spec.arguments] + + # There seem to be situations in which STDOUT / STDERR is not available + # (although it should just be an empty file in case of no output). Rather + # than crash here we let a potential validator check if the command maybe + # was successful anyway. + stdout = stderr = b"" + if job.spec.stdout_path.exists(): + with job.spec.stdout_path.open("rb") as out: + stdout = out.read() + else: + self.logger.warning( + "STDOUT at %s not found, this may indicate problems writing output", + job.spec.stdout_path.as_posix(), + ) + + if job.spec.stdout_path.exists(): + with job.spec.stderr_path.open("rb") as err: + stderr = err.read() + else: + self.logger.warning( + "STDERR at %s not found, this may indicate problems writing output", + job.spec.stderr_path.as_posix(), + ) + + res = subprocess.CompletedProcess( + command, + # Exit code is 130 if we cancelled due to timeout + returncode=job.status.exit_code if job.status.exit_code is not None else 130, + stdout=stdout, + stderr=stderr, + ) + return res + + +BatchSystemType = Literal["slurm", "rp", "pbspro", "lsf", "flux", "cobalt", "local"] + + +@dataclass +class ResourceManagerConfig: + """Configuration for job resource managers""" + + system: BatchSystemType = "local" + max_jobs: int = 100 + queue: str | None = None + project: str | None = None + launcher: str | None = None + walltime: str = "24:00:00" + polling_interval: int = 120 + + +@dataclass +class JobResourceConfig: + """Configuration for job resources""" + + nodes: int | None = None + processes_per_node: int | None = None + processes: int | None = None + cores_per_process: int | None = None + gpus_per_process: int | None = None + exclusive_use: bool = False + queue: str | None = None + walltime: str | None = None + custom_attributes: dict[str, Any] = field(default_factory=dict) + + def format_custom_attributes(self, system: BatchSystemType) -> dict[str, Any]: + """Provides custom attributes formatted for PSI/J""" + return {f"{system}.{key}": value for key, value in self.custom_attributes.items()} + + +class _JobCounterSemaphore: + def __init__(self, val: int = 0) -> None: + self._val = val + self._total = 0 + + @property + def val(self) -> int: + """Value of the semaphore""" + return self._val + + @property + def total(self) -> int: + """Total number of submissions""" + return self._total + + def inc(self) -> None: + """Increment the counter""" + self._val += 1 + + def dec(self) -> None: + """Decrement the counter""" + self._val -= 1 + self._total += 1 + + +class _SingletonJobHandler(type): + _instances: dict[type, type] = {} + + def __call__(cls, *args: Any, **kwds: Any) -> Any: + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwds) + singleton = cls._instances[cls] + + # We always reset the counter, otherwise successive `run_multi` calls will fail + if hasattr(singleton, "_counter") and "counter" in kwds: + singleton._counter = kwds["counter"] + return singleton + + +class JobHandler(metaclass=_SingletonJobHandler): + """ + Handles safe job cancellation in case something goes wrong. + + This class is implemented as a singleton so that there is + only one `JobHandler` instance per process. This allows safe + job cancellation and shutdowns through the signal handler. + + """ + + def __init__(self, counter: _JobCounterSemaphore | None = None) -> None: + self._jobs: dict[str, Job] = {} + self._counter = counter or _JobCounterSemaphore() + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any) -> None: + self.cancel_all() + + def add(self, job: Job) -> None: + """Register a job with the handler""" + self._counter.inc() + self._jobs[job.id] = job + + def wait_for_slot(self, timeout: float | None = None, n_jobs: int | None = None) -> None: + """Wait for a slot for a new job""" + if n_jobs is None: + return + + start = time.time() + while self._counter.val >= n_jobs: + time.sleep(0.1) + if timeout is not None and (time.time() - start) >= timeout: + self._stop_all() + + def wait(self, timeout: float | None = None) -> list[Job]: + """Wait for all jobs to complete""" + start = time.time() + while self._counter.val > 0: + time.sleep(0.1) + if timeout is not None and (time.time() - start) >= timeout: + self._stop_all() + + return list(self._jobs.values()) + + def cancel_all(self) -> None: + """Cancel all jobs""" + while self._jobs: + _, job = self._jobs.popitem() + if not job.status.final: + job.cancel() + + def _stop_all(self) -> None: + """Stop all jobs""" + for job in self._jobs.values(): + job.cancel() + job.wait() + return + + +class RunningProcess(ProcessBase): + def __init__(self, job: Job, verbose: bool = False, raise_on_failure: bool = True) -> None: + super().__init__() + self.job = job + self.verbose = verbose + self.raise_on_failure = raise_on_failure + self.stdout_path = None if job.spec is None else job.spec.stdout_path + self.stderr_path = None if job.spec is None else job.spec.stderr_path + + @property + def id(self) -> str | None: + """Provides the job ID of the batch system""" + return self.job.native_id + + def is_alive(self) -> bool: + """ + Whether the process is alive and running + + Returns + ------- + bool + ``True`` if the process is alive and running + + """ + return not self.job.status.final + + def kill(self, timeout: int = 0) -> None: + """ + Kill the process + + Parameters + ---------- + timeout + Timeout in seconds before forcefully killing + + """ + time.sleep(timeout) + if not self.job.status.final: + self.job.cancel() + + def wait(self) -> subprocess.CompletedProcess[bytes]: + """ + Wait for the process to complete + + Returns + ------- + subprocess.CompletedProcess[bytes] + Result of the execution, including STDOUT and STDERR + + """ + self.job.wait() + result = self._job_to_completed_process(self.job) + + self.check_returncode(result, raise_on_failure=self.raise_on_failure) + + if self.verbose: + self.logger.debug(_log_command_output(result.stdout, result.stderr)) + + return result + + +# This decorator exists to avoid always creating a new PSI/J `JobExecutor` instance with +# its own queue polling thread, causing potentially high loads on a batch system +_T = TypeVar("_T") +_P = ParamSpec("_P") +_memoized = {} + + +# I attempted to type this without explicitly adding the name keyword argument +# to `CommandRunner`, but concatenating keyword arguments with `ParamSpec`s is +# explicitly disallowed by PEP612 :( +def _memoize(cls: Callable[_P, _T]) -> Callable[_P, _T]: + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: + name = kwargs.get("name", None) + if name is None: + return cls(*args, **kwargs) + if name not in _memoized: + _memoized[name] = cls(*args, **kwargs) + return _memoized[name] + + return inner + + +class CommandRunnerPSIJ(ProcessBase): + """ + Command running utility based on PSI/J. + + Instantiate with preferred options and use a `run` method with your command. + + .. danger:: + It is not recommended to instantiate this class in the main process + (i.e. outside of your nodes ``run()`` and ``prepare()`` methods). This is due to + `possible subtle threading problems `_ + from the interplay of maize and PSI/J. + + .. danger:: + The default user-facing class uses memoization. If you're instantiating multiple + `CommandRunner` instances in the same process, they will by default refer to the + same instance. To create separate instances with potentially different parameters, + supply a custom name. + + Parameters + ---------- + name + The name given to this instance. The user-accessible version of this class + is memoized, so the same name will refer to the same instance. + raise_on_failure + Whether to raise an exception on failure, or whether to just return `False`. + working_dir + The working directory to use for execution, will use the current one by default. + validators + One or more `Validator` instances that will be called on the result of the command. + prefer_batch + Whether to prefer running on a batch submission system such as SLURM, if available + rm_config + Configuration of the resource manager + max_retries + How often to reattempt job submission for batch systems + + """ + + def _make_callback( + self, count: _JobCounterSemaphore, max_count: int = 0 + ) -> Callable[[Job, JobStatus], None]: + def _callback(job: Job, status: JobStatus) -> None: + if status.final: + self.logger.info("Job completed (%s/%s)", count.total + 1, max_count) + if job.status.exit_code is not None and job.status.exit_code > 0: + res = self._job_to_completed_process(job) + self.logger.warning( + "Job %s failed with exit code %s (%s)", + job.native_id, + res.returncode, + job.status.state, + ) + self.logger.warning(_log_command_output(stdout=res.stdout, stderr=res.stderr)) + count.dec() + + return _callback + + def __init__( + self, + *, + name: str | None = None, + raise_on_failure: bool = True, + working_dir: Path | None = None, + validators: Sequence["Validator"] | None = None, + prefer_batch: bool = False, + rm_config: ResourceManagerConfig | None = None, + max_retries: int = 3, + ) -> None: + super().__init__() + self.name = name + self.raise_on_failure = raise_on_failure + self.working_dir = working_dir.absolute() if working_dir is not None else Path.cwd() + self._write_output_to_temp = working_dir is None + self.validators = validators or [] + self.config = ResourceManagerConfig() if rm_config is None else rm_config + self.max_retries = max_retries + + # We're going to be doing local execution most of the time, + # most jobs we run are going to be relatively short, and we'll + # already have a reservation for the main maize workflow job + system = "local" + exec_config = JobExecutorConfig() + if prefer_batch: + self.logger.debug("Attempting to run on batch system") + # FIXME check for any submission system + if check_executable("sinfo"): + system = self.config.system + + # We override the default 30s polling interval to go easy on delicate batch systems + exec_config = SlurmExecutorConfig( + queue_polling_interval=self.config.polling_interval + ) + elif self.config.system != "local": + self.logger.warning( + "'%s' was not found on your system, running locally", self.config.system + ) + self._executor = JobExecutor.get_instance(system, config=exec_config) + + def validate(self, result: subprocess.CompletedProcess[bytes]) -> None: + """ + Validate a process result. + + Parameters + ---------- + result + Process result to validate + + Raises + ------ + ProcessError + If any of the validators failed or the returncode was not zero + + """ + + for validator in self.validators: + if not validator(result): + msg = ( + f"Validation failure for command '{_format_command(result)}' " + f"with validator '{validator}'" + ) + self.logger.warning(_log_command_output(result.stdout, result.stderr)) + if self.raise_on_failure: + raise ProcessError(msg) + self.logger.warning(msg) + + def run_async( + self, + command: list[str] | str, + verbose: bool = False, + working_dir: Path | None = None, + command_input: str | None = None, + pre_execution: list[str] | str | None = None, + cuda_mps: bool = False, + config: JobResourceConfig | None = None, + n_retries: int = 3, + ) -> RunningProcess: + """ + Run a command locally. + + Parameters + ---------- + command + Command to run as a single string, or a list of strings + verbose + If ``True`` will also log any STDOUT or STDERR output + working_dir + Optional working directory + command_input + Text string used as input for command + pre_execution + Command to run directly before the main one + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + config + Job submission config if using batch system + n_retries + Number of submission retries + + Returns + ------- + RunningProcess + Process handler allowing waiting, killing, or monitoring the running command + + """ + + job = self._create_job( + command, + config=config, + working_dir=working_dir, + verbose=verbose, + command_input=command_input, + pre_execution=pre_execution, + cuda_mps=cuda_mps, + ) + for i in range(n_retries): + self._attempt_submission(job) + proc = RunningProcess(job, verbose=verbose, raise_on_failure=self.raise_on_failure) + if proc.id is not None: + return proc + self.logger.warning( + "Job did not receive a valid ID, retrying... (%s/%s)", i + 1, n_retries + ) + raise ProcessError(f"Unable to successfully submit job after {n_retries} retries") + + def run_only( + self, + command: list[str] | str, + verbose: bool = False, + working_dir: Path | None = None, + command_input: str | None = None, + config: JobResourceConfig | None = None, + pre_execution: list[str] | str | None = None, + timeout: float | None = None, + cuda_mps: bool = False, + ) -> subprocess.CompletedProcess[bytes]: + """ + Run a command locally and block. + + Parameters + ---------- + command + Command to run as a single string, or a list of strings + verbose + If ``True`` will also log any STDOUT or STDERR output + working_dir + Optional working directory + command_input + Text string used as input for command + config + Resource configuration for jobs + pre_execution + Command to run directly before the main one + timeout + Maximum runtime for the command in seconds, or unlimited if ``None`` + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + + Returns + ------- + subprocess.CompletedProcess[bytes] + Result of the execution, including STDOUT and STDERR + + Raises + ------ + ProcessError + If the returncode was not zero + + """ + + job = self._create_job( + command, + working_dir=working_dir, + verbose=verbose, + command_input=command_input, + config=config, + pre_execution=pre_execution, + cuda_mps=cuda_mps, + ) + + with JobHandler(counter=_JobCounterSemaphore()) as hand: + self._attempt_submission(job) + hand.add(job) + job.wait(timeout=timedelta(seconds=timeout) if timeout is not None else None) + + result = self._job_to_completed_process(job) + + self.check_returncode(result, raise_on_failure=self.raise_on_failure) + if verbose: + self.logger.debug(_log_command_output(result.stdout, result.stderr)) + + return result + + def run_validate( + self, + command: list[str] | str, + verbose: bool = False, + working_dir: Path | None = None, + command_input: str | None = None, + config: JobResourceConfig | None = None, + pre_execution: list[str] | str | None = None, + timeout: float | None = None, + cuda_mps: bool = False, + ) -> subprocess.CompletedProcess[bytes]: + """ + Run a command and validate. + + Parameters + ---------- + command + Command to run as a single string, or a list of strings + verbose + If ``True`` will also log any STDOUT or STDERR output + working_dir + Optional working directory + command_input + Text string used as input for command + config + Resource configuration for jobs + pre_execution + Command to run directly before the main one + timeout + Maximum runtime for the command in seconds, or unlimited if ``None`` + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + + Returns + ------- + subprocess.CompletedProcess[bytes] + Result of the execution, including STDOUT and STDERR + + Raises + ------ + ProcessError + If any of the validators failed or the returncode was not zero + + """ + + result = self.run_only( + command=command, + verbose=verbose, + working_dir=working_dir, + command_input=command_input, + config=config, + pre_execution=pre_execution, + timeout=timeout, + cuda_mps=cuda_mps, + ) + self.validate(result=result) + return result + + # Convenience alias for simple calls to `run_validate` + run = run_validate + + def run_parallel( + self, + commands: Sequence[list[str] | str], + verbose: bool = False, + n_jobs: int = 1, + validate: bool = False, + working_dirs: Sequence[Path | None] | None = None, + command_inputs: Sequence[str | None] | None = None, + config: JobResourceConfig | None = None, + pre_execution: list[str] | str | None = None, + timeout: float | None = None, + cuda_mps: bool = False, + n_batch: int | None = None, + batchsize: int | None = None, + ) -> list[subprocess.CompletedProcess[bytes]]: + """ + Run multiple commands locally in parallel and block. + + Parameters + ---------- + commands + Commands to run as a list of single strings, or a list of lists + verbose + If ``True`` will also log any STDOUT or STDERR output + n_jobs + Number of processes to spawn at once + validate + Whether to validate the command execution + working_dirs + Directories to execute each command in + command_input + Text string used as input for each command, or ``None`` + config + Resource configuration for jobs + pre_execution + Command to run directly before the main one + timeout + Maximum runtime for the command in seconds, or unlimited if ``None`` + cuda_mps + Use the multi-process service to run multiple CUDA job processes on a single GPU + n_batch + Number of batches of commands to run together + sequentially. Incompatible with ``batchsize``. + batchsize + Size of command batches to run sequentially. Incompatible with ``n_batch``. + + Returns + ------- + list[subprocess.CompletedProcess[bytes]] + Result of the execution, including STDOUT and STDERR + + Raises + ------ + ProcessError + If the returncode was not zero + + """ + if command_inputs is not None and (n_batch is not None or batchsize is not None): + raise ValueError( + "Using command inputs and batching of commands is currently not compatible" + ) + + queue: Queue[tuple[list[str] | str, Path | None, str | None]] = Queue() + working_dirs = [None for _ in commands] if working_dirs is None else list(working_dirs) + if command_inputs is None: + command_inputs = [None for _ in commands] + + if n_batch is not None or batchsize is not None: + commands = [ + self._make_batch(comms) + for comms in split_list(list(commands), n_batch=n_batch, batchsize=batchsize) + ] + new_wds = [] + for wds in split_list(working_dirs, n_batch=n_batch, batchsize=batchsize): + if len(set(wds)) > 1: + raise ValueError("Working directories are not consistent within batches") + new_wds.append(wds[0]) + working_dirs = new_wds + command_inputs = [None for _ in commands] + + for command, work_dir, cmd_input in zip(commands, working_dirs, command_inputs): + queue.put((command, work_dir, cmd_input)) + + # If we're using a batch system we can submit as many jobs as we like, + # up to a reasonable maximum set in the system config + n_jobs = max(self.config.max_jobs, n_jobs) if self.config.system != "local" else n_jobs + + # Keeps track of completed jobs via a callback and a counter, + # and thus avoids using a blocking call to `wait` (or threads) + counter = _JobCounterSemaphore() + callback = self._make_callback(counter, max_count=len(commands)) + self._executor.set_job_status_callback(callback) + + with JobHandler(counter=counter) as hand: + while not queue.empty(): + hand.wait_for_slot(timeout=timeout, n_jobs=n_jobs) + command, work_dir, cmd_input = queue.get() + tentative_job = self._create_job( + command, + working_dir=work_dir, + verbose=verbose, + command_input=cmd_input, + config=config, + pre_execution=pre_execution, + cuda_mps=cuda_mps, + ) + + # Some failure modes necessitate recreating the job from the spec to + # play it safe, which is why we have a potentially new job instance here + job = self._attempt_submission(tentative_job) + hand.add(job) + + # Wait for all jobs to complete + jobs = hand.wait(timeout=timeout) + + # Collect all results + results: list[subprocess.CompletedProcess[bytes]] = [] + for job in jobs: + result = self._job_to_completed_process(job) + self.check_returncode(result, raise_on_failure=self.raise_on_failure) + if verbose: + self.logger.debug(_log_command_output(result.stdout, result.stderr)) + + if validate: + self.validate(result) + results.append(result) + return results + + def _attempt_submission(self, job: Job) -> Job: + """Attempts to submit a job""" + + # No special treatment for normal local commands + if self._executor.name == "local": + self._executor.submit(job) + self.logger.info( + "Running job with PSI/J id=%s, stdout=%s, stderr=%s", + job.id, + None if job.spec is None else job.spec.stdout_path, + None if job.spec is None else job.spec.stderr_path, + ) + return job + + # Batch systems like SLURM may have the occassional + # hiccup, we try multiple times to make sure + for i in range(self.max_retries): + try: + self._executor.submit(job) + except SubmitException as err: + self.logger.warning( + "Error submitting job, trying again (%s/%s)", + i + 1, + self.max_retries, + exc_info=err, + ) + time.sleep(10) + except InvalidJobException as err: + self.logger.warning( + "Error submitting, possible duplicate, trying again (%s/%s)", + i + 1, + self.max_retries, + exc_info=err, + ) + time.sleep(10) + + # Recreate job to be safe and get a new unique id + job = Job(job.spec) + else: + self.logger.info( + "Submitted %s job with PSI/J id=%s, native id=%s, stdout=%s, stderr=%s", + None if job.executor is None else job.executor.name, + job.id, + job.native_id, + None if job.spec is None else job.spec.stdout_path, + None if job.spec is None else job.spec.stderr_path, + ) + return job + + # Final try + self._executor.submit(job) + return job + + @staticmethod + def _make_batch(commands: Sequence[list[str] | str]) -> str: + """Batches commands together into a script""" + file = Path(mkdtemp(prefix="maize-")) / "commands.sh" + with file.open("w") as out: + for command in commands: + if isinstance(command, list): + command = " ".join(command) + out.write(command + "\n") + file.chmod(stat.S_IRWXU | stat.S_IRWXG) + return file.as_posix() + + def _create_job( + self, + command: list[str] | str, + working_dir: Path | None = None, + verbose: bool = False, + command_input: str | None = None, + config: JobResourceConfig | None = None, + pre_execution: list[str] | str | None = None, + cuda_mps: bool = False, + ) -> Job: + """Creates a PSI/J job description""" + + # We can't fully rely on `subprocess.run()` here, so we split it ourselves + if isinstance(command, str): + command = shlex.split(command) + + if verbose: + self.logger.debug("Running command: %s", " ".join(command)) + + exe, *arguments = command + + # General resource manager options, options that shouldn't change much from job to job + delta = _parse_slurm_walltime(self.config.walltime) + job_attr = JobAttributes( + queue_name=self.config.queue + if config is None or config.queue is None + else config.queue, + project_name=self.config.project, + duration=delta + if config is None or config.walltime is None + else _parse_slurm_walltime(config.walltime), + custom_attributes=None + if config is None + else config.format_custom_attributes(self.config.system), + ) + + # Resource configuration + if config is not None: + resource_attr = ResourceSpecV1( + node_count=config.nodes, + processes_per_node=config.processes_per_node, + process_count=config.processes, + cpu_cores_per_process=config.cores_per_process, + gpu_cores_per_process=config.gpus_per_process, + exclusive_node_use=config.exclusive_use, + ) + else: + resource_attr = None + + # Annoyingly, we can't access STDOUT / STDERR directly when + # using PSI/J, so we always have to create temporary files + base_dir = ( + Path(mkdtemp()) if self._write_output_to_temp else (working_dir or self.working_dir) + ) + base = base_dir / f"job-{unique_id()}" + stdout = base.with_name(base.name + "-out") + stderr = base.with_name(base.name + "-err") + + # Similarly, command inputs need to be written to a file and cannot be piped + stdin = None + if command_input is not None: + stdin = base.with_name(base.name + "-in") + with stdin.open("wb") as inp: + inp.write(command_input.encode()) + + env: dict[str, str] = {} + + # We may need to run the multi-process daemon to allow multiple processes on a single GPU + if cuda_mps: + mps_dir = mkdtemp(prefix="mps") + env |= { + "CUDA_MPS_LOG_DIRECTORY": mps_dir, + "CUDA_MPS_PIPE_DIRECTORY": mps_dir, + "MPS_DIR": mps_dir, + } + set_environment(env) + self.logger.debug("Spawning job with CUDA multi-process service in %s", mps_dir) + + # Pre-execution script, we generally avoid this but in some cases it can be required + pre_script = base.with_name(base.name + "-pre") + if pre_execution is not None and isinstance(pre_execution, str): + pre_execution = shlex.split(pre_execution) + + if cuda_mps: + if pre_execution is None: + pre_execution = ["nvidia-cuda-mps-control", "-d"] + else: + pre_execution.extend(["nvidia-cuda-mps-control", "-d"]) + + if pre_execution is not None: + if self.config.launcher == "srun": + self.logger.warning( + "When using launcher 'srun' pre-execution commands may not propagate" + ) + + with pre_script.open("w") as pre: + pre.write("#!/bin/bash\n") + pre.write(" ".join(pre_execution)) + self.logger.debug("Using pre-execution command: %s", " ".join(pre_execution)) + + # For some versions of PSI/J passing the environment explicitly will produce multiple + # `--export` statements, of which the last one will always override all the previous + # ones, resulting in an incomplete environment. So we instead rely on full environment + # inheritance for now. + spec = JobSpec( + executable=exe, + arguments=arguments, + directory=working_dir or self.working_dir, + stderr_path=stderr, + stdout_path=stdout, + stdin_path=stdin, + attributes=job_attr, + inherit_environment=True, + launcher=self.config.launcher, + resources=resource_attr, + pre_launch=pre_script if pre_execution is not None else None, + ) + return Job(spec) + + +CommandRunner = _memoize(CommandRunnerPSIJ) + +# This is required for testing purposes, as pytest will attempt to run many tests concurrently, +# leading to tests sharing `CommandRunner` instances due to memoization when they shouldn't +# (as we're testing different parameters). This is not a problem in actual use, since +# `CommandRunner` will generally only be used from inside separate processes. +_UnmemoizedCommandRunner = CommandRunnerPSIJ diff --git a/maize/utilities/io.py b/maize/utilities/io.py new file mode 100644 index 0000000..5cea5fd --- /dev/null +++ b/maize/utilities/io.py @@ -0,0 +1,708 @@ +"""Various input / output functionality.""" + +import argparse +import builtins as _b +from collections.abc import Sequence, Callable, Iterable +from dataclasses import dataclass, field +import importlib +import importlib.metadata +import importlib.resources +import importlib.util +import inspect +import json +import logging +from pathlib import Path, PosixPath +import pkgutil +import os +import shutil +import sys +from tempfile import mkdtemp +import time +from types import ModuleType +import typing +from typing import ( + Annotated, + Any, + Literal, + TypeVar, + TYPE_CHECKING, + TypedDict, + get_args, + get_origin, + cast, +) +from typing_extensions import Self, assert_never + +import toml +import yaml + +from maize.utilities.execution import ResourceManagerConfig + +if TYPE_CHECKING: + from maize.core.workflow import Workflow + + +class ScriptPairType(TypedDict): + interpreter: str + location: Path + + +ScriptSpecType = dict[str, ScriptPairType] + + +class NodeConfigDict(TypedDict): + """Dictionary form of `NodeConfig`""" + + python: str + modules: list[str] + scripts: dict[str, dict[Literal["interpreter", "location"], str]] + commands: dict[str, str] + + +T = TypeVar("T") + + +XDG = "XDG_CONFIG_HOME" +MAIZE_CONFIG_ENVVAR = "MAIZE_CONFIG" + + +log_build = logging.getLogger("build") +log_run = logging.getLogger(f"run-{os.getpid()}") + + +def _find_install_config() -> Path | None: + """Finds a potential config file in the maize package directory""" + # TODO This is ugly, not 100% sure we need the catch at this point + try: + for mod in importlib.resources.files("maize").iterdir(): + if (config := cast(Path, mod).parent.parent / "maize.toml").exists(): + return config + except NotADirectoryError: + log_build.debug("Problems reading importlib package directory data") + + log_build.debug("Config not found in package directory") + return None + + +# See: https://xdgbasedirectoryspecification.com/ +def _valid_xdg_path() -> bool: + """Checks the XDG path spec is valid""" + return XDG in os.environ and bool(os.environ[XDG]) and Path(os.environ[XDG]).is_absolute() + + +def expand_shell_vars(path: Path) -> Path: + """Expands paths containing shell variables to the full path""" + return Path(os.path.expandvars(path)) + + +def remove_dir_contents(path: Path) -> None: + """Removes all items contained in a directory""" + items = list(path.glob("*")) + log_run.debug("Found %s items to remove", len(items)) + for item in items: + log_run.debug("Removing '%s'", item) + if item.is_dir(): + shutil.rmtree(item) + else: + item.unlink(missing_ok=True) + + +def wait_for_file(path: Path, timeout: int | None = None, zero_byte_check: bool = True) -> None: + """ + Wait for a file to be created, or time out. + + Parameters + ---------- + path + Path to the file + timeout + Timeout in seconds, if not ``None`` will raise a `TimeoutError` + zero_byte_check + Whether to check if the generated file is empty + + """ + start = time.time() + while not path.exists() or (path.stat().st_size == 0 and zero_byte_check): + time.sleep(0.5) + if timeout is not None and (time.time() - start) >= timeout: + raise TimeoutError(f"Waiting for file {path} timed out") + + +def common_parent(files: Sequence[Path]) -> Path: + """ + Provides the common parent directory for a list of files. + + Parameters + ---------- + files + List of paths + + Returns + ------- + Path + Common parent directory + + """ + files = [file.absolute() for file in files] + if len(files) == 1 or len(set(files)) == 1: + return files[0].parent + + common_parts: list[str] = [] + + # We take a "vertical slice" of all paths, starting from root, so first + # iteration will be ("/", "/", ...), followed by e.g. ("Users", "Users") etc. + for parts in zip(*(file.parts for file in files)): + if len(set(parts)) == 1: + common_parts.append(parts[0]) + + # This is where the paths diverge + else: + break + return Path(*common_parts) + + +def sendtree( + files: dict[T, Path], dest: Path, mode: Literal["move", "copy", "link"] = "copy" +) -> dict[T, Path]: + """ + Links, copies or moves multiple files to a destination directory and preserves the structure. + + Parameters + ---------- + files + Paths to link / copy + dest + Destination directory + mode + Whether to link, copy or move the files + + Returns + ------- + dict[Any, Path] + Created links / copies + + """ + files = {k: file.absolute() for k, file in files.items()} + common = common_parent(list(files.values())) + + results: dict[T, Path] = {} + for k, file in files.items(): + dest_path = dest.absolute() / file.relative_to(common) + if not dest_path.exists(): + dest_path.parent.mkdir(parents=True, exist_ok=True) + if mode == "link": + os.symlink(file, dest_path) + elif mode == "copy": + shutil.copy(file, dest_path) + elif mode == "move": + shutil.move(file, dest_path) + else: + assert_never(mode) + results[k] = dest_path + return results + + +@dataclass +class NodeConfig: + """ + Node specific configuration. + + Parameters + ---------- + python + Python interpreter to use to run the node + modules + Map from callables to modules + scripts + Map from callables to interpreter - script pairs + commands + Paths to specific commands + parameters + Default parameter settings + + """ + + python: Path = field(default_factory=lambda: Path(sys.executable)) + modules: list[str] = field(default_factory=list) + commands: dict[str, Path | str] = field(default_factory=dict) + scripts: ScriptSpecType = field(default_factory=dict) + parameters: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Any) -> Self: + """ + Generate from a dictionary. + + Parameters + ---------- + data + Dictionary read in from a config file + + """ + config = cls() + if "python" in data: + config.python = data["python"] + if "modules" in data: + config.modules = data["modules"] + if "scripts" in data: + config.scripts = data["scripts"] + + # Current python executable is the default + for dic in config.scripts.values(): + if "interpreter" not in dic: + dic["interpreter"] = sys.executable + + # Make sure we have paths + dic["location"] = Path(dic["location"]) + + if "commands" in data: + config.commands = {exe: path for exe, path in data["commands"].items()} + if "parameters" in data: + config.parameters = data["parameters"] + + return config + + def generate_template(self, required_callables: list[str]) -> NodeConfigDict: + """ + Generate a template configuration + + Parameters + ---------- + required_callables + The list of software to generate a template for + + Returns + ------- + NodeConfigDict + Dictionary that can be serialized or used directly + + """ + res: NodeConfigDict = { + "python": self.python.as_posix(), + "modules": self.modules, + "commands": {prog: f"path/to/{prog}" for prog in required_callables}, + "scripts": { + prog: {"interpreter": "path/to/python", "location": f"path/to/{prog}"} + for prog in required_callables + }, + } + return res + + def generate_template_toml(self, name: str, required_callables: list[str]) -> str: + """ + Generate a template configuration as a TOML string + + Parameters + ---------- + required_callables + The list of software to generate a template for + + Returns + ------- + str + TOML config string + + """ + return toml.dumps({name: self.generate_template(required_callables)}) + + +@dataclass +class Config: + """ + Global configuration options. + + Parameters + ---------- + packages + List of namespace packages to load + scratch + The directory the workflow should be created in. Uses a temporary directory by default. + batch_config + Default options to be passed to the batch submission system + environment + Any environment variables to be set in the execution context + nodes + Entries specific to each node + + Examples + -------- + Here's an example configuration file with all sections: + + .. literalinclude:: ../../../examples/config.toml + :language: toml + :linenos: + + """ + + packages: list[str] = field(default_factory=lambda: ["maize.steps.mai", "maize.graphs.mai"]) + scratch: Path = Path(mkdtemp()) + batch_config: ResourceManagerConfig = field(default_factory=ResourceManagerConfig) + environment: dict[str, str] = field(default_factory=dict) + nodes: dict[str, NodeConfig] = field(default_factory=dict) + + @classmethod + def from_default(cls) -> Self: + """ + Create a default configuration from (in this order of priorities): + + * A path specified using the ``MAIZE_CONFIG`` environment variable + * A config file at ``~/.config/maize.toml`` + * A config file in the current package directory + + """ + config = cls() + if MAIZE_CONFIG_ENVVAR in os.environ: + config_file = Path(os.environ[MAIZE_CONFIG_ENVVAR]) + config.update(config_file) + log_build.debug("Using config at %s", config_file.as_posix()) + elif _valid_xdg_path() and (config_file := Path(os.environ[XDG]) / "maize.toml").exists(): + config.update(config_file) + log_build.debug("Using '%s' config at %s", XDG, config_file.as_posix()) + elif (install_config := _find_install_config()) is not None: + config.update(install_config) + log_build.debug("Using installation config at %s", install_config.as_posix()) + else: + msg = "Could not find a config file" + if not _valid_xdg_path(): + msg += f" (${XDG} is not set)" + log_build.warning(msg) + return config + + def update(self, file: Path) -> None: + """ + Read a maize configuration file. + + Parameters + ---------- + file + Path to the configuration file + + """ + data = read_input(file) + log_build.debug("Updating config with %s", file.as_posix()) + + for key, item in data.items(): + match key: + case "batch": + self.batch_config = ResourceManagerConfig(**item) + case "scratch": + self.scratch = Path(item) + case "environment": + self.environment = item + case "packages": + self.packages.extend(item) + case _: + self.nodes[key.lower()] = NodeConfig.from_dict(item) + + +# It's enough to import the base namespace package and let importlib +# find all modules. Any defined custom nodes will then be registered +# internally, and we don't have to refer to the explicit module path +# for workflow definitions. See the namespace package discovery documentation: +# https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ +def get_plugins(package: ModuleType | str) -> dict[str, ModuleType]: + """ + Finds packages in a given namespace. + + Parameters + ---------- + package + Base namespace package to load + + Returns + ------- + dict[str, ModuleType] + Dictionary of module names and loaded modules + + """ + if isinstance(package, str): + package = importlib.import_module(package) + return { + name: importlib.import_module(name) + for _, name, _ in pkgutil.iter_modules(package.__path__, package.__name__ + ".") + } + + +def load_file(file: Path | str, name: str | None = None) -> ModuleType: + """ + Load a python file as a module. + + Parameters + ---------- + file + Python file to load + name + Optional name to use for the module, will use the filename if not given + + Returns + ------- + ModuleType + The loaded module + + """ + file = Path(file) + name = file.name if name is None else name + spec = importlib.util.spec_from_file_location(name, file) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to import file '{file.as_posix()}'") + + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +class DictAction(argparse.Action): # pragma: no cover + """Allows parsing of dictionaries from the commandline""" + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: str | None = None, + ) -> None: + if values is not None: + if isinstance(values, str): + values = [values] + + keywords = dict(token.split("=") for token in values) + setattr(namespace, self.dest, keywords) + + +def args_from_function( + parser: argparse.ArgumentParser, func: Callable[..., Any] +) -> argparse.ArgumentParser: + """ + Add function arguments to an `argparse.ArgumentParser`. + + Parameters + ---------- + parser + The argument parser object + func + The function whose arguments should be mapped to the parser. + The function must be fully annotated, and it is recommended + to always supply a default. Furthermore, only the following + types are allowed: `typing.Literal`, `bool`, `int`, `float`, + `complex`, `str`, and `bytes`. + + Returns + ------- + argparse.ArgumentParser + The parser object with added arguments + + Raises + ------ + TypeError + If a function argument doesn't match one of the types specified above + + """ + func_args = inspect.getfullargspec(func).annotations + func_args.pop("return") + for arg_name, arg_type in func_args.items(): + dargs = None + if get_origin(arg_type) == Annotated: + arg_type = get_args(arg_type)[0] + if get_origin(arg_type) is not None: + dargs = get_args(arg_type) + arg_type = get_origin(arg_type) + + match arg_type: + # Just a simple flag + case _b.bool: + parser.add_argument(f"--{arg_name}", action="store_true") + + # Several options + case typing.Literal: + parser.add_argument(f"--{arg_name}", type=str, choices=dargs) + + # Anything else should be a callable type, e.g. int, float... + case _b.int | _b.float | _b.complex | _b.str | _b.bytes: + parser.add_argument(f"--{arg_name}", type=arg_type) # type: ignore + + # No exotic types for now + case _: + raise TypeError( + f"Type '{arg_type}' of argument '{arg_name}' is" + "currently not supported for dynamic graph construction" + ) + return parser + + +def parse_groups( + parser: argparse.ArgumentParser, extra_args: list[str] | None = None +) -> dict[str, argparse.Namespace]: + """ + Parse commandline arguments into separate groups. + + Parameters + ---------- + parser + Parser with initialised groups and added arguments + extra_args + Additional arguments to be parsed + + Returns + ------- + dict[str, dict[str, Any]] + Parsed arguments sorted by group + + """ + sys_args = sys.argv[1:] + if extra_args is not None: + sys_args += extra_args + args = parser.parse_args(sys_args) + groups = {} + for group in parser._action_groups: # pylint: disable=protected-access + if group.title is None: + continue + actions = { + arg.dest: getattr(args, arg.dest, None) + for arg in group._group_actions # pylint: disable=protected-access + } + groups[group.title] = argparse.Namespace(**actions) + return groups + + +def create_default_parser(help: bool = True) -> argparse.ArgumentParser: + """ + Creates the default maize commandline arguments. + + Returns + ------- + argparse.ArgumentParser + The created parser object + + """ + parser = argparse.ArgumentParser(description="Flow-based graph execution engine", add_help=help) + conf = parser.add_argument_group("maize") + conf.add_argument("--version", action="version", version=importlib.metadata.version("maize")) + conf.add_argument( + "-c", + "--check", + action="store_true", + default=False, + help="Check if the graph was built correctly and exit", + ) + conf.add_argument( + "-l", "--list", action="store_true", default=False, help="List all available nodes and exit" + ) + conf.add_argument( + "-o", + "--options", + action="store_true", + default=False, + help="List all exposed workflow parameters and exit", + ) + conf.add_argument( + "-d", "--debug", action="store_true", default=False, help="Provide debugging information" + ) + conf.add_argument( + "-q", + "--quiet", + action="store_true", + default=False, + help="Silence all output except errors and warnings", + ) + conf.add_argument("--keep", action="store_true", default=False, help="Keep all output files") + conf.add_argument("--config", type=Path, help="Global configuration file to use") + conf.add_argument("--scratch", type=Path, help="Workflow scratch location") + conf.add_argument("--parameters", type=Path, help="Additional parameters in JSON format") + conf.add_argument("--log", type=Path, help="Logfile to use") + return parser + + +def setup_workflow(workflow: "Workflow") -> None: + """ + Sets up an initialized workflow so that it can be run on the commandline as a script. + + Parameters + ---------- + workflow + The built workflow object to expose + + """ + # Argument parsing - we create separate groups for + # global settings and workflow specific options + parser = create_default_parser() + parser.description = workflow.description + + # Workflow-specific settings + flow = parser.add_argument_group(workflow.name) + flow = workflow.add_arguments(flow) + groups = parse_groups(parser) + + # Global settings + args = groups["maize"] + workflow.update_settings_with_args(args) + workflow.update_parameters(**vars(groups[workflow.name])) + + # Execution + workflow.check() + if args.check: + workflow.logger.info("Workflow compiled successfully") + return + + workflow.execute() + + +def with_keys(data: dict[T, Any], keys: Iterable[T]) -> dict[T, Any]: + """Provide a dictionary subset based on keys.""" + return {k: v for k, v in data.items() if k in keys} + + +def with_fields(data: Any, keys: Iterable[T]) -> dict[T, Any]: + """Provide a dictionary based on a subset of object attributes.""" + return with_keys(data.__dict__, keys=keys) + + +class _PathEncoder(json.JSONEncoder): # pragma: no cover + def default(self, o: Any) -> Any: + if isinstance(o, Path): + return o.as_posix() + return json.JSONEncoder.default(self, o) + + +def read_input(path: Path) -> dict[str, Any]: + """Reads an input file in JSON, YAML or TOML format and returns a dictionary.""" + if not path.exists(): + raise FileNotFoundError(f"File at {path.as_posix()} not found") + + data: dict[str, Any] + with path.open("r") as file: + suffix = path.suffix.lower()[1:] + if suffix == "json": + data = json.load(file) + elif suffix in ("yaml", "yml"): + # FIXME Unsafe load should NOT be required here, (we are importing modules such + # as pathlib that would be required to reconstruct all python objects). This + # closed issue references this problem: https://github.com/yaml/pyyaml/issues/665 + data = yaml.unsafe_load(file.read()) + elif suffix == "toml": + data = toml.loads(file.read()) + else: + raise ValueError(f"Unknown type '{suffix}'. Valid types: 'json', 'yaml', 'toml'") + + return data + + +def write_input(path: Path, data: dict[str, Any]) -> None: + """Saves a dictionary in JSON, YAML or TOML format.""" + + # Force dumping Path objects as strings + def path_representer(dumper: yaml.Dumper, data: Path) -> yaml.ScalarNode: + return dumper.represent_str(f"{data.as_posix()}") + + yaml.add_representer(PosixPath, path_representer) + + with path.open("w") as file: + suffix = path.suffix.lower()[1:] + if suffix == "json": + file.write(json.dumps(data, indent=4, cls=_PathEncoder)) + elif suffix in ("yaml", "yml"): + file.write(yaml.dump(data, sort_keys=False)) + elif suffix == "toml": + file.write(toml.dumps(data)) + else: + raise ValueError(f"Unknown type '{suffix}'. Valid types: 'json', 'yaml', 'toml'") diff --git a/maize/utilities/macros.py b/maize/utilities/macros.py new file mode 100644 index 0000000..0edd7bb --- /dev/null +++ b/maize/utilities/macros.py @@ -0,0 +1,305 @@ +"""Workflow macros to allow certain node and subgraph modifications""" + +from collections.abc import Callable, Sequence +import inspect +from typing import Any, TypeVar + +from maize.core.component import Component +from maize.core.interface import Input, MultiInput, MultiParameter, Output, Parameter +from maize.core.node import Node +from maize.core.graph import Graph +from maize.core.runtime import setup_build_logging +from maize.steps.plumbing import Copy, Merge, RoundRobin, Yes +from maize.utilities.testing import MockChannel +from maize.utilities.utilities import unique_id + + +T = TypeVar("T") +G = TypeVar("G", bound=type[Component]) + + +def tag(name: str) -> Callable[[G], G]: + """ + Tag a `Node` or `Graph` with a particular attribute. + + Parameters + ---------- + name + The tag to use + + Returns + ------- + Callable[[G], G] + Tagging class decorator + + """ + + def decorator(cls: G) -> G: + cls._tags.add(name) + return cls + + return decorator + + +def parallel( + node_type: type[Component], + n_branches: int, + inputs: Sequence[str] | None = None, + constant_inputs: Sequence[str] | None = None, + outputs: Sequence[str] | None = None, + **kwargs: Any, +) -> type[Graph]: + """ + Workflow macro to parallelize a node. The created subgraph + will have the exact same interface as the wrapped node. + + Parameters + ---------- + node_type + The node class to parallelize + n_branches + The number of parallel branches to create + inputs + The names of all inputs to parallelize, will use ``'inp'`` as default + constant_inputs + The names of all inputs with constant loop/batch-invariant data + outputs + The names of all outputs to parallelize, will use ``'out'`` as default + kwargs + Additional arguments to be passed to the ``add`` method + + Returns + ------- + type[Graph] + A subgraph containing multiple parallel branches of the node + + Examples + -------- + >>> parallel_node = flow.add(parallel( + ... ExampleNode, + ... n_branches=3, + ... inputs=('inp', 'inp_other'), + ... constant_inputs=('inp_const',), + ... outputs=('out',) + ... )) + >>> flow.connect_all( + ... (input_node.out, parallel_node.inp), + ... (other_input_node.out, parallel_node.inp_other), + ... (const_input.out, parallel_node.inp_const), + ... (parallel_node.out, output_node.inp), + ... ) + + """ + input_names = ["inp"] if inputs is None else inputs + output_names = ["out"] if outputs is None else outputs + constant_input_names = constant_inputs or [] + + # PEP8 convention dictates CamelCase for classes + new_name = _snake2camel(node_type.__name__) + "Parallel" + attrs: dict[str, Any] = {} + + def build(self: Graph) -> None: + # Because our node can have multiple inputs and outputs + # (which we want to all parallelize), we create one RR + # node for each input (and one Merge node for each output) + sowers = [self.add(RoundRobin[Any], name=f"sow-{name}") for name in input_names] + + # These are inputs that are constant and unchanging over all nodes + constant_sowers = [ + self.add(Copy[Any], name=f"sow-const-{name}") for name in constant_input_names + ] + + # Copy expects input for each loop, but because the content is constant, + # we can accept it once (and then send it over and over with `Yes`) + yes_relays = [self.add(Yes[Any], name=f"yes-{name}") for name in constant_input_names] + nodes = [ + self.add(node_type, name=f"{node_type.__name__}-{i}", **kwargs) + for i in range(n_branches) + ] + reapers = [self.add(Merge[Any], name=f"reap-{name}") for name in output_names] + + # Connect all ports to each separate RR / Merge node + for node in nodes: + for inp, sow in zip(input_names, sowers): + self.connect(sow.out, node.inputs[inp]) + for const_inp, const_sow in zip(constant_input_names, constant_sowers): + self.connect(const_sow.out, node.inputs[const_inp]) + for out, reap in zip(output_names, reapers): + self.connect(node.outputs[out], reap.inp) + + # Connect the `Yes` relay nodes + for const_sow, yes in zip(constant_sowers, yes_relays): + self.connect(yes.out, const_sow.inp) + + # Expose all I/O ports + for inp, sow in zip(input_names, sowers): + inp_port = self.map_port(sow.inp, name=inp) + setattr(self, inp, inp_port) + for out, reap in zip(output_names, reapers): + out_port = self.map_port(reap.out, name=out) + setattr(self, out, out_port) + for name, yes in zip(constant_input_names, yes_relays): + inp_port = self.map_port(yes.inp, name=name) + setattr(self, name, inp_port) + + # Expose all parameters + for name in nodes[0].parameters: + para: MultiParameter[Any, Any] = self.combine_parameters( + *(node.parameters[name] for node in nodes), name=name + ) + setattr(self, name, para) + + attrs["build"] = build + return type(new_name, (Graph,), attrs, register=False) + + +def lambda_node(func: Callable[[Any], Any]) -> type[Node]: + """ + Convert an anonymous function with single I/O into a node. + + Parameters + ---------- + func + Lambda function taking a single argument and producing a single output + + Returns + ------- + type[Node] + Custom lambda wrapper node + + Examples + -------- + >>> lam = flow.add(lambda_node(lambda x: 2 * x)) + >>> flow.connect_all((first.out, lam.inp), (lam.out, last.inp)) + + """ + new_name = f"lambda-{unique_id()}" + + def run(self: Node) -> None: + assert hasattr(self, "inp") + assert hasattr(self, "out") + data = self.inp.receive() + res = func(data) + self.out.send(res) + + attrs = {"inp": Input(), "out": Output(), "run": run} + return type(new_name, (Node,), attrs, register=False) + + +def _snake2camel(string: str) -> str: + """Converts a string from snake_case to CamelCase""" + return "".join(s.capitalize() for s in string.split("_")) + + +def function_to_node(func: Callable[..., Any]) -> type[Node]: + """ + Dynamically creates a new node type from an existing python function. + + Parameters + ---------- + func + Function to convert, ``args`` will be converted to inputs, ``kwargs`` + will be converted to parameters. A single return will be converted to + one output port. + + Returns + ------- + type[Node] + Node class + + """ + # PEP8 convention dictates CamelCase for classes + new_name = _snake2camel(func.__name__) + + # Prepare all positional arguments (= inputs) and + # keyword arguments (= parameters with default) + sig = inspect.signature(func) + args, kwargs = {}, {} + for name, arg in sig.parameters.items(): + if arg.default == inspect._empty: # pylint: disable=protected-access + args[name] = arg.annotation + else: + kwargs[name] = (arg.annotation, arg.default) + + attrs: dict[str, Any] = {} + + # Prepare inputs + for name, dtype in args.items(): + input_name = f"inp_{name}" if len(args) > 1 else "inp" + inp: Input[Any] = Input() + inp.datatype = dtype + attrs[input_name] = inp + + # Prepare output + out: Output[Any] = Output() + out.datatype = sig.return_annotation + attrs["out"] = out + + # Prepare parameters + for name, (dtype, default) in kwargs.items(): + param: Parameter[Any] = Parameter(default=default) + param.datatype = dtype + attrs[name] = param + + def run(self: Node) -> None: + assert hasattr(self, "out") + # Inputs will not contain MultiInput in this context, so this is safe + args = [inp.receive() for inp in self.inputs.values()] # type: ignore + kwargs = { + name: param.value + for name, param in self.parameters.items() + if name not in Node.__dict__ + } + res = func(*args, **kwargs) + self.out.send(res) + + attrs["run"] = run + new = type(new_name, (Node,), attrs) + return new + + +def node_to_function(cls: type[Node], **constructor_args: Any) -> Callable[..., dict[str, Any]]: + """ + Convert a node class to a function that takes + inputs and parameters as function arguments. + + Parameters + ---------- + cls + The node class (not instance) + + Returns + ------- + Callable[..., dict[str, Any]] + A function taking inputs and parameters as keyword arguments + and returning a dictionary with outputs + + """ + node = cls(name="test", parent=Graph(), **constructor_args) + + def inner(**kwargs: Any) -> dict[str, Any]: + for name, inp in node.inputs.items(): + items = kwargs[name] + if isinstance(inp, MultiInput) and isinstance(items, list): + for item in items: + channel: MockChannel[Any] = MockChannel(items=item) + inp.set_channel(channel) + else: + channel: MockChannel[Any] = MockChannel(items=items) # type: ignore + inp.set_channel(channel) + + for name, parameter in node.parameters.items(): + if name in kwargs: + parameter.set(kwargs[name]) + + outputs: dict[str, Any] = {} + for name, out in node.outputs.items(): + channel = MockChannel() + out.set_channel(channel) + outputs[name] = channel + + node.logger = setup_build_logging(name="test") + node.run() + return {k: chan.get() for k, chan in outputs.items()} + + return inner diff --git a/maize/utilities/py.typed b/maize/utilities/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/maize/utilities/resources.py b/maize/utilities/resources.py new file mode 100644 index 0000000..84e82ed --- /dev/null +++ b/maize/utilities/resources.py @@ -0,0 +1,163 @@ +"""Computational resource management""" + +import logging +from multiprocessing import get_context +import os +import subprocess +import time +from typing import TYPE_CHECKING, TypeVar, Any + +from maize.core.runtime import Status +from maize.utilities.execution import DEFAULT_CONTEXT + +if TYPE_CHECKING: + from maize.core.component import Component + +T = TypeVar("T") + + +_ctx = get_context(DEFAULT_CONTEXT) + + +def gpu_count() -> int: + """Finds the number of GPUs in the current system.""" + try: + result = subprocess.run( + "nvidia-smi -L | wc -l", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=True, + ) + return len(result.stdout.decode().split("\n")) - 1 + except OSError: + return 0 + + +# https://stackoverflow.com/a/55423170/6257823 +def cpu_count() -> int: + """ + Finds the number of CPUs in the current system, while respecting + the affinity to make it compatible with cluster management systems. + + """ + return len(os.sched_getaffinity(0)) + +# Only reference to something like this that I've found is: +# https://stackoverflow.com/questions/69366850/counting-semaphore-that-supports-multiple-acquires-in-an-atomic-call-how-would +class ChunkedSemaphore: + """ + Semaphore allowing chunked atomic resource acquisition. + + Parameters + ---------- + max_count + Initial value representing the maximum resources + sleep + Time to wait in between semaphore checks + + """ + + def __init__(self, max_count: int, sleep: int = 5) -> None: + # FIXME temporary Any type hint until this gets solved + # https://github.com/python/typeshed/issues/8799 + self._count: Any = _ctx.Value("i", max_count) + self._max_count = max_count + self.sleep = sleep + + def acquire(self, value: int) -> None: + """ + Acquire resources. + + Parameters + ---------- + value + Amount to acquire + + Raises + ------ + ValueError + If the amount requested exceeds the available resources + + """ + if value > self._max_count: + raise ValueError(f"Acquire ({value}) exceeds available resources ({self._max_count})") + + while True: + with self._count.get_lock(): + if (self._count.value - value) >= 0: + self._count.value -= value + break + + # Sleep outside lock to make sure others can access + time.sleep(self.sleep) + + def release(self, value: int) -> None: + """ + Release resources. + + Parameters + ---------- + value + Amount to release + + Raises + ------ + ValueError + If the amount released exceeds the available resources + + """ + with self._count.get_lock(): + if (self._count.value + value) > self._max_count: + raise ValueError( + f"Release ({value}) exceeds original resources ({self._max_count})" + ) + self._count.value += value + + +class Resources: + """ + Acquire computational resources in the form of a context manager. + + Parameters + ---------- + max_count + The maximum resource count + parent + Parent component object for status updates + + Examples + -------- + >>> cpus = Resources(max_count=5, parent=graph) + ... with cpus(3): + ... do_something(n_jobs=3) + + """ + + def __init__(self, max_count: int, parent: "Component") -> None: + self._sem = ChunkedSemaphore(max_count=max_count) + self._val = 0 + self.parent = parent + + def __call__(self, value: int) -> "Resources": + """ + Acquire resources. + + Parameters + ---------- + value + Amount to acquire + + """ + self._val = value + return self + + def __enter__(self) -> "Resources": + self.parent.status = Status.WAITING_FOR_RESOURCES + self._sem.acquire(self._val) + self.parent.status = Status.RUNNING + return self + + def __exit__(self, *_: Any) -> None: + self._sem.release(self._val) + self._val = 0 diff --git a/maize/utilities/testing.py b/maize/utilities/testing.py new file mode 100644 index 0000000..8aa9b9c --- /dev/null +++ b/maize/utilities/testing.py @@ -0,0 +1,252 @@ +"""Testing utilities""" + +import logging +import queue +from typing import Any, TypeVar, cast + +from maize.core.interface import MultiInput, MultiOutput +from maize.core.node import Node +from maize.core.channels import Channel +from maize.core.workflow import Workflow +from maize.utilities.io import Config + + +T = TypeVar("T") + + +class MockChannel(Channel[T]): + """ + Mock channel class presenting a `Channel` interface for testing. + Can be loaded with an item to simulate a node receiving it from + a neighbour. Can also be used to retrieve a sent item for assertions. + + Parameters + ---------- + items + Items to be loaded into the channel + + """ + + def __init__(self, items: T | list[T] | None = None) -> None: + self._locked = False + self._queue: queue.Queue[T] = queue.Queue() + if items is not None: + if not isinstance(items, list): + items = [items] + for item in items: + self._queue.put(item) + + @property + def active(self) -> bool: + return not self._locked + + @property + def ready(self) -> bool: + return not self._queue.empty() + + @property + def size(self) -> int: + return self._queue.qsize() + + def close(self) -> None: + self._locked = True + + def kill(self) -> None: + # No multiprocessing here, so no special killing needed + pass + + def receive(self, timeout: float | None = None) -> T | None: + if not self._queue.empty(): + item = self._queue.get(timeout=timeout) + if self._queue.empty(): + self.close() + return item + return None + + def send(self, item: T, timeout: float | None = None) -> None: + self._queue.put(item, timeout=timeout) + + def get(self) -> T | None: + """Unconditionally get the stored item for assertions.""" + return self._queue.get() + + def flush(self, timeout: float = 0.1) -> list[T]: + data = [] + while not self._queue.empty(): + data.append(self._queue.get(timeout=timeout)) + return data + + +class TestRig: + """ + Test rig for user `Node` tasks. Can be loaded with parameters and inputs. + + Parameters + ---------- + cls + `Node` child class to be wrapped + config + Global configuration for the parent workflow + + Attributes + ---------- + inputs + Dictionary of input values + parameters + Dictionary of parameter values + node + `Node` instance + + See Also + -------- + MockChannel : Class simulating the behaviour of normal channels + without the issues associated with multiple processes + + Examples + -------- + >>> rig = TestRig(Foo) + ... rig.set_inputs(inputs=dict(inp="bar")) + ... rig.set_parameters(parameters=dict(param=42)) + ... rig.setup(n_outputs=2) + ... outputs = rig.run() + + """ + + inputs: dict[str, Any] + parameters: dict[str, Any] + node: Node + + def __init__(self, cls: type[Node], config: Config | None = None) -> None: + self._cls = cls + self._mock_parent = Workflow(level=logging.DEBUG) + self._mock_parent._message_queue = queue.Queue() # type: ignore + if config is not None: + self._mock_parent.config = config + + def set_inputs(self, inputs: dict[str, Any] | None) -> None: + """Set the task input values""" + self.inputs = inputs if inputs is not None else {} + + def set_parameters(self, parameters: dict[str, Any] | None) -> None: + """Set the task parameter values""" + self.parameters = parameters if parameters is not None else {} + + def setup( + self, n_outputs: int | dict[str, int] | None = None, **kwargs: Any + ) -> dict[str, MockChannel[Any] | list[MockChannel[Any]]]: + """Instantiate the node and create mock interfaces.""" + self.node = self._cls(name="test", parent=self._mock_parent, **kwargs) + self.node.logger = self._mock_parent.logger + for name, inp in self.node.inputs.items(): + if name not in self.inputs and (inp.default is not None or inp.optional): + continue + items = self.inputs[name] + if isinstance(inp, MultiInput) and isinstance(items, list): + for item in items: + channel: MockChannel[Any] = MockChannel(items=item) + inp.set_channel(channel) + else: + channel = MockChannel(items=items) + inp.set_channel(channel) + + for name, parameter in self.node.parameters.items(): + if name in self.parameters: + parameter.set(self.parameters[name]) + + outputs: dict[str, MockChannel[Any] | list[MockChannel[Any]]] = {} + for name, out in self.node.outputs.items(): + out_channel: list[MockChannel[Any]] | MockChannel[Any] + if isinstance(out, MultiOutput) and n_outputs is not None: + n_out = n_outputs if isinstance(n_outputs, int) else n_outputs[name] + out_channel = [] + for _ in range(n_out): + chan = MockChannel[Any]() + out.set_channel(chan) + out_channel.append(chan) + else: + out_channel = MockChannel() + out.set_channel(out_channel) + outputs[name] = out_channel + + return outputs + + def run(self) -> None: + """Run the node with inputs and parameters previously set.""" + self.node._prepare() + self.node._iter_run(cleanup=False) # pylint: disable=protected-access + + def setup_run_multi( + self, + inputs: dict[str, Any] | None = None, + parameters: dict[str, Any] | None = None, + n_outputs: int | dict[str, int] | None = None, + **kwargs: Any, + ) -> dict[str, MockChannel[Any] | list[MockChannel[Any]]]: + """ + Instantiate and run the node with a specific set of parameters and inputs. + Note that this method will potentially return a mix of `MockChannel` and + `list[MockChannel]`, so your receiving side needs to handle both types correctly. + + Parameters + ---------- + inputs + Inputs for your node + parameters + Parameters for your node + n_outputs + How many outputs to create for `MultiOutput` + kwargs + Any additional arguments to pass to the node constructor + + Returns + ------- + dict[str, MockChannel[Any] | list[MockChannel[Any]]] + The outputs of the node in the form of channels potentially containing data + + See Also + -------- + TestRig.setup_run + Testing for nodes with a fixed number of outputs + + """ + self.set_inputs(inputs) + self.set_parameters(parameters) + outputs = self.setup(n_outputs=n_outputs, **kwargs) + self.run() + return outputs + + def setup_run( + self, + inputs: dict[str, Any] | None = None, + parameters: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, MockChannel[Any]]: + """ + Instantiate and run the node with a specific set of parameters and inputs. + If you need variable outputs, use `setup_run_multi` instead. + + Parameters + ---------- + inputs + Inputs for your node + parameters + Parameters for your node + kwargs + Any additional arguments to pass to the node constructor + + Returns + ------- + dict[str, MockChannel[Any]] + The outputs of the node in the form of channels potentially containing data + + See Also + -------- + TestRig.setup_run_multi + Testing for nodes with a variable number of outputs + + """ + self.set_inputs(inputs) + self.set_parameters(parameters) + outputs = self.setup(**kwargs) + self.run() + return cast(dict[str, MockChannel[Any]], outputs) diff --git a/maize/utilities/utilities.py b/maize/utilities/utilities.py new file mode 100644 index 0000000..97f306c --- /dev/null +++ b/maize/utilities/utilities.py @@ -0,0 +1,670 @@ +"""Various unclassifiable utilities.""" + +import ast +from collections.abc import Generator, Callable +from contextlib import redirect_stderr +import contextlib +import datetime +from enum import Enum +import functools +import inspect +import io +import itertools +import math +import os +from pathlib import Path +import random +import re +import shlex +import time +import string +import sys +from types import UnionType +from typing import ( + TYPE_CHECKING, + AnyStr, + Literal, + TypeVar, + Any, + Annotated, + Union, + TypeAlias, + cast, + get_args, + get_origin, +) +from typing_extensions import assert_never +import warnings + +from beartype.door import is_subhint +import networkx as nx + +if TYPE_CHECKING: + from maize.core.graph import Graph + from maize.core.component import Component + +T = TypeVar("T") +U = TypeVar("U") + + +def unique_id(length: int = 6) -> str: + """ + Creates a somewhat unique identifier. + + Parameters + ---------- + length + Length of the generated ID + + Returns + ------- + str + A random string made up of lowercase ASCII letters and digits + + """ + # This should be safer than truncating a UUID + return "".join(random.choices(string.ascii_lowercase + string.digits, k=length)) + + +class StrEnum(Enum): + """Allows use of enum names as auto string values. See `StrEnum` in Python 3.11.""" + + @staticmethod + def _generate_next_value_(name: str, *_: Any) -> str: + return name + + +# See this SO answer: +# https://stackoverflow.com/questions/39372708/spawn-multiprocessing-process-under-different-python-executable-with-own-path +def change_environment(exec_path: str | Path) -> None: # pragma: no cover + """ + Changes the python environment based on the executable. + + Parameters + ---------- + exec_path + Path to the python executable (normally ``sys.executable``) + + """ + old_path = os.environ.get("PATH", "") + exec_abs = Path(exec_path).parent.absolute() + os.environ["PATH"] = exec_abs.as_posix() + os.pathsep + old_path + base = exec_abs.parent + site_packages = base / "lib" / f"python{sys.version[:4]}" / "site-packages" + old_sys_path = list(sys.path) + + import site # pylint: disable=import-outside-toplevel + + site.addsitedir(site_packages.as_posix()) + sys.prefix = base.as_posix() + new_sys_path = [] + for item in list(sys.path): + if item not in old_sys_path: + new_sys_path.append(item) + sys.path.remove(item) + sys.path[:0] = new_sys_path + + +def set_environment(env: dict[str, str]) -> None: + """ + Set global system environment variables. + + Parameters + ---------- + env + Dictionary of name-value pairs + + """ + for key, value in env.items(): + os.environ[key] = value + + +def has_module_system() -> bool: + """Checks whether the system can use modules.""" + return "LMOD_DIR" in os.environ + + +# https://stackoverflow.com/questions/5427040/loading-environment-modules-within-a-python-script +def load_modules(*names: str) -> None: + """ + Loads environment modules using ``lmod``. + + Parameters + ---------- + names + Module names to load + + """ + lmod = Path(os.environ["LMOD_DIR"]) + sys.path.insert(0, (lmod.parent / "init").as_posix()) + from env_modules_python import module # pylint: disable=import-outside-toplevel,import-error + + for name in names: + out = io.StringIO() # pylint: disable=no-member + with redirect_stderr(out): + module("load", name) + if "error" in out.getvalue(): + raise OSError(f"Error loading module '{name}'") + + +def _extract_single_class_docs(node_type: type["Component"]) -> dict[str, str]: + """Extracts docs from a class definition in the style of :pep:`258`.""" + docs = {} + + # A lot of things can go wrong here, so we err on the side of caution + with contextlib.suppress(Exception): + source = inspect.getsource(node_type) + for node in ast.iter_child_nodes(ast.parse(source)): + for anno, expr in itertools.pairwise(ast.iter_child_nodes(node)): + match anno, expr: + case ast.AnnAssign( + target=ast.Name(name), + annotation=ast.Subscript( + value=ast.Name( + id="Input" | "Output" | "Parameter" | "FileParameter" | "Flag" + ) + ), + ), ast.Expr(value=ast.Constant(doc)): + docs[name] = doc + return docs + + +# This is *slightly* evil, but it allows us to get docs that are +# both parseable by sphinx and usable for commandline help messages +def extract_attribute_docs(node_type: type["Component"]) -> dict[str, str]: + """ + Extracts attribute docstrings in the style of :pep:`258`. + + Parameters + ---------- + node_type + Node class to extract attribute docstrings from + + Returns + ------- + dict[str, str] + Dictionary of attribute name - docstring pairs + + """ + docs: dict[str, str] = {} + for cls in (node_type, *node_type.__bases__): + docs |= _extract_single_class_docs(cls) + return docs + + +def typecheck(value: Any, datatype: Any) -> bool: + """ + Checks if a value is valid using type annotations. + + Parameters + ---------- + value + Value to typecheck + datatype + Type to check against. Can also be a Union or Annotated. + + Returns + ------- + bool + ``True`` if the value is valid, ``False`` otherwise + + """ + # This means we have probably omitted types + if datatype is None or isinstance(datatype, TypeVar) or datatype == Any: + return True + + # In some cases (e.g. file types) we'll hopefully have an + # `Annotated` type, using a custom callable predicate for + # validation. This could be a check for the correct file + # extension or a particular range for a numerical input + if get_origin(datatype) == Annotated: + cls, *predicates = get_args(datatype) + return typecheck(value, cls) and all(pred(value) for pred in predicates) + + if get_origin(datatype) == Literal: # pylint: disable=comparison-with-callable + options = get_args(datatype) + return value in options + + # Any kind of container type + if len(anno := get_args(datatype)) > 0: + if get_origin(datatype) == UnionType: + return any(typecheck(value, cls) for cls in anno) + if get_origin(datatype) in (tuple, list): + return all(typecheck(val, arg) for val, arg in zip(value, get_args(datatype))) + if get_origin(datatype) == dict: + key_type, val_type = get_args(datatype) + return all(typecheck(key, key_type) for key in value) and all( + typecheck(val, val_type) for val in value.values() + ) + + # Safe fallback in case we don't know what this is, this should avoid false positives + datatype = get_origin(datatype) + return isinstance(value, datatype) + + +_U = TypeVar("_U", Callable[[Any], Any], type[Any]) + + +def deprecated( + msg: str | None = None, +) -> Callable[[_U], _U]: + """Inserts a deprecation warning for a class or a function""" + + msg = "." if msg is None else ", " + msg + + def _warn(obj: Any) -> None: + warnings.simplefilter("always", DeprecationWarning) + warnings.warn( + f"{obj.__name__} is deprecated" + msg, + category=DeprecationWarning, + stacklevel=2, + ) + warnings.simplefilter("default", DeprecationWarning) + + def deprecator(obj: _U) -> _U: + if inspect.isclass(obj): + orig_init = obj.__init__ + + def __init__(self: T, *args: Any, **kwargs: Any) -> None: + _warn(obj) + orig_init(self, *args, **kwargs) + + # Mypy complains, recommended workaround seems to be to just ignore: + # https://github.com/python/mypy/issues/2427 + obj.__init__ = __init__ + return obj + + # Order matters here, as classes are also callable + elif callable(obj): + + def inner(*args: Any, **kwargs: Any) -> Any: + _warn(obj) + return obj(*args, **kwargs) + + return cast(_U, inner) + + else: + assert_never(obj) + + return deprecator + + +class Timer: + """ + Timer with start, pause, and stop functionality. + + Examples + -------- + >>> t = Timer() + ... t.start() + ... do_something() + ... t.pause() + ... do_something_else() + ... print(t.stop()) + + """ + + def __init__(self) -> None: + self._start = 0.0 + self._running = False + self._elapsed_time = 0.0 + + @property + def elapsed_time(self) -> datetime.timedelta: + """Returns the elapsed time.""" + if self.running: + self.pause() + self.start() + return datetime.timedelta(seconds=self._elapsed_time) + + @property + def running(self) -> bool: + """Returns whether the timer is currently running.""" + return self._running + + def start(self) -> None: + """Start the timer.""" + self._start = time.time() + self._running = True + + def pause(self) -> None: + """Temporarily pause the timer.""" + self._elapsed_time += time.time() - self._start + self._running = False + + def stop(self) -> datetime.timedelta: + """Stop the timer and return the elapsed time in seconds.""" + if self.running: + self.pause() + return datetime.timedelta(seconds=self._elapsed_time) + + +def graph_cycles(graph: "Graph") -> list[list[str]]: + """Returns whether the graph contains cycles.""" + mdg = graph_to_nx(graph) + return list(nx.simple_cycles(mdg)) + + +def graph_to_nx(graph: "Graph") -> nx.MultiDiGraph: + """ + Converts a workflow graph to a ``networkx.MultiDiGraph`` object. + + Parameters + ---------- + graph + Workflow graph + + Returns + ------- + nx.MultiDiGraph + Networkx graph instance + + """ + mdg = nx.MultiDiGraph() + unique_nodes = {node.component_path for node in graph.flat_nodes} + unique_channels = {(inp[:-1], out[:-1]) for inp, out in graph.flat_channels} + mdg.add_nodes_from(unique_nodes) + mdg.add_edges_from(unique_channels) + return mdg + + +# `Unpack` just seems completely broken with 3.10 and mypy 0.991 +# See PEP646 on tuple unpacking: +# https://peps.python.org/pep-0646/#unpacking-unbounded-tuple-types +NestedDict: TypeAlias = dict[T, Union["NestedDict[T, U]", U]] + + +def tuple_to_nested_dict(*data: Any) -> NestedDict[T, U]: + """Convert a tuple into a sequentially nested dictionary.""" + out: NestedDict[T, U] + ref: NestedDict[T, U] + out = ref = {} + *head, semifinal, final = data + for token in head: + ref[token] = ref = {} + ref[semifinal] = final + return out + + +def nested_dict_to_tuple(__data: NestedDict[T, U]) -> tuple[T | U, ...]: + """Convert a sequentially nested dictionary into a tuple.""" + out: list[T | U] = [] + data: U | NestedDict[T, U] = __data + while data is not None: + if not isinstance(data, dict): + out.append(data) + break + first, data = next(iter(data.items())) + out.append(first) + return tuple(out) + + +def has_file(path: Path) -> bool: + """Returns whether the specified directory contains files.""" + return path.exists() and len(list(path.iterdir())) > 0 + + +def make_list(item: list[T] | set[T] | tuple[T, ...] | T) -> list[T]: + """Makes a single item or sequence of items into a list.""" + if isinstance(item, list | set | tuple): + return list(item) + return [item] + + +def chunks(data: list[T], n: int) -> Generator[list[T], None, None]: + """Splits a dataset into ``n`` chunks""" + size, rem = divmod(len(data), n) + for i in range(n): + si = (size + 1) * (i if i < rem else rem) + size * (0 if i < rem else i - rem) + yield data[si : si + (size + 1 if i < rem else size)] + + +def split_list( + arr: list[T], *, n_batch: int | None = None, batchsize: int | None = None +) -> list[list[T]]: + """ + Splits a list into smaller lists. + + Parameters + ---------- + arr + List to split + n_batch + Number of batches to generate + batchsize + Size of the batches + + Returns + ------- + list[list[T]] + List of split lists + + """ + if batchsize is None and n_batch is not None: + batchsize = math.ceil(len(arr) / n_batch) + elif (batchsize is None and n_batch is None) or n_batch is not None: + raise ValueError("You must specify either the number of batches or the batchsize") + assert batchsize is not None # Only needed to shutup mypy + return [arr[i : i + batchsize] for i in range(0, len(arr), batchsize)] + + +def split_multi(string: str, chars: str) -> list[str]: + """ + Split string on multiple characters + + Parameters + ---------- + string + String to split + chars + Characters to split on + + Returns + ------- + list[str] + List of string splitting results + + """ + if not chars: + return [string] + splits = [] + for chunk in string.split(chars[0]): + splits.extend(split_multi(chunk, chars[1:])) + return splits + + +def format_datatype(dtype: Any) -> str: + """ + Formats a datatype to print nicely. + + Parameters + ---------- + dtype + Datatype to format + + Returns + ------- + str + Formatted datatype + + """ + if hasattr(dtype, "__origin__"): + return str(dtype) + return str(dtype.__name__) + + +def get_all_annotations(cls: type, visited: set[type] | None = None) -> dict[str, Any]: + """ + Recursively collect all annotations from a class and its superclasses. + + Parameters + ---------- + cls + The class to find annotations from + + visited + Set of visited classes + + Returns + ------- + dict[str, Any]: + Dictionary of annotations + """ + if visited is None: + visited = set() + if cls in visited: + return {} + visited.add(cls) + # Start with an empty dictionary for annotations + annotations: dict[str, Any] = {} + # First, recursively collect and merge annotations from base classes + for base in cls.__bases__: + annotations |= get_all_annotations(base, visited) + # Then, merge those with the current class's annotations, ensuring they take precedence + annotations |= cls.__annotations__ if hasattr(cls, "__annotations__") else {} + return annotations + + +# If the type is not given in the constructor (e.g. `out = Output[int]()`), +# it's hopefully given as an annotation (e.g. `out: Output[int] = Output()`) +def extract_superclass_type(owner: Any, name: str) -> Any: + """ + Extract type annotations from superclasses. + + Parameters + ---------- + owner + Parent object with annotations in the form of ``x: str = ...`` + name + Name of the variable + + Returns + ------- + Any + The annotated type, ``None`` if there wasn't one found + + """ + # __annotations__ does not include super class + # annotations, so we combine all of them first + annotations = get_all_annotations(owner.__class__) + annotations |= owner.__annotations__ + if name in annotations: + return get_args(annotations[name])[0] + return None + + +def extract_type(obj: Any) -> Any: + """ + Extract type annotations from an object or class. + + Parameters + ---------- + obj + Object with type annotations in the form of ``Object[...]`` + + Returns + ------- + Any + The annotated type, ``None`` if there wasn't one found + + """ + # This is an undocumented implementation detail to retrieve type + # arguments from an instance to use for dynamic / runtime type checking, + # and could break at some point. We should revisit this in the future, + # but for now it conveniently removes a bit of `Node` definition boilerplate. + # See also this SO question: + # https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance + if hasattr(obj, "__orig_class__"): + return get_args(obj.__orig_class__)[0] + if hasattr(obj, "__args__"): + return obj.__args__[0] + return None + + +def is_path_type(arg1: Any) -> bool: + """ + Checks if type is a `Path`-like type. + + Parameters + ---------- + arg1 + The datatype to check + + Returns + ------- + bool + ``True`` if the type is a `Path`-like, ``False`` otherwise + + """ + return is_subhint(arg1, Path | list[Path] | Annotated[Path, ...] | dict[Any, Path]) + + +# FIXME There seems to be an issue with python 3.11 and NDArray typehints. We're sticking with +# python 3.10 for now, but ideally we would drop the dependency on beartype at some point. +def matching_types(arg1: Any, arg2: Any, strict: bool = False) -> bool: + """ + Checks if two types are matching. + + Parameters + ---------- + arg1, arg2 + The datatypes to compare + strict + If set, will only output ``True`` if the types match exactly + + Returns + ------- + bool + ``True`` if the types are compatible, ``False`` otherwise + + """ + if strict: + return is_subhint(arg1, arg2) and is_subhint(arg2, arg1) + if None in (arg1, arg2): + return True + return is_subhint(arg1, arg2) or is_subhint(arg2, arg1) + + +def find_probable_files_from_command(command: str | list[str]) -> list[Path]: + """ + Finds possible files from a command string. + + Should not be fully relied upon, as a file located in the current + directory with no suffix will not be easily identifiable as such. + + Parameters + ---------- + command + String or list of tokens to check for files + + Returns + ------- + list[Path] + Listing of `Path` objects, or an empty list if no files were found + + """ + if isinstance(command, str): + command = shlex.split(command) + # The `Path` constructor will never throw an exception + # as long as we supply `str` (or a `str` subtype) + return [Path(token) for token in command if any(c in token for c in ("/", "."))] + + +def match_context(match: re.Match[AnyStr], chars: int = 100) -> AnyStr: + """ + Provides context to a regular expression match. + + Parameters + ---------- + match + Regular expression match object + chars + Number of characters of context to provide + + Returns + ------- + AnyStr + Match context + + """ + return match.string[match.start() - chars : match.end() + chars] diff --git a/maize/utilities/validation.py b/maize/utilities/validation.py new file mode 100644 index 0000000..7381371 --- /dev/null +++ b/maize/utilities/validation.py @@ -0,0 +1,177 @@ +"""Defines some simple process run validators.""" + + +from abc import abstractmethod +from pathlib import Path +import subprocess +from maize.utilities.io import wait_for_file + +from maize.utilities.utilities import make_list + + +class Validator: + """ + Validate if a command has run successfully or not. + + Calling an instance with a ``subprocess.CompletedProcess`` object + should return boolean indicating the success of the command. + + """ + + @abstractmethod + def __call__(self, result: subprocess.CompletedProcess[bytes]) -> bool: + """ + Calls the validator with the instantiated search string on a command result. + + Parameters + ---------- + result + Object containing `stdout` and `stderr` of the completed command + + Returns + ------- + bool + ``True`` if the command succeeded, ``False`` otherwise + + """ + + +class OutputValidator(Validator): + """ + Validate the STDOUT / STDERR of an external command. + + Calling an instance with a ``subprocess.CompletedProcess`` object + should return boolean indicating the success of the command. + + Parameters + ---------- + expect + String or list of strings to look for in the output + + """ + + def __init__(self, expect: str | list[str]) -> None: + self.expect = make_list(expect) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(self.expect)})" + + @abstractmethod + def __call__(self, result: subprocess.CompletedProcess[bytes]) -> bool: + pass + + +class FailValidator(OutputValidator): + """ + Check if an external command has failed by searching for strings in the output. + + Parameters + ---------- + expect + String or list of strings to look for in the + output, any match will indicate a failure. + + """ + + def __call__(self, result: subprocess.CompletedProcess[bytes]) -> bool: + for exp in self.expect: + if (result.stderr is not None and exp in result.stderr.decode(errors="ignore")) or ( + exp in result.stdout.decode(errors="ignore") + ): + return False + return True + + +class SuccessValidator(OutputValidator): + """ + Check if an external command has succeeded by searching for strings in the output. + + Parameters + ---------- + expect + If all strings specified here are found, the command was successful + + """ + + def __call__(self, result: subprocess.CompletedProcess[bytes]) -> bool: + for exp in self.expect: + if not ( + (result.stderr is not None and exp in result.stderr.decode(errors="ignore")) + or (exp in result.stdout.decode(errors="ignore")) + ): + return False + return True + + +class FileValidator(Validator): + """ + Check if an external command has succeeded by searching for one or more generated files. + + Parameters + ---------- + expect + If all files specified here are found, the command was successful + zero_byte_check + Whether to check if the generated file is empty + timeout + Will wait ``timeout`` seconds for the file to appear + + """ + + def __init__( + self, expect: Path | list[Path], zero_byte_check: bool = True, timeout: int = 5 + ) -> None: + self.expect = make_list(expect) + self.zero_byte_check = zero_byte_check + self.timeout = timeout + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({', '.join(p.as_posix() for p in self.expect)}, " + f"zero_byte_check={self.zero_byte_check})" + ) + + def __call__(self, _: subprocess.CompletedProcess[bytes]) -> bool: + for exp in self.expect: + try: + wait_for_file(exp, timeout=self.timeout, zero_byte_check=self.zero_byte_check) + except TimeoutError: + return False + return True + + +class ContentValidator(Validator): + """ + Check if an external command has succeeded by searching + contents in one or more generated files. Empty files fail. + + Parameters + ---------- + expect + dictionary of file paths and list of strings to check for in each file + timeout + Will wait ``timeout`` seconds for the file to appear + """ + + def __init__(self, expect: dict[Path, list[str]], timeout: int = 5) -> None: + self.expect = expect + self.timeout = timeout + + def __repr__(self) -> str: + arg_descr = ", ".join( + f"{p.as_posix()}:{','.join(s for s in self.expect[p])}" for p in self.expect + ) + return f"{self.__class__.__name__}({arg_descr})" + + def __call__(self, _: subprocess.CompletedProcess[bytes]) -> bool: + for file_path, expected in self.expect.items(): + try: + wait_for_file(file_path, timeout=self.timeout, zero_byte_check=True) + with file_path.open() as f: + file_contents = f.read() + for exp_string in expected: + if exp_string and exp_string not in file_contents: + return False + except TimeoutError: + return False + return True diff --git a/maize/utilities/visual.py b/maize/utilities/visual.py new file mode 100644 index 0000000..ed65af7 --- /dev/null +++ b/maize/utilities/visual.py @@ -0,0 +1,166 @@ +"""Utilities for graph and workflow visualization.""" + +import sys +from typing import Any, Literal, Union, cast, TYPE_CHECKING, get_args, get_origin + +from matplotlib.colors import to_hex + +from maize.utilities.execution import check_executable +from maize.core.runtime import Status + +try: + import graphviz +except ImportError: + HAS_GRAPHVIZ = False +else: + HAS_GRAPHVIZ = True + if not check_executable(["dot", "-V"]): + HAS_GRAPHVIZ = False + +if TYPE_CHECKING: + from maize.core.graph import Graph + +# AZ colors +_COLORS = { + "mulberry": (131, 0, 81), + "navy": (0, 56, 101), + "purple": (60, 16, 83), + "gold": (240, 171, 0), + "lime-green": (196, 214, 0), + "graphite": (63, 68, 68), + "light-blue": (104, 210, 223), + "magenta": (208, 0, 111), + "platinum": (157, 176, 172), +} + +_STATUS_COLORS = { + Status.NOT_READY: _COLORS["platinum"], + Status.READY: _COLORS["graphite"], + Status.RUNNING: _COLORS["light-blue"], + Status.COMPLETED: _COLORS["lime-green"], + Status.FAILED: _COLORS["magenta"], + Status.STOPPED: _COLORS["purple"], + Status.WAITING_FOR_INPUT: _COLORS["gold"], + Status.WAITING_FOR_OUTPUT: _COLORS["gold"], + Status.WAITING_FOR_RESOURCES: _COLORS["navy"], + Status.WAITING_FOR_COMMAND: _COLORS["navy"], +} + + +def _rgb(red: int, green: int, blue: int) -> Any: + return to_hex((red / 255, green / 255, blue / 255)) + + +def _pprint_dtype(dtype: Any) -> str: + """Prints datatypes in a concise way""" + ret = "" + if (origin := get_origin(dtype)) is not None: + ret += f"{origin.__name__}" + if args := get_args(dtype): + ret += f"[{', '.join(_pprint_dtype(arg) for arg in args)}]" + elif hasattr(dtype, "__name__"): + ret += dtype.__name__ + return ret + + +HEX_COLORS = {k: _rgb(*c) for k, c in _COLORS.items()} +HEX_STATUS_COLORS = {k: _rgb(*c) for k, c in _STATUS_COLORS.items()} +COLOR_SEQ = list(HEX_COLORS.values()) +GRAPHVIZ_STYLE = dict( + node_attr={ + "fillcolor": "#66666622", + "fontname": "Consolas", + "fontsize": "11", + "shape": "box", + "style": "rounded,filled", + "penwidth": "2.0", + }, + graph_attr={"bgcolor": "#ffffff00"}, + edge_attr={ + "fontname": "Consolas", + "fontsize": "9", + "penwidth": "2.0", + "color": HEX_COLORS["graphite"], + }, +) + + +def nested_graphviz( + flow: "Graph", + max_level: int = sys.maxsize, + coloring: Literal["nesting", "status"] = "nesting", + labels: bool = True, + _dot: Union["graphviz.Digraph", None] = None, + _level: int = 0, +) -> "graphviz.Digraph": + """ + Create a graphviz digraph instance from a workflow or graph. + + Parameters + ---------- + flow + Workflow to convert + _dot + Graphviz dot for recursive internal passing + _level + Current nesting level for internal passing + + Returns + ------- + graphviz.Digraph + Graphviz object for visualization + + """ + if _dot is None: + _dot = graphviz.Digraph(flow.name, **GRAPHVIZ_STYLE) + for name, node in flow.nodes.items(): + # Can't check for graph due to circular imports + if hasattr(node, "nodes") and _level < max_level: + # We need to prefix the name with "cluster" to make + # sure graphviz recognizes it as a subgraph + with _dot.subgraph(name="cluster-" + name) as subgraph: + color = COLOR_SEQ[_level] if coloring == "nesting" else HEX_COLORS["platinum"] + subgraph.attr(label=name) + subgraph.attr(**GRAPHVIZ_STYLE["node_attr"]) + subgraph.attr(color=color) + nested_graphviz( + cast("Graph", node), + max_level=max_level, + coloring=coloring, + labels=labels, + _dot=subgraph, + _level=_level + 1, + ) + else: + # Because we can have duplicate names in subgraphs, we need to refer + # to each node by its full path (and construct the edges this way too) + color = COLOR_SEQ[_level] if coloring == "nesting" else HEX_STATUS_COLORS[node.status] + unique_name = "-".join(node.component_path) + _dot.node(unique_name, label=name, color=color) + + for (*out_path, out_port_name), (*inp_path, inp_port_name) in flow.channels: + root = flow.root + out = root.get_port(*out_path, out_port_name) + inp = root.get_port(*inp_path, inp_port_name) + dtype_label = _pprint_dtype(out.datatype) + out_name = "-".join(out_path[: max_level + 1]) + inp_name = "-".join(inp_path[: max_level + 1]) + headlabel = ( + inp_port_name.removeprefix("inp_") + if len(inp.parent.inputs) > 1 and inp_port_name != "inp" + else None + ) + taillabel = ( + out_port_name.removeprefix("out_") + if len(out.parent.outputs) > 1 and out_port_name != "out" + else None + ) + _dot.edge( + out_name, + inp_name, + label=dtype_label if labels else None, + headlabel=headlabel if labels else None, + taillabel=taillabel if labels else None, + **GRAPHVIZ_STYLE["edge_attr"], + ) + return _dot diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3d302b2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,87 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "maize" +description = "Graph-based workflow manager for computational chemistry" +dynamic = ["version"] +readme = "README.md" +license = {file = "LICENSE"} +authors = [{name = "AstraZeneca"}] +maintainers = [{name = "Thomas Löhr", email = "thomas.lohr@astrazeneca.com"}] +requires-python = ">=3.10" +dependencies = [ + "networkx>=3.1", + "dill>=0.3.6", + "numpy>=1.24.3", + "pyyaml>=0.2.5", + "toml>=0.10.2", + "matplotlib>=3.7.1", + "beartype>=0.13.1", + "psij-python @ git+https://github.com/ExaWorks/psij-python.git@9e1a777", + "mypy>=1.2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.3.1", + "pytest-datadir>=1.4.1", + "pytest-mock>=3.10.0", + "pylint>=2.17.3", + "sphinx>=6.2.0", + "pandoc>=3.1.1", + "nbsphinx>=0.9.1", + "furo>=2023.3.27", +] + +[project.scripts] +maize = "maize.maize:main" + +[tool.setuptools.packages.find] +include = ["maize*"] + +[tool.setuptools.package-data] +"*" = ["py.typed", "../maize.toml"] + +[tool.setuptools.dynamic] +version = {attr = "maize.maize.__version__"} + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "DEBUG" +addopts = "-vv" +markers = ["random: run randomised workflows"] + +[tool.coverage.report] +exclude_also = [ + "def __repr__", + "if TYPE_CHECKING:", + "if HAS_GRAPHVIZ:", + "def _repr_mimebundle_", + "@(abc\\.)?abstractmethod", + "assert_never" +] + +[tool.mypy] +# Note that exclude will not be honored by VS Code +# https://github.com/microsoft/vscode-python/issues/16511 +exclude = 'tests/' +follow_imports = "silent" +ignore_missing_imports = true +strict = true +disable_error_code = "unused-ignore" + +[tool.black] +line-length = 100 + +[tool.ruff] +select = ["E", "F", "W", "UP", "SIM", "PTH", "PL", "NPY"] +ignore = ["PLR", "PLW", "F401"] +line-length = 100 + +[tool.pylint.main] +source-roots = "./maize/*" + +[tool.coverage.run] +relative_files = true \ No newline at end of file diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..5e7b0a8 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,2 @@ +sonar.projectKey=maize +sonar.organizations=devops diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3cb472d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,170 @@ +"""Global fixtures""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import, unused-variable, unused-argument + +from pathlib import Path +import shutil +from typing import Annotated, Any + +import pytest + +from maize.core.interface import Input, Output, Parameter, FileParameter, Suffix +from maize.core.graph import Graph +from maize.core.node import Node +from maize.core.workflow import Workflow +from maize.steps.plumbing import Copy, Merge, Delay +from maize.steps.io import Return, Void + + +@pytest.fixture +def temp_working_dir(tmp_path: Any, monkeypatch: Any) -> None: + monkeypatch.chdir(tmp_path) + + +class A(Node): + out = Output[int]() + val = Parameter[int](default=3) + file = FileParameter[Annotated[Path, Suffix("pdb")]](default=Path("./fake")) + flag = Parameter[bool](default=False) + + def run(self): + self.out.send(self.val.value) + + +@pytest.fixture +def example_a(): + return A + + +class B(Node): + fail: bool = False + inp = Input[int]() + out = Output[int]() + out_final = Output[int]() + + def run(self): + if self.fail: + self.fail = False + raise RuntimeError("This is a test exception") + + val = self.inp.receive() + self.logger.debug("%s received %s", self.name, val) + if val > 48: + self.logger.debug("%s stopping", self.name) + self.out_final.send(val) + return + self.out.send(val + 2) + + +@pytest.fixture +def example_b(): + return B + + +@pytest.fixture +def node_with_file(): + class NodeFile(Node): + inp: Input[int] = Input() + + def run(self) -> None: + file = Path("test.out") + file.unlink(missing_ok=True) + data = self.inp.receive() + self.logger.debug("received %s", data) + with file.open("w") as f: + f.write(str(data)) + return NodeFile + + +class SubSubGraph(Graph): + def build(self): + a = self.add(A, "a", parameters=dict(val=36)) + d = self.add(Delay[int], "delay", parameters=dict(delay=1)) + self.connect(a.out, d.inp) + self.out = self.map_port(d.out, name="out") + self.combine_parameters(a.val, name="val") + + +@pytest.fixture +def subsubgraph(): + return SubSubGraph + + +class SubGraph(Graph): + def build(self): + a = self.add(SubSubGraph, "ssg", parameters=dict(val=36)) + d = self.add(Delay[int], "delay", parameters=dict(delay=1)) + self.connect(a.out, d.inp) + self.out = self.map_port(d.out, "out") + + +@pytest.fixture +def subgraph(): + return SubGraph + + +@pytest.fixture +def subgraph_multi(): + class SubgraphMulti(Graph): + def build(self): + a = self.add(A, parameters=dict(val=36)) + copy = self.add(Copy[int]) + void = self.add(Void) + self.connect(a.out, copy.inp) + self.connect(copy.out, void.inp) + self.map(copy.out) + return SubgraphMulti + + +@pytest.fixture +def nested_graph(subgraph, example_b): + g = Workflow() + sg = g.add(subgraph, "sg") + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(sg.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + return g + + +@pytest.fixture +def nested_graph_with_params(subsubgraph, example_b): + g = Workflow() + sg = g.add(subsubgraph, "sg") + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(sg.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + g.combine_parameters(sg.parameters["val"], name="val") + return g + + +class NewGraph(Graph): + def build(self): + a = self.add(A, "a") + b = self.add(B, "b", loop=True) + t = self.add(Return[int], "t") + self.connect(a.out, b.inp) + self.connect(b.out_final, t.inp) + self.out = self.map_port(b.out, "out") + +class NewGraph2(Graph): + def build(self): + d = self.add(Delay[int], "d") + t = self.add(Return[int], "t") + self.connect(d.out, t.inp) + self.inp = self.map_port(d.inp, "inp") + +@pytest.fixture +def newgraph(): + return NewGraph + +@pytest.fixture +def newgraph2(): + return NewGraph2 diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/conftest.py b/tests/core/conftest.py new file mode 100644 index 0000000..499c775 --- /dev/null +++ b/tests/core/conftest.py @@ -0,0 +1,217 @@ +"""Core testing data""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +import dill +import pytest + +from maize.core.channels import DataChannel, FileChannel +from maize.core.component import Component +from maize.core.interface import Output, Input + + +@pytest.fixture +def datafile(shared_datadir): + return shared_datadir / "testorigin.abc" + + +@pytest.fixture +def datafile2(shared_datadir): + return shared_datadir / "nested" / "testorigin2.abc" + + +@pytest.fixture +def nested_datafiles(datafile, datafile2): + return [datafile, datafile2] + + +@pytest.fixture +def nested_datafiles_dict(datafile, datafile2): + return {"foo": datafile, 42: datafile2} + + +@pytest.fixture +def data(): + return 17 + + +@pytest.fixture +def filedata(tmp_path): + return tmp_path / "chan-test" + + +@pytest.fixture +def loaded_datachannel(): + channel = DataChannel(size=1) + channel.preload(dill.dumps(42)) + return channel + + +@pytest.fixture +def loaded_datachannel2(): + channel = DataChannel(size=2) + channel.preload(dill.dumps(42)) + channel.preload(dill.dumps(-42)) + return channel + + +@pytest.fixture +def empty_datachannel(): + channel = DataChannel(size=3) + return channel + + +@pytest.fixture +def empty_filechannel(tmp_path): + channel = FileChannel() + channel.setup(destination=tmp_path / "chan-test") + return channel + + +@pytest.fixture +def empty_filechannel2(tmp_path): + channel = FileChannel(mode="link") + channel.setup(destination=tmp_path / "chan-test") + return channel + + +@pytest.fixture +def loaded_filechannel(empty_filechannel, datafile): + empty_filechannel.preload(datafile) + return empty_filechannel + + +@pytest.fixture +def loaded_filechannel2(empty_filechannel, nested_datafiles): + empty_filechannel.preload(nested_datafiles) + return empty_filechannel + + +@pytest.fixture +def loaded_filechannel3(empty_filechannel2, nested_datafiles): + empty_filechannel2.preload(nested_datafiles) + return empty_filechannel2 + + +@pytest.fixture +def loaded_filechannel4(empty_filechannel, nested_datafiles_dict): + empty_filechannel.preload(nested_datafiles_dict) + return empty_filechannel + + +@pytest.fixture +def loaded_filechannel5(tmp_path, datafile): + channel = FileChannel() + chan_path = tmp_path / "chan-test" + chan_path.mkdir() + (chan_path / datafile.name).touch() + channel.setup(destination=chan_path) + return channel + + +# If this breaks, check this: +# https://stackoverflow.com/questions/42014484/pytest-using-fixtures-as-arguments-in-parametrize +@pytest.fixture(params=["loaded_datachannel", "loaded_filechannel"]) +def loaded_channel(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture(params=["empty_datachannel", "empty_filechannel"]) +def empty_channel(request): + return request.getfixturevalue(request.param) + + +@pytest.fixture +def mock_component(): + return Component(name="mock") + + +@pytest.fixture +def connected_output(empty_datachannel, mock_component): + out = Output().build(name="Test", parent=mock_component) + out.set_channel(empty_datachannel) + return out + + +@pytest.fixture +def connected_input(loaded_datachannel, mock_component): + inp = Input().build(name="Test", parent=mock_component) + inp.set_channel(loaded_datachannel) + return inp + + +@pytest.fixture +def connected_input_default(loaded_datachannel, data, mock_component): + inp = Input(default=data).build(name="Test", parent=mock_component) + inp.set_channel(loaded_datachannel) + return inp + + +@pytest.fixture +def connected_input_default_factory(data, mock_component): + inp = Input[list[int]](default_factory=lambda: [data]).build(name="Test", parent=mock_component) + return inp + + +@pytest.fixture +def connected_input_multi(loaded_datachannel2, mock_component): + inp = Input().build(name="Test", parent=mock_component) + inp.set_channel(loaded_datachannel2) + return inp + + +@pytest.fixture +def unconnected_input(mock_component): + inp = Input().build(name="Test", parent=mock_component) + return inp + + +@pytest.fixture +def unconnected_input_default(mock_component, data): + inp = Input(default=data).build(name="Test", parent=mock_component) + return inp + + +@pytest.fixture +def unconnected_output(mock_component): + out = Output().build(name="Test", parent=mock_component) + return out + + +@pytest.fixture +def connected_pair(empty_datachannel, mock_component): + inp = Input().build(name="Test", parent=mock_component) + out = Output().build(name="Test", parent=mock_component) + out.set_channel(empty_datachannel) + inp.set_channel(empty_datachannel) + return inp, out + + +@pytest.fixture +def connected_file_output(empty_filechannel, mock_component): + out = Output().build(name="Test", parent=mock_component) + out.set_channel(empty_filechannel) + return out + + +@pytest.fixture +def connected_file_output_full(loaded_filechannel, mock_component): + out = Output().build(name="Test", parent=mock_component) + out.set_channel(loaded_filechannel) + return out + + +@pytest.fixture +def connected_file_input(loaded_filechannel, mock_component): + inp = Input().build(name="Test", parent=mock_component) + inp.set_channel(loaded_filechannel) + return inp + + +@pytest.fixture +def connected_file_pair(empty_filechannel, mock_component): + inp = Input().build(name="Test", parent=mock_component) + out = Output().build(name="Test", parent=mock_component) + out.set_channel(empty_filechannel) + inp.set_channel(empty_filechannel) + return inp, out diff --git a/tests/core/data/checkpoint-nested.yaml b/tests/core/data/checkpoint-nested.yaml new file mode 100644 index 0000000..468d674 --- /dev/null +++ b/tests/core/data/checkpoint-nested.yaml @@ -0,0 +1,82 @@ +_data: +- sg: + ssg: + delay: + inp: !!binary | + gARdlC4= +- sg: + delay: + inp: !!binary | + gARdlC4= +- b: + inp: !!binary | + gARdlC4= +- m: + inp: !!binary | + gARdlC4= +- t: + inp: !!binary | + gARdlC4= +_status: +- sg: + ssg: + a: READY +- sg: + ssg: + delay: READY +- sg: + delay: READY +- b: READY +- m: READY +- t: READY +channels: +- receiving: + m: inp + sending: + sg: + delay: out +- receiving: + m: inp + sending: + b: out +- receiving: + b: inp + sending: + m: out +- receiving: + t: inp + sending: + b: out_final +description: null +level: 20 +name: None +nodes: +- description: null + fail_ok: false + n_attempts: 1 + name: sg + parameters: {} + status: READY + type: SubGraph +- description: null + fail_ok: false + n_attempts: 1 + name: b + parameters: {} + status: READY + type: B +- description: null + fail_ok: false + n_attempts: 1 + name: m + parameters: {} + status: READY + type: Merge +- description: null + fail_ok: false + n_attempts: 1 + name: t + parameters: {} + status: READY + type: Return +parameters: [] diff --git a/tests/core/data/checkpoint.yaml b/tests/core/data/checkpoint.yaml new file mode 100644 index 0000000..c003921 --- /dev/null +++ b/tests/core/data/checkpoint.yaml @@ -0,0 +1,31 @@ +_data: +- term: + inp: !!binary | + gARdlC4= +_status: +- a: READY +- term: READY +channels: +- receiving: + term: inp + sending: + a: out +description: null +level: 20 +name: None +nodes: +- description: null + fail_ok: false + n_attempts: 1 + name: a + parameters: {} + status: READY + type: A +- description: null + fail_ok: false + n_attempts: 1 + name: term + parameters: {} + status: READY + type: Return +parameters: [] diff --git a/tests/core/data/graph-inp-map.yaml b/tests/core/data/graph-inp-map.yaml new file mode 100644 index 0000000..96ad6fc --- /dev/null +++ b/tests/core/data/graph-inp-map.yaml @@ -0,0 +1,46 @@ +channels: +- receiving: + m: inp + sending: + sg: + delay: out +- receiving: + m: inp + sending: + b: out +- receiving: + b: inp + sending: + m: out +- receiving: + m: inp + sending: + del: out +- receiving: + t: inp + sending: + b: out_final +level: 20 +name: None +nodes: +- name: sg + type: SubGraph +- name: b + type: B +- name: m + type: Merge +- name: del + type: Delay +- name: t + type: Return +parameters: +- name: val + value: 42 + map: + - sg: + ssg: + a: val +- name: delay + value: 2 + map: + - del: inp diff --git a/tests/core/data/graph-inp-para.yaml b/tests/core/data/graph-inp-para.yaml new file mode 100644 index 0000000..496ee93 --- /dev/null +++ b/tests/core/data/graph-inp-para.yaml @@ -0,0 +1,44 @@ +channels: +- receiving: + m: inp + sending: + sg: + delay: out +- receiving: + m: inp + sending: + b: out +- receiving: + b: inp + sending: + m: out +- receiving: + m: inp + sending: + del: out +- receiving: + t: inp + sending: + b: out_final +level: 20 +name: None +nodes: +- name: sg + type: SubGraph +- name: b + type: B +- name: m + type: Merge +- name: del + type: Delay + parameters: + inp: 2 +- name: t + type: Return +parameters: +- name: val + value: 42 + map: + - sg: + ssg: + a: val diff --git a/tests/core/data/graph-mp-fixed.yaml b/tests/core/data/graph-mp-fixed.yaml new file mode 100644 index 0000000..8dfb24d --- /dev/null +++ b/tests/core/data/graph-mp-fixed.yaml @@ -0,0 +1,23 @@ +name: Example +level: DEBUG +nodes: +- name: sg + type: SubGraph +- name: ex + type: Example +- name: concat + type: ConcatAndPrint +channels: +- receiving: + concat: inp + sending: + sg: out +- receiving: + concat: inp + sending: + ex: out +parameters: +- name: data + value: "World" + map: + - ex: data diff --git a/tests/core/data/graph.yaml b/tests/core/data/graph.yaml new file mode 100644 index 0000000..a9fb4cd --- /dev/null +++ b/tests/core/data/graph.yaml @@ -0,0 +1,36 @@ +channels: +- receiving: + m: inp + sending: + sg: + delay: out +- receiving: + m: inp + sending: + b: out +- receiving: + b: inp + sending: + m: out +- receiving: + t: inp + sending: + b: out_final +level: 20 +name: None +nodes: +- name: sg + type: SubGraph +- name: b + type: B +- name: m + type: Merge +- name: t + type: Return +parameters: +- name: val + value: 42 + map: + - sg: + ssg: + a: val diff --git a/tests/core/data/nested/testorigin2.abc b/tests/core/data/nested/testorigin2.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/data/testorigin.abc b/tests/core/data/testorigin.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/data/two-node.json b/tests/core/data/two-node.json new file mode 100644 index 0000000..b3c186e --- /dev/null +++ b/tests/core/data/two-node.json @@ -0,0 +1 @@ +{"level": 20, "name": "None", "description": null, "nodes": [{"name": "a", "description": null, "fail_ok": false, "n_attempts": 1, "parameters": {"val": null, "file": null, "flag": null}, "type": "A", "status": "READY"}, {"name": "term", "description": null, "fail_ok": false, "n_attempts": 1, "parameters": {}, "type": "Return", "status": "READY"}], "parameters": [], "channels": [{"sending": {"a": "out"}, "receiving": {"term": "input"}}]} \ No newline at end of file diff --git a/tests/core/data/two-node.toml b/tests/core/data/two-node.toml new file mode 100644 index 0000000..5b51bb3 --- /dev/null +++ b/tests/core/data/two-node.toml @@ -0,0 +1,23 @@ +level = 20 +name = "None" +parameters = [] +[[nodes]] +name = "a" +fail_ok = false +n_attempts = 1 +type = "A" +status = "READY" + +[[nodes]] +name = "term" +fail_ok = false +n_attempts = 1 +type = "Return" +status = "READY" + +[[channels]] + +[channels.sending] +a = "out" +[channels.receiving] +term = "input" diff --git a/tests/core/data/two-node.yml b/tests/core/data/two-node.yml new file mode 100644 index 0000000..c5fec03 --- /dev/null +++ b/tests/core/data/two-node.yml @@ -0,0 +1,27 @@ +channels: +- receiving: + term: input + sending: + a: out +description: null +level: 20 +name: None +nodes: +- description: null + fail_ok: false + n_attempts: 1 + name: a + parameters: + file: null + flag: null + val: null + status: READY + type: A +- description: null + fail_ok: false + n_attempts: 1 + name: term + parameters: {} + status: READY + type: Return +parameters: [] diff --git a/tests/core/test_channel.py b/tests/core/test_channel.py new file mode 100644 index 0000000..b49c45f --- /dev/null +++ b/tests/core/test_channel.py @@ -0,0 +1,211 @@ +"""Channel testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +from multiprocessing import Process +from threading import Thread +import time +import pytest + +from maize.core.channels import ChannelFull, ChannelException +from maize.utilities.utilities import has_file + + +class Test_Channel: + def test_channel_active(self, loaded_channel): + assert loaded_channel.active + + def test_channel_ready(self, loaded_channel): + assert loaded_channel.ready + + def test_channel_ready_multi(self, loaded_channel): + assert loaded_channel.ready + assert loaded_channel.ready + + def test_channel_close(self, loaded_channel): + loaded_channel.close() + assert not loaded_channel.active + assert loaded_channel.ready + + def test_empty_channel_ready(self, empty_channel): + assert empty_channel.active + assert not empty_channel.ready + + def test_empty_channel_close(self, empty_channel): + empty_channel.close() + assert not empty_channel.active + assert not empty_channel.ready + + +class Test_DataChannel: + def test_channel_send(self, empty_datachannel): + empty_datachannel.send(42) + assert empty_datachannel.ready + + def test_channel_send_full(self, loaded_datachannel): + with pytest.raises(ChannelFull): + loaded_datachannel.send(42, timeout=1) + + def test_channel_receive(self, loaded_datachannel): + assert loaded_datachannel.receive() == 42 + + def test_channel_receive_empty(self, empty_datachannel): + assert empty_datachannel.receive(timeout=1) is None + + def test_channel_receive_buffer(self, loaded_datachannel): + assert loaded_datachannel.ready + assert loaded_datachannel.receive() == 42 + + def test_channel_receive_buffer_multi(self, loaded_datachannel): + assert loaded_datachannel.ready + assert loaded_datachannel.ready + assert loaded_datachannel.receive() == 42 + + def test_channel_close_receive(self, loaded_datachannel): + loaded_datachannel.close() + assert loaded_datachannel.receive() == 42 + + def test_channel_receive_multi(self, loaded_datachannel2): + assert loaded_datachannel2.receive() == 42 + assert loaded_datachannel2.receive() == -42 + + def test_channel_receive_multi_buffer(self, loaded_datachannel2): + assert loaded_datachannel2.ready + assert loaded_datachannel2.receive() == 42 + assert loaded_datachannel2.receive() == -42 + + def test_channel_receive_multi_buffer2(self, loaded_datachannel2): + assert loaded_datachannel2.ready + assert loaded_datachannel2.ready + assert loaded_datachannel2.receive() == 42 + assert loaded_datachannel2.receive() == -42 + + def test_channel_receive_multi_close(self, loaded_datachannel2): + loaded_datachannel2.close() + assert loaded_datachannel2.ready + assert loaded_datachannel2.receive() == 42 + assert loaded_datachannel2.receive() == -42 + + +class Test_FileChannel: + def test_channel_size_empty(self, empty_filechannel): + assert empty_filechannel.size == 0 + + def test_channel_size_full(self, loaded_filechannel): + assert loaded_filechannel.size == 1 + + def test_channel_send(self, empty_filechannel, datafile): + empty_filechannel.send(datafile) + assert empty_filechannel.ready + + def test_channel_send_no_file(self, empty_filechannel, shared_datadir): + with pytest.raises(ChannelException): + empty_filechannel.send(shared_datadir / "nonexistent.abc") + + def test_channel_send_copy(self, empty_filechannel, datafile): + empty_filechannel.copy = True + empty_filechannel.send(datafile) + assert empty_filechannel.ready + + def test_channel_send_full(self, loaded_filechannel, datafile): + with pytest.raises(ChannelFull): + loaded_filechannel.send(datafile, timeout=1) + + def test_channel_send_full_payload(self, empty_filechannel, datafile): + """ + Test situation where data was put in the payload queue, but the + trigger wasn't set. This shouldn't really happen in practice. + """ + empty_filechannel._payload.put(datafile) + with pytest.raises(ChannelFull): + empty_filechannel.send(datafile, timeout=1) + + def test_channel_send_dict(self, empty_filechannel, nested_datafiles_dict): + empty_filechannel.send(nested_datafiles_dict) + assert empty_filechannel.ready + + def test_channel_preload_full(self, loaded_filechannel, datafile): + with pytest.raises(ChannelFull): + loaded_filechannel.preload(datafile) + + def test_channel_receive(self, loaded_filechannel, datafile): + path = loaded_filechannel.receive() + assert path.exists() + assert path.name == datafile.name + + def test_channel_receive_trigger(self, empty_filechannel): + empty_filechannel._file_trigger.set() + with pytest.raises(ChannelException): + empty_filechannel.receive(timeout=1) + + def test_channel_receive_nested_multi(self, loaded_filechannel2, nested_datafiles): + paths = loaded_filechannel2.receive() + assert paths[1].parent.name == nested_datafiles[1].parent.name + for path, datafile in zip(paths, nested_datafiles): + assert path.exists() + assert path.name == datafile.name + + def test_channel_receive_nested_multi_link(self, loaded_filechannel3, nested_datafiles): + paths = loaded_filechannel3.receive() + assert paths[1].parent.name == nested_datafiles[1].parent.name + for path, datafile in zip(paths, nested_datafiles): + assert path.exists() + assert path.name == datafile.name + + def test_channel_receive_multi(self, loaded_filechannel, datafile2): + path = loaded_filechannel.receive() + loaded_filechannel.send(datafile2) + path = loaded_filechannel.receive() + assert path.exists() + assert path.name == datafile2.name + + def test_channel_receive_dict(self, loaded_filechannel4, nested_datafiles_dict): + paths = loaded_filechannel4.receive() + assert paths.keys() == nested_datafiles_dict.keys() + for key in nested_datafiles_dict: + assert paths[key].exists() + assert paths[key].name == nested_datafiles_dict[key].name + + def test_channel_receive_empty(self, empty_filechannel): + assert empty_filechannel.receive(timeout=1) is None + + def test_channel_receive_buffer(self, loaded_filechannel): + assert loaded_filechannel.ready + assert loaded_filechannel.receive().exists() + + def test_channel_receive_buffer_multi(self, loaded_filechannel): + assert loaded_filechannel.ready + assert loaded_filechannel.ready + assert loaded_filechannel.receive().exists() + + def test_channel_receive_auto_preload(self, loaded_filechannel5, datafile): + path, *_ = loaded_filechannel5.receive() + assert path.exists() + assert path.name == datafile.name + + def test_channel_close_receive(self, loaded_filechannel): + loaded_filechannel.close() + assert loaded_filechannel.receive().exists() + + def test_channel_close_timeout(self, empty_filechannel, datafile): + def _clear(): + time.sleep(5) + empty_filechannel._file_trigger.clear() + + empty_filechannel.send(datafile) + assert empty_filechannel._file_trigger.is_set() + assert has_file(empty_filechannel._channel_dir) + proc = Process(target=_clear) + proc.start() + empty_filechannel.close() + proc.join() + assert not empty_filechannel._file_trigger.is_set() + + def test_channel_flush(self, loaded_filechannel, datafile): + path = loaded_filechannel.flush() + assert path[0].exists() + assert path[0].name == datafile.name + + def test_channel_flush_empty(self, empty_filechannel): + path = empty_filechannel.flush() + assert path == [] diff --git a/tests/core/test_component.py b/tests/core/test_component.py new file mode 100644 index 0000000..01a29fa --- /dev/null +++ b/tests/core/test_component.py @@ -0,0 +1,123 @@ +"""Component testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +import logging +import pytest + +from maize.core.component import Component +from maize.core.interface import Input, Output, Parameter +from maize.core.node import Node + + +@pytest.fixture +def parent_component(tmp_path): + comp = Component(name="master") + comp.setup_directories(tmp_path) + return comp + + +@pytest.fixture +def component(parent_component): + comp = Component(name="test", parent=parent_component) + comp.setup_directories(parent_component.work_dir) + return comp + + +@pytest.fixture +def nested_component(component): + return Component(name="leaf", parent=component) + + +@pytest.fixture +def nested_component_no_name(component): + return Component(parent=component) + + +@pytest.fixture +def example_comp_class(): + class Example(Node): + """Example docstring""" + required_callables = ["blah"] + inp: Input[int] = Input() + out: Output[int] = Output() + para: Parameter[int] = Parameter() + return Example + + +class Test_Component: + def test_no_subclass(self): + with pytest.raises(KeyError): + Component.get_node_class("nonexistent") + + def test_sample_config(self, example_comp_class): + example_comp_class._generate_sample_config(name="foo") + assert "Example configuration" in example_comp_class.__doc__ + + def test_serialized_summary(self, example_comp_class): + data = example_comp_class.serialized_summary() + assert data["name"] == "Example" + assert data["inputs"][0]["name"] == "inp" + assert data["outputs"][0]["name"] == "out" + assert data["parameters"][0]["name"] == "para" + + def test_summary_line(self, example_comp_class): + assert "Example docstring" in example_comp_class.get_summary_line() + + def test_get_interfaces(self, example_comp_class): + assert {"inp", "out", "para"} == example_comp_class.get_interfaces() + + def test_get_inputs(self, example_comp_class): + assert "inp" in example_comp_class.get_inputs() + + def test_get_outputs(self, example_comp_class): + assert "out" in example_comp_class.get_outputs() + + def test_get_parameters(self, example_comp_class): + assert "para" in example_comp_class.get_parameters() + + def test_get_available(self, example_comp_class): + assert example_comp_class in Component.get_available_nodes() + + def test_init(self, parent_component): + assert parent_component.level == logging.INFO + assert parent_component.name == "master" + + def test_init_child(self, component, parent_component): + assert component.parent is parent_component + assert component.name == "test" + parent_component.signal.set() + assert component.signal.is_set() + + def test_root(self, component, parent_component): + assert component.root is parent_component + assert parent_component.root is parent_component + + def test_parameter_fail(self, component): + with pytest.raises(KeyError): + component.update_parameters(non_existent=42) + + def test_component_path(self, nested_component): + assert nested_component.component_path == ("test", "leaf",) + + def test_work_dir(self, nested_component, tmp_path): + nested_component.setup_directories(tmp_path / nested_component.parent.work_dir) + assert "/comp-master/comp-test/comp-leaf" in nested_component.work_dir.as_posix() + assert nested_component.work_dir.exists() + + def test_work_dir_auto(self, nested_component_no_name, tmp_path): + nested_component_no_name.setup_directories( + tmp_path / nested_component_no_name.parent.work_dir) + assert "/comp-master/comp-test/comp-" in nested_component_no_name.work_dir.as_posix() + assert nested_component_no_name.work_dir.exists() + + def test_work_dir_default(self, nested_component, temp_working_dir): + nested_component.setup_directories() + assert "comp-leaf" in nested_component.work_dir.as_posix() + assert nested_component.work_dir.exists() + + def test_as_dict(self, nested_component): + dic = nested_component.as_dict() + assert dic["name"] == "leaf" + assert "description" not in dic + assert "status" not in dic diff --git a/tests/core/test_graph.py b/tests/core/test_graph.py new file mode 100644 index 0000000..11bb46e --- /dev/null +++ b/tests/core/test_graph.py @@ -0,0 +1,701 @@ +"""Graph testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import, unused-variable, unused-argument + +import argparse +import json +import logging +from pathlib import Path +from typing import Generic, TypeVar + +import pytest + +from maize.core.channels import FileChannel +from maize.core.graph import Graph, GraphBuildException +from maize.core.interface import Input, Port +from maize.core.node import Node +from maize.core.runtime import Status +from maize.core.workflow import Workflow, CheckpointException, ParsingException +from maize.steps.plumbing import Merge, Delay +from maize.steps.io import LoadFile, Return, SaveFile +from maize.utilities.io import Config, read_input + + +@pytest.fixture +def graph_dict(shared_datadir): + return read_input(shared_datadir / "graph.yaml") + + +@pytest.fixture(params=["nodes", "parameters"]) +def graph_dict_missing_field(graph_dict, request): + graph_dict[request.param][0].pop("name") + return graph_dict + + +@pytest.fixture +def graph_dict_missing_channel_field(graph_dict): + del graph_dict["channels"][0]["sending"] + return graph_dict + + +@pytest.fixture +def graph_dict_missing_channel(graph_dict): + graph_dict["channels"].pop() + return graph_dict + + +@pytest.fixture +def two_node_graph(example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + return g + + +@pytest.fixture +def two_node_graph_param(example_a): + g = Workflow(level="DEBUG") + a = g.add(example_a, "a", parameters=dict(val=42)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + return g + + +@pytest.fixture +def graph_with_params(two_node_graph_param): + param = two_node_graph_param.nodes["a"].parameters["val"] + two_node_graph_param.combine_parameters(param, name="value") + return two_node_graph_param + + +@pytest.fixture +def graph_with_all_params(two_node_graph_param): + node = two_node_graph_param.nodes["a"] + for param in node.parameters.values(): + two_node_graph_param.combine_parameters(param, name=param.name) + return two_node_graph_param + + +class Test_Graph_init: + def test_graph_add(self, example_a): + g = Workflow() + g.add(example_a, "a") + assert len(g.nodes) == 1 + + def test_two_node_init(self, two_node_graph): + assert len(two_node_graph.channels) == 1 + assert len(two_node_graph.nodes) == 2 + assert len(two_node_graph.nodes["a"].parameters) == 9 + + def test_from_dict(self, graph_dict, example_a, example_b, subgraph): + g = Workflow.from_dict(graph_dict) + assert len(g.nodes) == 4 + assert len(g.channels) == 4 + assert len(g.parameters) == 1 + assert g.parameters["val"].value == 42 + + def test_from_dict_missing_node_field( + self, graph_dict_missing_field, example_a, example_b, subgraph + ): + A, B = example_a, example_b + with pytest.raises((ParsingException, KeyError)): + g = Workflow.from_dict(graph_dict_missing_field) + + def test_from_dict_missing_channel_field( + self, graph_dict_missing_channel_field, example_a, example_b, subgraph + ): + A, B = example_a, example_b + with pytest.raises(ParsingException): + g = Workflow.from_dict(graph_dict_missing_channel_field) + + def test_from_dict_missing_channel( + self, graph_dict_missing_channel, example_a, example_b, subgraph + ): + A, B = example_a, example_b + with pytest.raises(GraphBuildException): + g = Workflow.from_dict(graph_dict_missing_channel) + + def test_from_yaml(self, shared_datadir, example_a, example_b, subgraph): + A, B = example_a, example_b + g = Workflow.from_file(shared_datadir / "graph.yaml") + assert len(g.nodes) == 4 + assert len(g.channels) == 4 + assert len(g.parameters) == 1 + assert g.parameters["val"].value == 42 + + def test_from_yaml_inp_map(self, shared_datadir, example_a, example_b, subgraph): + A, B = example_a, example_b + g = Workflow.from_file(shared_datadir / "graph-inp-map.yaml") + assert len(g.nodes) == 5 + assert len(g.channels) == 5 + assert len(g.parameters) == 2 + assert g.parameters["val"].value == 42 + para = g.parameters["delay"] + assert para.value == 2 + assert para.is_set + assert not Port.is_connected(para._parameters[0]) + + def test_from_yaml_inp_para(self, shared_datadir, example_a, example_b, subgraph): + A, B = example_a, example_b + g = Workflow.from_file(shared_datadir / "graph-inp-para.yaml") + assert len(g.nodes) == 5 + assert len(g.channels) == 5 + assert len(g.parameters) == 1 + assert g.parameters["val"].value == 42 + para = g.nodes["del"].inputs["inp"] + assert para.value == 2 + assert para.is_set + assert not Port.is_connected(para) + + def test_to_dict(self, two_node_graph): + data = two_node_graph.to_dict() + assert len(data["nodes"]) == 2 + assert len(data["channels"]) == 1 + + def test_nested_to_dict(self, nested_graph): + data = nested_graph.to_dict() + assert len(data["nodes"]) == 4 + assert len(data["channels"]) == 4 + + def test_to_from_dict(self, two_node_graph): + data = two_node_graph.to_dict() + graph = Workflow.from_dict(data) + assert two_node_graph.nodes.keys() == graph.nodes.keys() + assert two_node_graph.channels.keys() == graph.channels.keys() + new_data = graph.to_dict() + assert new_data == data + + def test_to_from_dict2(self, two_node_graph): + two_node_graph.config.scratch = Path("bar") + two_node_graph.scratch = Path("foo") + data = two_node_graph.to_dict() + graph = Workflow.from_dict(data) + assert two_node_graph.nodes.keys() == graph.nodes.keys() + assert two_node_graph.channels.keys() == graph.channels.keys() + assert two_node_graph.scratch == Path("foo") + assert two_node_graph.config.scratch == Path("foo") + new_data = graph.to_dict() + assert new_data == data + + def test_to_from_dict_complex(self, nested_graph, example_a, example_b): + A, B = example_a, example_b + data = nested_graph.to_dict() + graph = Workflow.from_dict(data) + assert nested_graph.nodes.keys() == graph.nodes.keys() + assert nested_graph.channels.keys() == graph.channels.keys() + + @pytest.mark.parametrize("suffix,length", [("json", 25), ("yml", 14), ("toml", 18)]) + def test_to_file(self, two_node_graph, tmp_path, suffix, length): + file = tmp_path / f"two-node.{suffix}" + two_node_graph.to_file(file) + assert file.exists() + assert len(file.read_text().split("\n")) == length + + def test_to_checkpoint(self, two_node_graph): + file = two_node_graph.work_dir.glob(f"ckp-{two_node_graph.name}-*.yaml") + two_node_graph.to_checkpoint(fail_ok=False) + assert next(file).exists() + + def test_to_checkpoint_given_file(self, two_node_graph, tmp_path): + file = tmp_path / "checkpoint.yaml" + two_node_graph.to_checkpoint(file, fail_ok=False) + assert file.exists() + + def test_to_checkpoint_given_file_fail(self, two_node_graph, tmp_path): + file = tmp_path / "non-existent" / "checkpoint.yaml" + with pytest.raises(CheckpointException): + two_node_graph.to_checkpoint(file, fail_ok=False) + + def test_to_checkpoint_given_file_fail_ok(self, two_node_graph, tmp_path): + file = tmp_path / "non-existent" / "checkpoint.yaml" + two_node_graph.to_checkpoint(file, fail_ok=True) + + def test_from_checkpoint(self, shared_datadir, example_a, example_b): + A, B = example_a, example_b + file = shared_datadir / "checkpoint.yaml" + g = Workflow.from_checkpoint(file) + assert "a" in g.nodes + assert "term" in g.nodes + assert g.nodes["a"].status == Status.READY + assert not g.nodes["a"].fail_ok + + def test_to_checkpoint_nested(self, nested_graph, tmp_path): + file = tmp_path / "checkpoint-nested.yaml" + nested_graph.to_checkpoint(file) + assert file.exists() + assert len(file.read_text().split("\n")) == 62 + + def test_from_checkpoint_nested( + self, shared_datadir, example_a, example_b, subgraph, subsubgraph + ): + A, B = example_a, example_b + SubGraph, SubSubGraph = subgraph, subsubgraph + file = shared_datadir / "checkpoint-nested.yaml" + g = Workflow.from_checkpoint(file) + + @pytest.mark.parametrize("extra_options", [["--value", "42"], ["--value", "-2"]]) + def test_update_with_args(self, graph_with_params, extra_options): + setting = int(extra_options[-1]) + graph_with_params.update_with_args(extra_options) + assert graph_with_params.parameters["value"].value == setting + + def test_update_with_args_all(self, graph_with_all_params): + extra_options = ["--val", "42", "--flag", "--file", "file.pdb"] + graph_with_all_params.update_with_args(extra_options) + assert graph_with_all_params.parameters["val"].value == 42 + assert graph_with_all_params.parameters["flag"].value + assert graph_with_all_params.parameters["file"].value == Path("file.pdb") + + def test_update_with_args_all2(self, graph_with_all_params): + extra_options = ["--val", "42", "--no-flag", "--file", "file.pdb"] + graph_with_all_params.update_with_args(extra_options) + assert graph_with_all_params.parameters["val"].value == 42 + assert not graph_with_all_params.parameters["flag"].value + assert graph_with_all_params.parameters["file"].value == Path("file.pdb") + + def test_update_with_args_all3(self, graph_with_all_params): + extra_options = ["--val", "42"] + params = Path("para.json") + with params.open("w") as out: + json.dump({"flag": True}, out) + + config = Path("config.json") + with config.open("w") as out: + json.dump({"scratch": "scratch-folder"}, out) + + global_options = argparse.Namespace( + log=Path("log.log"), + quiet=False, + debug=False, + keep=False, + parameters=params, + config=config, + scratch=None, + ) + graph_with_all_params.update_with_args(extra_options) + graph_with_all_params.update_settings_with_args(global_options) + assert graph_with_all_params.parameters["val"].value == 42 + assert graph_with_all_params.logfile == Path("log.log") + assert graph_with_all_params.parameters["flag"].value + assert graph_with_all_params.config.scratch == Path("scratch-folder") + + def test_update_with_args_all4(self, graph_with_all_params): + extra_options = ["--val", "42"] + params = Path("para.json") + with params.open("w") as out: + json.dump({"flag": True}, out) + + config = Path("config.json") + with config.open("w") as out: + json.dump({"scratch": "scratch-folder"}, out) + + global_options = argparse.Namespace( + log=Path("log.log"), + quiet=False, + debug=False, + keep=False, + parameters=params, + config=config, + scratch=Path("other-scratch-folder"), + ) + graph_with_all_params.update_with_args(extra_options) + graph_with_all_params.update_settings_with_args(global_options) + assert graph_with_all_params.parameters["val"].value == 42 + assert graph_with_all_params.logfile == Path("log.log") + assert graph_with_all_params.parameters["flag"].value + assert graph_with_all_params.scratch == Path("other-scratch-folder") + + def test_update_with_args_all_fail1(self, graph_with_all_params): + extra_options = ["--val", "seven", "--flag", "--file", "file.pdb"] + with pytest.raises(ParsingException): + graph_with_all_params.update_with_args(extra_options) + + def test_update_with_args_all_fail2(self, graph_with_all_params): + extra_options = ["--val", "42", "--flag", "--blah", "file.pdb"] + with pytest.raises(ParsingException): + graph_with_all_params.update_with_args(extra_options) + + def test_update_with_args_all_fail3(self, graph_with_all_params): + extra_options = ["--val", "42", "--flag", "--file", "file.xyz"] + with pytest.raises(ValueError): + graph_with_all_params.update_with_args(extra_options) + + def test_update_with_args_all_fail4(self, graph_with_all_params): + extra_options = ["--val", "42"] + global_options = argparse.Namespace( + log=Path("log.log"), + quiet=False, + debug=False, + keep=False, + parameters=None, + config=Path("nofile.toml"), + scratch=None, + ) + graph_with_all_params.update_with_args(extra_options) + with pytest.raises(FileNotFoundError): + graph_with_all_params.update_settings_with_args(global_options) + + +class Test_Graph_properties: + def test_get_node(self, nested_graph): + assert nested_graph.get_node("b").name == "b" + assert nested_graph.get_node("sg", "delay").name == "delay" + assert nested_graph.get_node("sg", "ssg", "a").name == "a" + assert nested_graph.get_node("sg", "ssg", "delay").name == "delay" + assert nested_graph.get_node("sg", "delay") is not nested_graph.get_node( + "sg", "ssg", "delay" + ) + + def test_get_parameter(self, nested_graph): + assert nested_graph.get_parameter("sg", "ssg", "a", "val").name == "val" + assert nested_graph.get_parameter("sg", "delay", "delay").name == "delay" + with pytest.raises(KeyError): + nested_graph.get_parameter("sg", "delay", "nope") + + def test_get_port(self, nested_graph): + assert nested_graph.get_port("sg", "ssg", "out").name == "out" + assert nested_graph.get_port("m", "inp").name == "inp" + + def test_flat_nodes(self, nested_graph): + assert {node.name for node in nested_graph.flat_nodes} == {"a", "b", "delay", "t", "m"} + + def test_as_dict(self, nested_graph): + assert len(nested_graph.as_dict()["name"]) == 6 + + def test_directories(self, nested_graph, tmp_path): + nested_graph.setup_directories(tmp_path) + assert nested_graph.work_dir.name.startswith("graph-") + assert nested_graph.work_dir.exists() + assert nested_graph.get_node("b").work_dir.name == "node-b" + assert nested_graph.get_node("sg", "delay").work_dir.name == "node-delay" + assert nested_graph.get_node("sg", "delay").work_dir.parent.name == "graph-sg" + + def test_directories_parent(self, nested_graph, temp_working_dir): + nested_graph.config.scratch = Path("test") + nested_graph.setup_directories() + assert nested_graph.work_dir.name.startswith("graph-") + assert nested_graph.work_dir.exists() + + def test_directories_rename(self, nested_graph, tmp_path): + nested_graph.name = "test" + (tmp_path / "graph-test").mkdir() + nested_graph.setup_directories(tmp_path) + assert nested_graph.work_dir.name == "graph-test-0" + assert nested_graph.work_dir.exists() + + +class Test_Graph_build: + def test_add(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + assert "a" in g.nodes + assert a in g.flat_nodes + + def test_add_all(self, example_a, example_b): + g = Workflow() + a, b = g.add_all(example_a, example_b) + assert "a" in g.nodes + assert "b" in g.nodes + assert a in g.flat_nodes + assert b in g.flat_nodes + + def test_graph_add_generic(self): + T = TypeVar("T") + + class Example(Node, Generic[T]): + inp: Input[T] = Input() + + g = Workflow(strict=True) + with pytest.raises(GraphBuildException): + g.add(Example, "a") + + def test_add_duplicate(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + with pytest.raises(GraphBuildException): + a = g.add(example_a, "a") + + def test_add_instance(self, example_a): + g = Workflow() + with pytest.raises(GraphBuildException): + a = g.add(example_a(), "a") + + def test_add_param(self, example_a): + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=42, flag=True)) + assert "a" in g.nodes + assert a in g.flat_nodes + assert a.parameters["val"].value == 42 + assert a.parameters["flag"].value + + def test_add_param_fail(self, example_a): + g = Workflow() + with pytest.raises(KeyError): + a = g.add(example_a, "a", parameters=dict(nonexistent=1)) + + def test_add_subgraph_looped(self, subgraph, example_b): + g = Workflow() + sg = g.add(subgraph, "sg", loop=True) + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(sg.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + assert sg.nodes["delay"].looped + assert sg.nodes["ssg"].looped + assert sg.nodes["ssg"].nodes["delay"].looped + assert sg.nodes["ssg"].nodes["a"].looped + + def test_check(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "t") + with pytest.raises(GraphBuildException): + g.check() + + def test_connect(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + g.check() + assert a.status == Status.READY + assert t.status == Status.READY + + def test_connect_file(self, tmp_path): + g = Workflow(level=logging.DEBUG) + a = g.add(LoadFile[Path], "a") + t = g.add(SaveFile[Path], "t") + g.connect(a.out, t.inp) + a.file.set(Path("fake")) + t.destination.set(Path("fake")) + g.check() + g.setup_directories(tmp_path) + assert len(g.channels) == 1 + assert isinstance(g.channels.popitem()[1], FileChannel) + assert (t.parent.work_dir / f"{t.name}-{t.inp.name}").exists() + assert a.status == Status.READY + assert t.status == Status.READY + + def test_connect_file_link(self, tmp_path): + g = Workflow(level=logging.DEBUG) + a = g.add(LoadFile[Path], "a") + t = g.add(SaveFile[Path], "t") + a.out.mode = t.inp.mode = "link" + g.connect(a.out, t.inp) + a.file.set(Path("fake")) + t.destination.set(Path("fake")) + g.check() + g.setup_directories(tmp_path) + assert len(g.channels) == 1 + chan = g.channels.popitem()[1] + assert chan.mode == "link" + assert isinstance(chan, FileChannel) + assert (t.parent.work_dir / f"{t.name}-{t.inp.name}").exists() + assert a.status == Status.READY + assert t.status == Status.READY + + def test_connect_file_move(self, tmp_path): + g = Workflow(level=logging.DEBUG) + a = g.add(LoadFile[Path], "a") + t = g.add(SaveFile[Path], "t") + a.out.mode = t.inp.mode = "move" + g.connect(a.out, t.inp) + a.file.set(Path("fake")) + t.destination.set(Path("fake")) + g.check() + g.setup_directories(tmp_path) + assert len(g.channels) == 1 + chan = g.channels.popitem()[1] + assert chan.mode == "move" + assert isinstance(chan, FileChannel) + assert (t.parent.work_dir / f"{t.name}-{t.inp.name}").exists() + assert a.status == Status.READY + assert t.status == Status.READY + + def test_connect_bad_types(self, example_b): + g = Workflow() + a = g.add(LoadFile[Path], "a") + b = g.add(example_b, "b", loop=True) + t1 = g.add(Return[int], "t1") + t2 = g.add(Return[int], "t2") + with pytest.raises(GraphBuildException): + g.connect(a.out, b.inp) + g.connect(b.out, t1.inp) + g.connect(b.out_final, t2.inp) + + def test_connect_duplicate(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + with pytest.raises(GraphBuildException): + g.connect(a.out, t.inp) + + def test_connect_bad_workflow(self, example_b): + g = Workflow() + a = g.add(LoadFile[int], "a") + + g1 = Workflow() + b = g1.add(example_b, "b", loop=True) + with pytest.raises(GraphBuildException): + g.connect(a.out, b.inp) + + def test_autoconnect(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "t") + g.auto_connect(a, t) + g.check() + assert a.status == Status.READY + assert t.status == Status.READY + + def test_connect_large(self, subgraph, example_b): + g = Workflow() + sg = g.add(subgraph, "sg") + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(sg.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + assert b.status == Status.READY + assert m.status == Status.READY + assert t.status == Status.READY + assert len(g.nodes) == 4 + assert len(g.flat_nodes) == 6 + assert len(g.channels) == 4 + + def test_connect_large_shorthand(self, subgraph, example_b): + g = Workflow() + sg = g.add(subgraph, "sg") + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + sg >> m + m.inp << b.out + m.out >> b.inp + t << b + assert b.status == Status.READY + assert m.status == Status.READY + assert t.status == Status.READY + assert len(g.nodes) == 4 + assert len(g.flat_nodes) == 6 + assert len(g.channels) == 4 + + def test_chain(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + d1 = g.add(Delay[int], "d1") + d2 = g.add(Delay[int], "d2") + d3 = g.add(Delay[int], "d3") + t = g.add(Return[int], "t") + g.chain(a, d1, d2, d3, t) + assert a.status == Status.READY + assert d1.status == Status.READY + assert d2.status == Status.READY + assert d3.status == Status.READY + assert t.status == Status.READY + assert len(g.nodes) == 5 + assert len(g.flat_nodes) == 5 + assert len(g.channels) == 4 + + def test_connect_all(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + d1 = g.add(Delay[int], "d1") + d2 = g.add(Delay[int], "d2") + d3 = g.add(Delay[int], "d3") + t = g.add(Return[int], "t") + g.connect_all( + (a.out, d1.inp), + (d1.out, d2.inp), + (d2.out, d3.inp), + (d3.out, t.inp), + ) + assert a.status == Status.READY + assert d1.status == Status.READY + assert d2.status == Status.READY + assert d3.status == Status.READY + assert t.status == Status.READY + assert len(g.nodes) == 5 + assert len(g.flat_nodes) == 5 + assert len(g.channels) == 4 + + def test_map_parameters(self, example_a): + g = Workflow() + a = g.add(example_a, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + g.combine_parameters(a.val, name="val") + assert "val" in g.parameters + + def test_map_subgraph(self, subsubgraph): + g = Workflow() + a = g.add(subsubgraph, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + g.combine_parameters(a.val, name="val") + assert "val" in g.parameters + + def test_map_port_subgraph(self): + g = Workflow() + a = g.add(Delay[int], "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + g.map_port(a.inp) + + def test_map_port_subgraph_existing(self): + g = Workflow() + a1 = g.add(Delay[int], "a1") + a2 = g.add(Delay[int], "a2") + t = g.add(Return[int], "t") + g.connect(a1.out, t.inp) + g.map_port(a1.inp) + with pytest.raises(KeyError): + g.map_port(a2.inp) + + def test_map_bad_interface(self, subsubgraph): + g = Workflow() + a = g.add(subsubgraph, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + with pytest.raises(ValueError): + g.map(a.status) + + def test_map_subgraph_duplicate(self, subsubgraph): + g = Workflow() + a = g.add(subsubgraph, "a") + b = g.add(Delay[int], "b") + c = g.add(Delay[int], "c") + t = g.add(Return[int], "t") + g.connect(a.out, b.inp) + g.connect(b.out, c.inp) + g.connect(c.out, t.inp) + with pytest.raises(GraphBuildException): + g.map(b.delay, c.delay) + + def test_map_subgraph_multi(self, subgraph_multi): + g = Workflow() + a = g.add(subgraph_multi, "a") + t = g.add(Return[int], "t") + g.connect(a.out, t.inp) + assert "out" in a.outputs + + def test_build(self, newgraph, newgraph2): + g = Workflow() + ng = g.add(newgraph, "ng") + ng2 = g.add(newgraph2, "ng2") + g.connect(ng.out, ng2.inp) + g.check() + assert "out" in ng.ports + assert "inp" in ng2.ports + assert len(ng.ports) == 1 + assert len(ng2.ports) == 1 + assert len(g.nodes) == 2 + assert len(g.flat_nodes) == 5 + assert len(g.channels) == 1 diff --git a/tests/core/test_node.py b/tests/core/test_node.py new file mode 100644 index 0000000..063a631 --- /dev/null +++ b/tests/core/test_node.py @@ -0,0 +1,294 @@ +"""Node testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +from pathlib import Path +import time +from traceback import TracebackException +from types import MethodType +from typing import Any + +import pytest + +from maize.core.component import Component +from maize.core.interface import FileParameter, Parameter, PortException, Input, Output +from maize.core.node import Node, NodeBuildException +from maize.core.runtime import Status, setup_build_logging +from maize.utilities.testing import MockChannel +from maize.utilities.validation import SuccessValidator +from maize.utilities.execution import ProcessError + + +@pytest.fixture +def mock_parent(): + return Component(name="parent") + + +@pytest.fixture +def node(mock_parent): + class TestNode(Node): + inp: Input[Any] = Input() + out: Output[Any] = Output() + data: Parameter[Any] = Parameter() + file: FileParameter[Path] = FileParameter() + + def build(self): + super().build() + self.logger = setup_build_logging("build") + + def run(self): + pass + + return TestNode(parent=mock_parent) + + +@pytest.fixture +def node_hybrid(mock_parent): + class TestNode(Node): + inp: Input[int] = Input() + inp_default: Input[int] = Input(default=17) + inp_optional: Input[int] = Input(optional=True) + out: Output[int] = Output() + + def run(self): + a = self.inp.receive() + b = self.inp_default.receive() + if self.inp_optional.ready(): + c = self.inp_optional.receive() + self.out.send(a + b + c) + + return TestNode(parent=mock_parent) + + +@pytest.fixture +def node_with_channel(node): + node.inp.set_channel(MockChannel()) + node.out.set_channel(MockChannel()) + node.logger = setup_build_logging(name="test") + return node + + +@pytest.fixture +def node_with_run_fail(node_with_channel): + def fail(self): + raise Exception("test") + node_with_channel.run = MethodType(fail, node_with_channel) + return node_with_channel + + +@pytest.fixture +def node_with_run_fail_first(node_with_channel): + def fail(self): + i = self.i + self.i += 1 + if i == 0: + raise Exception("test") + node_with_channel.i = 0 + node_with_channel.run = MethodType(fail, node_with_channel) + return node_with_channel + + +@pytest.fixture +def node_with_run_int(node_with_channel): + def fail(self): + raise KeyboardInterrupt("test") + node_with_channel.run = MethodType(fail, node_with_channel) + return node_with_channel + + +@pytest.fixture +def invalid_node_class(): + class InvalidTestNode(Node): + inp: Input[Any] = Input() + + return InvalidTestNode + + +@pytest.fixture +def invalid_node_class_no_ports(): + class InvalidTestNode(Node): + def build(self): + pass + + def run(self): + pass + + return InvalidTestNode + + +class Test_Node: + def test_init(self, node, mock_parent): + assert len(node.inputs) == 1 + assert len(node.outputs) == 1 + assert len(node.ports) == 2 + assert len(node.parameters) == 8 + assert node.data.name == "data" + assert node.file.name == "file" + assert not node.ports_active() + assert node.parent is mock_parent + assert node.status == Status.READY + + def test_setup(self, node): + node.setup_directories() + assert Path(f"node-{node.name}").exists() + + def test_setup_path(self, node, tmp_path): + node.setup_directories(tmp_path) + assert (tmp_path / f"node-{node.name}").exists() + + def test_init_fail_abc(self, invalid_node_class): + with pytest.raises(TypeError): + invalid_node_class() + + def test_init_fail_no_ports(self, invalid_node_class_no_ports): + with pytest.raises(NodeBuildException): + invalid_node_class_no_ports() + + def test_node_user_parameters(self, node): + assert "inp" in node.user_parameters + assert "data" in node.user_parameters + assert "file" in node.user_parameters + assert "python" not in node.user_parameters + assert "modules" not in node.user_parameters + assert "scripts" not in node.user_parameters + + def test_shutdown(self, node_with_channel): + node_with_channel._shutdown() + assert node_with_channel.status == Status.COMPLETED + + def test_ports(self, node_with_channel): + assert node_with_channel.ports_active() + + def test_run_command(self, node): + assert node.run_command("echo hello").returncode == 0 + val = SuccessValidator("hello") + assert node.run_command("echo hello", validators=[val]).returncode == 0 + with pytest.raises(ProcessError): + assert node.run_command("echo other", validators=[val]).returncode == 0 + + def test_run_multi(self, node): + for ret in node.run_multi(["echo hello", "echo foo"]): + assert ret.returncode == 0 + results = node.run_multi(["echo hello" for _ in range(6)], n_batch=2) + assert len(results) == 2 + for res in results: + assert res.returncode == 0 + val = SuccessValidator("hello") + assert node.run_multi(["echo hello"], validators=[val])[0].returncode == 0 + with pytest.raises(ProcessError): + assert node.run_multi(["echo other"], validators=[val])[0].returncode == 0 + + def test_loop(self, node_with_channel): + node_with_channel.max_loops = 2 + start = time.time() + for _ in node_with_channel._loop(step=1.0): + continue + total = time.time() - start + assert 1.5 < total < 2.5 + + def test_loop_shutdown(self, node_with_channel): + looper = node_with_channel._loop() + next(looper) + assert node_with_channel.ports_active() + node_with_channel.inp.close() + assert not node_with_channel.inp.active + try: + next(looper) + except StopIteration: + pass + assert node_with_channel.status == Status.STOPPED + + def test_prepare(self, node_with_channel): + node_with_channel.required_packages = ["numpy"] + node_with_channel.required_callables = ["echo"] + node_with_channel._prepare() + with pytest.raises(ModuleNotFoundError): + node_with_channel.required_packages = ["nonexistentpackage"] + node_with_channel._prepare() + with pytest.raises(NodeBuildException): + node_with_channel.required_packages = [] + node_with_channel.required_callables = ["idontexist"] + node_with_channel._prepare() + + def test_execute(self, node_with_channel): + node_with_channel.execute() + assert node_with_channel.status == Status.COMPLETED + assert not node_with_channel.ports_active() + assert not node_with_channel.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_channel._message_queue.empty(): + updates.append(node_with_channel._message_queue.get()) + summary = updates[-1] + assert summary.name == node_with_channel.name + assert summary.status == Status.COMPLETED + + def test_execute_run_fail(self, node_with_run_fail): + node_with_run_fail.execute() + assert node_with_run_fail.status == Status.FAILED + assert not node_with_run_fail.ports_active() + assert node_with_run_fail.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_run_fail._message_queue.empty(): + updates.append(node_with_run_fail._message_queue.get()) + summary = updates[-1] + assert isinstance(updates[-2].exception, TracebackException) + assert summary.name == node_with_run_fail.name + assert summary.status == Status.FAILED + + def test_execute_inactive(self, node_with_channel): + node_with_channel.active.set(False) + node_with_channel.execute() + assert node_with_channel.status == Status.STOPPED + assert not node_with_channel.ports_active() + assert not node_with_channel.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_channel._message_queue.empty(): + updates.append(node_with_channel._message_queue.get()) + summary = updates[-1] + assert summary.name == node_with_channel.name + assert summary.status == Status.STOPPED + + def test_execute_run_int(self, node_with_run_int): + with pytest.raises(KeyboardInterrupt): + node_with_run_int.execute() + assert node_with_run_int.status == Status.STOPPED + assert not node_with_run_int.ports_active() + assert node_with_run_int.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_run_int._message_queue.empty(): + updates.append(node_with_run_int._message_queue.get()) + summary = updates[-1] + assert summary.name == node_with_run_int.name + assert summary.status == Status.STOPPED + + def test_execute_run_fail_ok(self, node_with_run_fail): + node_with_run_fail.fail_ok = True + node_with_run_fail.execute() + assert node_with_run_fail.status == Status.FAILED + assert not node_with_run_fail.ports_active() + assert not node_with_run_fail.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_run_fail._message_queue.empty(): + updates.append(node_with_run_fail._message_queue.get()) + summary = updates[-1] + assert summary.name == node_with_run_fail.name + assert summary.status == Status.FAILED + + def test_execute_run_fail_2_attempts(self, node_with_run_fail_first): + node_with_run_fail_first.n_attempts = 3 + node_with_run_fail_first.execute() + assert node_with_run_fail_first.status == Status.COMPLETED + assert not node_with_run_fail_first.ports_active() + assert not node_with_run_fail_first.signal.is_set() + time.sleep(0.5) + updates = [] + while not node_with_run_fail_first._message_queue.empty(): + updates.append(node_with_run_fail_first._message_queue.get()) + summary = updates[-1] + assert summary.name == node_with_run_fail_first.name + assert summary.status == Status.COMPLETED diff --git a/tests/core/test_ports.py b/tests/core/test_ports.py new file mode 100644 index 0000000..435b87e --- /dev/null +++ b/tests/core/test_ports.py @@ -0,0 +1,637 @@ +"""Port testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +from multiprocessing import Process +from pathlib import Path +import time +from typing import Annotated, Any + +import numpy as np +from numpy.typing import NDArray +import pytest + +from maize.core.interface import ( + PortInterrupt, + Interface, + Suffix, + Parameter, + FileParameter, + MultiParameter, + ParameterException, + MultiInput, + MultiOutput, + PortException, +) + + +@pytest.fixture +def annotated(): + return Annotated[Path, Suffix("abc", "xyz")] + + +@pytest.fixture +def annotated_list(): + return list[Annotated[Path, Suffix("abc", "xyz")]] + + +@pytest.fixture +def interface(): + return Interface() + + +@pytest.fixture +def interface_typed(): + return Interface[int]() + + +@pytest.fixture +def interface_annotated(annotated): + return Interface[annotated]() + + +@pytest.fixture +def parameter(mock_component): + return Parameter[int](default=42).build(name="test", parent=mock_component) + + +@pytest.fixture +def parameter2(mock_component): + return Parameter[int](default=42).build(name="test2", parent=mock_component) + + +@pytest.fixture +def parameter_array(mock_component): + return Parameter[NDArray[Any]](default=np.array([1, 2, 3])).build( + name="test2", parent=mock_component + ) + + +@pytest.fixture +def parameter_nodefault(mock_component): + return Parameter[int]().build(name="test", parent=mock_component) + + +@pytest.fixture +def file_parameter(mock_component, annotated): + return FileParameter[annotated]().build(name="test", parent=mock_component) + + +@pytest.fixture +def file_list_parameter(mock_component, annotated_list): + return FileParameter[annotated_list]().build(name="test", parent=mock_component) + + +@pytest.fixture +def multi_parameter(mock_component, parameter, parameter2): + return MultiParameter(parameters=(parameter, parameter2), default=3).build( + "foo", parent=mock_component + ) + + +@pytest.fixture +def multi_parameter_hook(mock_component, parameter, parameter2): + return MultiParameter[str, int](parameters=(parameter, parameter2), hook=int).build( + "foo", parent=mock_component + ) + + +@pytest.fixture +def multi_parameter_hook2(mock_component, parameter, parameter2): + return MultiParameter[str, int](parameters=(parameter, parameter2), hook=int, default="5").build( + "foo", parent=mock_component + ) + + +@pytest.fixture +def multi_input(mock_component): + return MultiInput[int]().build(name="test", parent=mock_component) + + +@pytest.fixture +def multi_output(mock_component): + return MultiOutput[int]().build(name="test", parent=mock_component) + + +@pytest.fixture +def multi_output_fixed(mock_component): + return MultiOutput[int](n_ports=2).build(name="test", parent=mock_component) + + +@pytest.fixture +def suffix1(): + return Suffix("pdb", "gro") + + +@pytest.fixture +def suffix2(): + return Suffix("pdb") + + +@pytest.fixture +def suffix3(): + return Suffix("xyz") + + +class Test_Suffix: + def test_suffix(self, suffix1): + assert suffix1(Path("a.pdb")) + assert suffix1(Path("a.gro")) + assert not suffix1(Path("a.xyz")) + + def test_eq(self, suffix1, suffix2, suffix3, mock_component): + assert suffix1 != mock_component + assert suffix1 == suffix2 + assert suffix1 != suffix3 + + +class Test_Interface: + def test_build(self, interface, mock_component): + interface._target = "parameters" + inter = interface.build(name="test", parent=mock_component) + assert inter.datatype is None + assert mock_component.parameters["test"] == inter + + def test_typed(self, interface_typed, mock_component): + interface_typed._target = "parameters" + inter = interface_typed.build(name="test", parent=mock_component) + assert inter.datatype == int + + def test_typed_check(self, interface_typed, mock_component): + interface_typed._target = "parameters" + inter = interface_typed.build(name="test", parent=mock_component) + assert inter.check(42) + assert not inter.check("foo") + + def test_annotated(self, interface_annotated, mock_component, annotated): + interface_annotated._target = "parameters" + inter = interface_annotated.build(name="test", parent=mock_component) + assert inter.datatype == annotated + + def test_annotated_check(self, interface_annotated, mock_component): + interface_annotated._target = "parameters" + inter = interface_annotated.build(name="test", parent=mock_component) + assert inter.check(Path("file.abc")) + assert not inter.check(Path("file.pdb")) + assert inter.check("foo.abc") + assert not inter.check("foo") + + def test_path(self, interface, mock_component): + interface._target = "parameters" + inter = interface.build(name="test", parent=mock_component) + assert inter.path == ("mock", "test") + + def test_serialized(self, interface): + assert interface.serialized == {"type": "typing.Any", "kind": "Interface"} + + +class Test_Parameter: + def test_set(self, parameter): + assert parameter.is_set + parameter.set(17) + assert parameter.is_set + assert parameter.value == 17 + with pytest.raises(ValueError): + parameter.set("foo") + + def test_default(self, parameter_array): + assert parameter_array.is_set + assert parameter_array.is_default + parameter_array.set(np.array([1, 2, 3])) + assert parameter_array.is_default + assert all(parameter_array.value == np.array([1, 2, 3])) + + def test_set_none(self, parameter_nodefault): + assert not parameter_nodefault.is_set + parameter_nodefault.set(17) + assert parameter_nodefault.is_set + + def test_changed(self, parameter): + assert not parameter.changed + parameter.set(17) + assert parameter.changed + + def test_serialized(self, parameter): + assert parameter.serialized == { + "type": "", + "kind": "Parameter", + "default": 42, + "optional": True, + } + + def test_file(self, file_parameter, shared_datadir): + with pytest.raises(ParameterException): + file_parameter.filepath + file_parameter.set(Path("nonexistent.abc")) + assert file_parameter.is_set + with pytest.raises(FileNotFoundError): + file_parameter.filepath + file_parameter.set(shared_datadir / "testorigin.abc") + assert file_parameter.filepath + + def test_file_list(self, file_list_parameter, shared_datadir): + with pytest.raises(ParameterException): + file_list_parameter.filepath + file_list_parameter.set([Path("nonexistent.abc"), Path("other_nonexistent.xyz")]) + assert file_list_parameter.is_set + with pytest.raises(FileNotFoundError): + file_list_parameter.filepath + file_list_parameter.set([shared_datadir / "testorigin.abc"]) + assert file_list_parameter.filepath + + def test_file_serialized(self, file_parameter): + assert file_parameter.serialized["kind"] == "FileParameter" + assert file_parameter.serialized["default"] is None + assert not file_parameter.serialized["optional"] + assert file_parameter.serialized["exist_required"] + + def test_multi(self, multi_parameter, mock_component): + assert multi_parameter.default == 3 + assert multi_parameter.datatype == int + assert multi_parameter.is_set + assert multi_parameter.value == 3 + multi_parameter.set(10) + assert multi_parameter.value == 10 + assert multi_parameter._parameters[0].value == 10 + assert multi_parameter.parents == [mock_component, mock_component] + + def test_multi_as_dict(self, multi_parameter): + assert multi_parameter.as_dict() == { + "name": "foo", + "value": 3, + "type": "int", + "map": [{"mock": "test"}, {"mock": "test2"}], + } + + def test_multi_serialized(self, multi_parameter): + assert multi_parameter.serialized == { + "type": "", + "kind": "MultiParameter", + "default": 3, + "optional": True, + } + + def test_multi_bad(self, parameter, file_parameter): + with pytest.raises(ParameterException): + MultiParameter(parameters=(parameter, file_parameter), default=3) + + def test_multi_hook(self, multi_parameter_hook, mock_component): + assert multi_parameter_hook.default is None + assert not multi_parameter_hook.is_set + multi_parameter_hook.set("10") + assert multi_parameter_hook.value == "10" + assert multi_parameter_hook._parameters[0].value == 10 + assert multi_parameter_hook._parameters[1].value == 10 + assert multi_parameter_hook.parents == [mock_component, mock_component] + + def test_multi_hook2(self, multi_parameter_hook2, mock_component): + assert multi_parameter_hook2.default == "5" + assert multi_parameter_hook2.is_set + assert multi_parameter_hook2.value == "5" + assert multi_parameter_hook2._parameters[0].value == 5 + assert multi_parameter_hook2._parameters[1].value == 5 + multi_parameter_hook2.set("10") + assert multi_parameter_hook2.value == "10" + assert multi_parameter_hook2._parameters[0].value == 10 + assert multi_parameter_hook2._parameters[1].value == 10 + assert multi_parameter_hook2.parents == [mock_component, mock_component] + + +class Test_Output: + def test_connection(self, connected_output): + assert connected_output.connected + + def test_serialized(self, connected_output): + assert connected_output.serialized == { + "type": "None", + "kind": "Output", + "optional": False, + "mode": "copy", + } + + def test_active(self, connected_output): + assert connected_output.active + + def test_close(self, connected_output): + connected_output.close() + assert not connected_output.active + + def test_send(self, connected_output): + connected_output.send(42) + + def test_send_hook(self, connected_output): + connected_output.hook = lambda x: x + 2 + connected_output.send(42) + + def test_send_file(self, connected_file_output, datafile): + connected_file_output.send(datafile) + + def test_send_unconnected(self, unconnected_output): + with pytest.raises(PortException): + unconnected_output.send(42) + + def test_send_closed(self, connected_output): + connected_output.close() + with pytest.raises(PortInterrupt): + connected_output.send(42) + + +class Test_Input: + def test_connection(self, connected_input): + assert connected_input.connected + + def test_serialized(self, connected_input): + assert connected_input.serialized == { + "type": "None", + "kind": "Input", + "optional": False, + "mode": "copy", + "default": None, + "cached": False + } + + def test_no_connection(self, unconnected_input): + assert not unconnected_input.connected + + def test_set(self, unconnected_input): + unconnected_input.set(42) + assert unconnected_input.value == 42 + unconnected_input.datatype = str + with pytest.raises(ValueError): + unconnected_input.set(42) + + def test_active(self, connected_input): + assert connected_input.active + + def test_close(self, connected_input): + connected_input.close() + assert connected_input.active + assert connected_input.ready() + connected_input.receive() + assert not connected_input.active + + def test_inactive(self, unconnected_input): + assert not unconnected_input.active + with pytest.raises(PortException): + unconnected_input.receive() + + def test_preload(self, connected_input): + assert connected_input.receive() == 42 + assert not connected_input.ready() + connected_input.preload(42) + assert connected_input.ready() + connected_input.datatype = int + with pytest.raises(ValueError): + connected_input.preload("foo") + + def test_preload_hook(self, connected_input): + connected_input.hook = lambda x: x + 2 + connected_input.preload(42) + assert connected_input.ready() + assert connected_input.receive() == 44 + + def test_default_factory(self, connected_input_default_factory): + assert connected_input_default_factory.ready() + assert connected_input_default_factory.receive() == [17] + + def test_dump(self, connected_input): + assert connected_input.dump() == [42] + + def test_dump_hook(self, connected_input): + connected_input.hook = lambda x: x + 2 + assert connected_input.dump() == [44] + + def test_dump_unconnected(self, unconnected_input): + with pytest.raises(PortException): + unconnected_input.dump() + + def test_receive(self, connected_input): + assert connected_input.receive() == 42 + + def test_receive_cached(self, connected_input): + connected_input.cached = True + assert connected_input.receive() == 42 + assert connected_input.receive() == 42 + + def test_receive_hook(self, connected_input): + connected_input.hook = lambda x: x + 2 + assert connected_input.receive() == 44 + + def test_receive_optional(self, connected_input): + connected_input.optional = True + connected_input.preload(None) + with pytest.raises(PortException): + connected_input.receive() + + def test_receive_optional_none(self, connected_input): + connected_input.optional = True + connected_input.preload(None) + assert connected_input.receive_optional() is None + + def test_receive_optional_unconnected(self, unconnected_input): + unconnected_input.optional = True + unconnected_input.set(None) + assert unconnected_input.receive_optional() is None + + def test_ready_receive(self, connected_input): + assert connected_input.ready() + assert connected_input.receive() == 42 + + def test_close_receive(self, connected_input): + connected_input.close() + assert connected_input.ready() + assert connected_input.receive() == 42 + assert not connected_input.ready() + + def test_receive_multi_attempt(self, connected_input): + assert connected_input.receive() == 42 + connected_input.close() + with pytest.raises(PortInterrupt): + connected_input.receive() + + def test_receive_multi_success(self, connected_input_multi): + assert connected_input_multi.receive() == 42 + assert connected_input_multi.receive() == -42 + assert not connected_input_multi.ready() + + def test_file_receive(self, connected_file_input): + assert connected_file_input.receive().exists() + + def test_file_ready_receive(self, connected_file_input): + assert connected_file_input.ready() + assert connected_file_input.receive().exists() + + def test_file_close_receive(self, connected_file_input): + connected_file_input.close() + assert connected_file_input.ready() + assert connected_file_input.receive().exists() + assert not connected_file_input.ready() + + def test_default(self, connected_input_default): + assert connected_input_default.is_set + assert connected_input_default.is_default + assert not connected_input_default.changed + assert connected_input_default.active + assert connected_input_default.ready() + assert connected_input_default.receive() == 42 + assert not connected_input_default.ready() + with pytest.raises(ParameterException): + connected_input_default.set(1) + + def test_unconnected_default(self, unconnected_input_default): + assert unconnected_input_default.is_set + assert unconnected_input_default.is_default + assert not unconnected_input_default.changed + assert unconnected_input_default.active + assert unconnected_input_default.ready() + assert unconnected_input_default.receive() == 17 + assert unconnected_input_default.ready() + assert unconnected_input_default.receive() == 17 + unconnected_input_default.set(39) + assert unconnected_input_default.receive() == 39 + assert not unconnected_input_default.is_default + assert unconnected_input_default.changed + assert unconnected_input_default.active + + +class Test_both: + def test_send_receive(self, connected_pair): + inp, out = connected_pair + out.send(42) + assert inp.receive() == 42 + + def test_send_receive_full(self, connected_pair): + inp, out = connected_pair + + def send(): + for _ in range(4): + out.send(42) + + p = Process(target=send) + p.start() + time.sleep(5) + assert inp.receive() == 42 + assert inp.receive() == 42 + assert inp.receive() == 42 + assert inp.receive() == 42 + p.join() + + def test_send_receive_wait(self, connected_pair): + inp, out = connected_pair + + def recv(): + inp.receive() + + p = Process(target=recv) + p.start() + time.sleep(5) + out.send(42) + p.join() + + def test_send_receive_hook(self, connected_pair): + inp, out = connected_pair + inp.hook = out.hook = lambda x: x + 2 + out.send(42) + assert inp.receive() == 46 + + def test_send_receive_multi(self, connected_pair): + inp, out = connected_pair + out.send(42) + out.send(101010) + out.send(-42) + assert inp.receive() == 42 + assert inp.receive() == 101010 + assert inp.ready() + assert inp.receive() == -42 + assert not inp.ready() + + def test_file_send_receive(self, connected_file_pair, datafile): + inp, out = connected_file_pair + out.send(datafile) + assert inp.ready() + assert inp.receive().exists() + assert not inp.ready() + + +class Test_MultiInput: + def test_properties(self, multi_input, empty_channel, loaded_channel): + assert not multi_input.connected + multi_input.set_channel(empty_channel) + multi_input.set_channel(loaded_channel) + assert multi_input.connected + assert multi_input[0].channel is empty_channel + assert multi_input[1].channel is loaded_channel + assert len(multi_input) == 2 + assert multi_input.default is None + + def test_serialized(self, multi_input): + assert multi_input.serialized == { + "type": "", + "kind": "MultiInput", + "n_ports": 0, + "optional": False, + "mode": "copy", + } + + def test_receive(self, multi_input, loaded_datachannel2): + multi_input.set_channel(loaded_datachannel2) + multi_input.set_channel(loaded_datachannel2) + assert multi_input[0].receive() == 42 + assert multi_input[1].receive() == -42 + + def test_receive_set(self, multi_input): + multi_input.set(39) + for port in multi_input: + assert port.receive() == 39 + + def test_close(self, multi_input, loaded_channel): + multi_input.set_channel(loaded_channel) + multi_input.set_channel(loaded_channel) + assert multi_input[0].active + assert multi_input[1].active + multi_input.close() + assert multi_input[0].active + assert multi_input[1].active + multi_input[0].receive() + assert not multi_input[0].active + assert not multi_input[1].active + + def test_dump(self, multi_input, loaded_datachannel2): + multi_input.set_channel(loaded_datachannel2) + multi_input.set_channel(loaded_datachannel2) + assert multi_input.dump() == [[42, -42], []] + + def test_preload(self, multi_input, empty_datachannel): + multi_input.set_channel(empty_datachannel) + multi_input.set_channel(empty_datachannel) + multi_input.preload([1, 2]) + assert multi_input[0].receive() == 1 + assert multi_input[1].receive() == 2 + + +class Test_MultiOutput: + def test_properties(self, multi_output, empty_channel, loaded_channel): + assert not multi_output.connected + multi_output.set_channel(empty_channel) + multi_output.set_channel(loaded_channel) + assert multi_output.connected + assert multi_output[0].channel is empty_channel + assert multi_output[1].channel is loaded_channel + assert len(multi_output) + + def test_send(self, multi_output, empty_datachannel): + multi_output.set_channel(empty_datachannel) + multi_output.set_channel(empty_datachannel) + multi_output[0].send(42) + multi_output[1].send(42) + assert empty_datachannel.receive() == 42 + assert empty_datachannel.receive() == 42 + + def test_close(self, multi_output, loaded_channel): + multi_output.set_channel(loaded_channel) + multi_output.set_channel(loaded_channel) + assert multi_output[0].active + assert multi_output[1].active + multi_output.close() + assert not multi_output[0].active + assert not multi_output[1].active diff --git a/tests/core/test_workflow.py b/tests/core/test_workflow.py new file mode 100644 index 0000000..100e561 --- /dev/null +++ b/tests/core/test_workflow.py @@ -0,0 +1,551 @@ +"""Workflow execution testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import, unused-variable, unused-argument + +from dataclasses import dataclass +import logging +from pathlib import Path +import random +import sys +import time +from typing import Annotated + +import networkx +import numpy as np +import pytest + +from maize.core.interface import ( + Input, + Output, + Parameter, + FileParameter, + MultiInput, + MultiOutput, + Suffix, +) +from maize.core.graph import Graph +from maize.core.node import Node +from maize.core.runtime import DEFAULT_CONTEXT, NodeException +from maize.core.workflow import Workflow, FutureWorkflowResult, expose, wait_for_all +from maize.steps.plumbing import Merge, Delay +from maize.steps.io import LoadData, Return, LoadFile, SaveFile +from maize.utilities.execution import JobResourceConfig, WorkflowStatus, check_executable + + +ENV_EXEC = Path(sys.executable).parents[2] / "maize-test" / "bin" / "python" + + +class A(Node): + out = Output[int]() + val = Parameter[int]() + file = FileParameter[Annotated[Path, Suffix("pdb")]]() + flag = Parameter[bool]() + + def run(self): + self.out.send(self.val.value) + + +@pytest.fixture +def example_a(): + return A + + +class Aany(Node): + out = Output() + val = Parameter() + file = FileParameter() + flag = Parameter() + + def run(self): + self.out.send(self.val.value) + + +@pytest.fixture +def example_a_any(): + return Aany + + +class Aenv(Node): + out = Output[int]() + val = Parameter[int](default=42) + file = FileParameter[Annotated[Path, Suffix("pdb")]]() + flag = Parameter[bool]() + + def run(self): + import scipy + + self.out.send(self.val.value) + + +@pytest.fixture +def example_a_env(): + return Aenv + + +class B(Node): + fail: bool = False + inp = Input[int]() + out = Output[int]() + out_final = Output[int]() + + def run(self): + if self.fail: + self.fail = False + raise RuntimeError("This is a test exception") + + val = self.inp.receive() + self.logger.debug("%s received %s", self.name, val) + if val > 48: + self.logger.debug("%s stopping", self.name) + self.out_final.send(val) + return + self.out.send(val + 2) + + +@pytest.fixture +def example_b(): + return B + + +@pytest.fixture +def loopable_subgraph(example_b): + class Loopable(Graph): + def build(self) -> None: + dela = self.add(Delay[int], parameters={"delay": 1}) + b = self.add(example_b) + self.connect(dela.out, b.inp) + self.map(dela.inp, b.out, b.out_final) + + return Loopable + + +class SubSubGraph(Graph): + def build(self): + a = self.add(A, "a", parameters=dict(val=36)) + d = self.add(Delay[int], "delay", parameters=dict(delay=1)) + self.connect(a.out, d.inp) + self.out = self.map_port(d.out, "out") + self.combine_parameters(a.val, name="val") + + +@pytest.fixture +def subsubgraph(example_a): + return SubSubGraph + + +class Random(Node): + """Perform a random linear combination of inputs and send results to outputs.""" + + inp = MultiInput[float](optional=True) + out = MultiOutput[float](optional=True) + fail = Parameter[bool](default=False) + + def run(self): + n_inputs = len(self.inp) + n_outputs = len(self.out) + # Starting nodes + if n_inputs == 0: + ins = np.array([1]) + else: + ins = np.array([ip.receive() for ip in self.inp]) + weights = np.random.random(size=(max(n_outputs, 1), max(n_inputs, 1))) + res = weights @ ins + time.sleep(random.random()) + if self.fail.value and (random.random() < 0.01): + raise Exception("fail test") + for i, out in enumerate(self.out): + if out.active: + out.send(res[i]) + + +@pytest.fixture +def random_node(): + return Random + + +class RandomLoop(Node): + inp = MultiInput[float](optional=True) + out = MultiOutput[float](optional=True) + fail = Parameter[bool](default=False) + + def run(self): + n_inputs = len(self.inp) + n_outputs = len(self.out) + ins = np.ones(max(1, n_inputs)) + for i, inp in enumerate(self.inp): + if inp.ready(): + self.logger.debug("Receiving via input %s", i) + ins[i] = inp.receive() + weights = np.random.random(size=(max(n_outputs, 1), max(n_inputs, 1))) + res = weights @ ins + time.sleep(random.random()) + if self.fail.value and (random.random() < 0.1): + raise RuntimeError("This is a test exception") + if random.random() < 0.05: + return + + for i, out in enumerate(self.out): + if out.active: + self.logger.debug("Sending via output %s", i) + out.send(res[i]) + + +@pytest.fixture +def random_loop_node(): + return RandomLoop + + +class RandomResources(Node): + inp = MultiInput[float](optional=True) + out = MultiOutput[float](optional=True) + fail = Parameter[bool](default=False) + + def run(self): + n_inputs = len(self.inp) + n_outputs = len(self.out) + # Starting nodes + if n_inputs == 0: + ins = np.array([1]) + else: + ins = np.array([inp.receive() for inp in self.inp]) + + with self.cpus(8): + weights = np.random.random(size=(max(n_outputs, 1), max(n_inputs, 1))) + res = weights @ ins + time.sleep(random.random()) + if self.fail.value and (random.random() < 0.1): + raise Exception("fail test") + for i, out in enumerate(self.out): + if out.active: + out.send(res[i]) + + +@pytest.fixture +def random_resource_node(): + return RandomResources + + +@pytest.fixture +def graph_mp_fixed_file(shared_datadir): + return shared_datadir / "graph-mp-fixed.yaml" + + +@pytest.fixture +def simple_data(): + return 42 + + +@dataclass +class TestData: + x: int + y: list[str] + z: dict[tuple[int, int], bytes] + + +@pytest.fixture +def normal_data(): + return TestData(x=39, y=["hello", "göteborg"], z={(0, 1): b"foo"}) + + +@pytest.fixture +def weird_data(): + return {("foo", 4j): lambda x: print(x + 1)} + + +@pytest.fixture(params=[10, 50, 100]) +def random_dag_gen(request): + def _random_dag(fail, node_type): + dag = networkx.gnc_graph(request.param) + g = Workflow() + for i in range(dag.number_of_nodes()): + node = g.add(node_type, str(i), parameters=dict(fail=fail)) + + for sen, rec in dag.edges: + g.connect(sending=g.nodes[str(sen)].out, receiving=g.nodes[str(rec)].inp) + + term = g.add(Return[float], "term") + for node in g.nodes.values(): + if "out" in node.outputs: + g.connect(sending=node.outputs["out"], receiving=term.inp) + break + + g.ret = term.ret_queue + print({k: {k: v.connected for k, v in n.ports.items()} for k, n in g.nodes.items()}) + return g + + return _random_dag + + +@pytest.fixture(params=[10, 50, 100]) +def random_dcg_gen(request, random_loop_node): + def _random_dcg(fail): + dcg = networkx.random_k_out_graph(request.param, k=2, alpha=0.2) + g = Workflow(level=logging.DEBUG) + for i in range(dcg.number_of_nodes()): + _ = g.add(random_loop_node, str(i), loop=True, parameters=dict(fail=fail)) + + for sen, rec, _ in dcg.edges: + g.logger.info("Connecting %s -> %s", sen, rec) + g.connect( + sending=g.nodes[str(sen)].out, + receiving=g.nodes[str(rec)].inp, + size=request.param * 5, + ) + + return g + + return _random_dcg + + +class Test_workflow: + def test_register(self): + Workflow.register(name="test", factory=lambda: Workflow()) + assert Workflow.get_workflow_summary("test") == "" + assert callable(Workflow.get_available_workflows().pop()) + assert Workflow.from_name("test") + with pytest.raises(KeyError): + Workflow.from_name("nope") + + def test_template(self, random_dag_gen, random_node): + wf = random_dag_gen(fail=False, node_type=random_node) + assert wf.generate_config_template() == "" + + def test_expose(self, simple_data, example_a, mocker, shared_datadir): + @expose + def flow() -> Workflow: + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=simple_data)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + g.map(a.val, a.file, a.flag) + return g + + mocker.patch("sys.argv", ["testing", "--val", "foo", "--check"]) + with pytest.raises(SystemExit): + flow() + + file = Path("file.pdb") + file.touch() + mocker.patch( + "sys.argv", ["testing", "--val", "17", "--file", file.as_posix(), "--flag", "--check"] + ) + flow() + + +class Test_graph_run: + def test_single_execution_channel_simple(self, simple_data, example_a): + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=simple_data)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + g.execute() + assert t.get() == simple_data + + def test_single_execution_channel_simple_file(self, node_with_file): + g = Workflow(cleanup_temp=False, level="debug") + g.config.scratch = Path("./") + l = g.add(LoadData[int], parameters={"data": 42}, loop=True, max_loops=3) + a = g.add(node_with_file, loop=True) + g.connect(l.out, a.inp) + g.execute() + assert g.work_dir.exists() + assert (g.work_dir / "node-loaddata").exists() + with pytest.raises(StopIteration): + next((g.work_dir / "node-loaddata").iterdir()).exists() + + def test_single_execution_channel_normal(self, normal_data, example_a_any): + g = Workflow() + a = g.add(example_a_any, "a", parameters=dict(val=normal_data)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + g.execute() + assert t.get() == normal_data + + @pytest.mark.skipif( + DEFAULT_CONTEXT == "spawn", + reason=( + "Acceptable datatypes for channels are restricted " + "when using multiprocessing 'spawn' context" + ), + ) + def test_single_execution_channel_weird(self, weird_data, example_a_any): + g = Workflow() + a = g.add(example_a_any, "a", parameters=dict(val=weird_data)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + g.execute() + assert t.get().keys() == weird_data.keys() + + @pytest.mark.skipif( + not ENV_EXEC.exists(), + reason=( + "Testing alternative environment execution requires " + "the `maize-test` environment to be installed" + ), + ) + def test_single_execution_alt_env(self, example_a_env, simple_data): + g = Workflow(level=logging.DEBUG) + a = g.add(example_a_env, "a", parameters=dict(python=ENV_EXEC, val=simple_data)) + t = g.add(Return[int], "term") + g.connect(a.out, t.inp) + g.execute() + assert t.get() == simple_data + + def test_single_execution_complex_graph(self, example_a, example_b): + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=40)) + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(a.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + g.execute() + assert t.get() == 50 + + def test_single_execution_complex_graph_with_subgraph(self, subsubgraph, example_b): + g = Workflow(level="DEBUG") + a = g.add(subsubgraph, "a", parameters=dict(val=40)) + b = g.add(example_b, "b", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(a.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + g.execute() + assert t.get() == 50 + + def test_execution_subgraph_looped(self, loopable_subgraph, example_a): + g = Workflow(level="DEBUG") + a = g.add(example_a, "a", parameters={"val": 36}) + sg = g.add(loopable_subgraph, "sg", loop=True) + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(a.out, m.inp) + g.connect(sg.out, m.inp) + g.connect(m.out, sg.inp) + g.connect(sg.out_final, t.inp) + g.execute() + assert t.get() == 50 + + def test_multi_execution_complex_graph(self, example_a, example_b): + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=40)) + b = g.add(example_b, "b", n_attempts=2, loop=True) + b.fail = True + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(a.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + g.execute() + assert t.get() == 50 + + def test_multi_execution_complex_graph_fail(self, example_a, example_b): + g = Workflow() + a = g.add(example_a, "a", parameters=dict(val=40)) + b = g.add(example_b, "b", n_attempts=1, loop=True) + b.fail = True + m = g.add(Merge[int], "m") + t = g.add(Return[int], "t") + g.connect(a.out, m.inp) + g.connect(b.out, m.inp) + g.connect(m.out, b.inp) + g.connect(b.out_final, t.inp) + with pytest.raises(NodeException): + g.execute() + + @pytest.mark.skipif( + not check_executable("sinfo"), + reason="Testing slurm requires a functioning Slurm batch system", + ) + def test_submit(self, tmp_path, shared_datadir): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / "test.abc"}) + flow.connect(data.out, out.inp) + res = flow.submit(folder=tmp_path, config=JobResourceConfig(walltime="00:02:00")) + + assert res.id + assert res.query() in (WorkflowStatus.QUEUED, WorkflowStatus.RUNNING) + assert res.wait() == WorkflowStatus.COMPLETED + assert (tmp_path / "test.abc").exists() + + @pytest.mark.skipif( + not check_executable("sinfo"), + reason="Testing slurm requires a functioning Slurm batch system", + ) + def test_submit_wait_all(self, tmp_path, shared_datadir): + res = {} + for name in ("a", "b"): + flow = Workflow(name=name) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / f"test-{name}.abc"}) + flow.connect(data.out, out.inp) + res[name] = flow.submit(folder=tmp_path / name, config=JobResourceConfig(walltime="00:02:00")) + + wait_for_all(list(res.values())) + for name in ("a", "b"): + assert res[name].id + assert res[name].wait() == WorkflowStatus.COMPLETED + assert (tmp_path / f"test-{name}.abc").exists() + + @pytest.mark.skipif( + not check_executable("sinfo"), + reason="Testing slurm requires a functioning Slurm batch system", + ) + def test_submit_cancel(self, tmp_path, shared_datadir): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + delay = flow.add(Delay[Path], parameters={"delay": 60}) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / "test.abc"}) + flow.connect(data.out, delay.inp) + flow.connect(delay.out, out.inp) + res = flow.submit(folder=tmp_path, config=JobResourceConfig(walltime="00:02:00")) + + assert res.id + assert res.query() in (WorkflowStatus.QUEUED, WorkflowStatus.RUNNING) + res.cancel() + time.sleep(10) # Wait a bit for slow queueing systems + assert res.query() == WorkflowStatus.CANCELLED + assert not (tmp_path / "test.abc").exists() + + @pytest.mark.skipif( + not check_executable("sinfo"), + reason="Testing slurm requires a functioning Slurm batch system", + ) + def test_submit_serdes(self, tmp_path, shared_datadir): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + delay = flow.add(Delay[Path], parameters={"delay": 30}) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / "test.abc"}) + flow.connect(data.out, delay.inp) + flow.connect(delay.out, out.inp) + res = flow.submit(folder=tmp_path, config=JobResourceConfig(walltime="00:02:00")) + + assert res.id + assert res.query() in (WorkflowStatus.QUEUED, WorkflowStatus.RUNNING) + serdes = FutureWorkflowResult.from_dict(res.to_dict()) + assert serdes.query() in (WorkflowStatus.QUEUED, WorkflowStatus.RUNNING) + assert serdes.wait() == WorkflowStatus.COMPLETED + assert (tmp_path / "test.abc").exists() + + @pytest.mark.random + def test_random_dags(self, random_dag_gen, random_node): + random_dag_gen(fail=False, node_type=random_node).execute() + + @pytest.mark.random + def test_random_dags_resources(self, random_dag_gen, random_resource_node): + random_dag_gen(fail=False, node_type=random_resource_node).execute() + + @pytest.mark.random + def test_random_dcgs(self, random_dcg_gen): + random_dcg_gen(fail=False).execute() + + @pytest.mark.random + def test_random_dcgs_with_fail(self, random_dcg_gen): + with pytest.raises(NodeException): + random_dcg_gen(fail=True).execute() diff --git a/tests/data/testorigin.abc b/tests/data/testorigin.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/testorigin2.abc b/tests/data/testorigin2.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/steps/__init__.py b/tests/steps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/steps/data/testorigin.abc b/tests/steps/data/testorigin.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/steps/data/testorigin2.abc b/tests/steps/data/testorigin2.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/steps/test_io.py b/tests/steps/test_io.py new file mode 100644 index 0000000..ee848bd --- /dev/null +++ b/tests/steps/test_io.py @@ -0,0 +1,136 @@ +"""Tests for IO nodes""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init + +from pathlib import Path +import shutil + +import dill +import pytest + +from maize.utilities.testing import TestRig +from maize.steps.io import ( + Dummy, + FileBuffer, + LoadFile, + LoadData, + LoadFiles, + Log, + LogResult, + SaveFile, + Return, + SaveFiles, + Void, +) + + +@pytest.fixture +def simple_file(shared_datadir): + return shared_datadir / "testorigin.abc" + + +@pytest.fixture +def simple_file2(shared_datadir): + return shared_datadir / "testorigin2.abc" + + +class Test_io: + def test_dummy(self): + t = TestRig(Dummy) + t.setup_run() + + def test_void(self): + t = TestRig(Void) + t.setup_run(inputs={"inp": 42}) + + @pytest.mark.skip(reason="Fails when run together with all other tests, passes alone") + def test_loadfile(self, simple_file): + t = TestRig(LoadFile[Path]) + out = t.setup_run(parameters={"file": simple_file}) + assert out["out"].get() == simple_file + with pytest.raises(FileNotFoundError): + t.setup_run(parameters={"file": Path("nofile")}) + + def test_loadfiles(self, simple_file, simple_file2): + t = TestRig(LoadFiles[Path]) + out = t.setup_run(parameters={"files": [simple_file, simple_file2]}) + assert out["out"].get() == [simple_file, simple_file2] + with pytest.raises(FileNotFoundError): + t.setup_run(parameters={"files": [Path("nofile"), simple_file]}) + + def test_loaddata(self): + t = TestRig(LoadData[int]) + out = t.setup_run(parameters={"data": 42}) + assert out["out"].get() == 42 + + def test_logresult(self): + t = TestRig(LogResult) + t.setup_run(inputs={"inp": 42}) + + def test_log(self): + t = TestRig(Log[int]) + out = t.setup_run(inputs={"inp": 42}) + assert out["out"].get() == 42 + + def test_log2(self, simple_file): + t = TestRig(Log[Path]) + out = t.setup_run(inputs={"inp": simple_file}) + assert out["out"].get() == simple_file + + def test_savefile(self, simple_file, tmp_path): + t = TestRig(SaveFile[Path]) + t.setup_run( + inputs={"inp": simple_file}, parameters={"destination": tmp_path / simple_file.name} + ) + assert (tmp_path / simple_file.name).exists() + + def test_savefile_dir(self, simple_file, tmp_path): + t = TestRig(SaveFile[Path]) + t.setup_run(inputs={"inp": simple_file}, parameters={"destination": tmp_path}) + assert (tmp_path / simple_file.name).exists() + + def test_savefile_parent(self, simple_file, tmp_path): + t = TestRig(SaveFile[Path]) + t.setup_run( + inputs={"inp": simple_file}, + parameters={"destination": tmp_path / "folder" / simple_file.name}, + ) + assert (tmp_path / "folder" / simple_file.name).exists() + + def test_savefile_overwrite(self, simple_file, tmp_path): + dest = tmp_path / simple_file.name + shutil.copy(simple_file, dest) + first_time = dest.stat().st_mtime + t = TestRig(SaveFile[Path]) + t.setup_run( + inputs={"inp": simple_file}, parameters={"destination": dest, "overwrite": True} + ) + assert dest.stat().st_mtime > first_time + + def test_savefiles(self, simple_file, simple_file2, tmp_path): + t = TestRig(SaveFiles[Path]) + t.setup_run( + inputs={"inp": [[simple_file, simple_file2]]}, parameters={"destination": tmp_path} + ) + assert (tmp_path / simple_file.name).exists() + assert (tmp_path / simple_file2.name).exists() + + def test_filebuffer(self, simple_file, tmp_path): + t = TestRig(FileBuffer[Path]) + location = tmp_path / simple_file.name + t.setup_run(inputs={"inp": [simple_file]}, parameters={"file": location}) + assert location.exists() + + def test_filebuffer_receive(self, simple_file, tmp_path): + location = tmp_path / simple_file.name + shutil.copy(simple_file, location) + t = TestRig(FileBuffer[Path]) + res = t.setup_run(inputs={"inp": [simple_file]}, parameters={"file": location}) + file = res["out"].get() + assert file is not None + assert file.exists() + + def test_return(self): + t = TestRig(Return[int]) + t.setup_run(inputs={"inp": 42}) + assert dill.loads(t.node.ret_queue.get()) == 42 diff --git a/tests/steps/test_plumbing.py b/tests/steps/test_plumbing.py new file mode 100644 index 0000000..6f0ddb5 --- /dev/null +++ b/tests/steps/test_plumbing.py @@ -0,0 +1,236 @@ +"""Tests for plumbing nodes""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init + +import time + +import pytest + +from maize.utilities.testing import TestRig +from maize.steps.plumbing import ( + Accumulate, + Batch, + Choice, + Combine, + CopyEveryNIter, + Delay, + IndexDistribute, + IntegerMap, + Merge, + Copy, + MergeLists, + Multiplex, + Multiply, + RoundRobin, + Scatter, + Barrier, + TimeDistribute, + Yes, +) + + +class Test_plumbing: + def test_multiply(self): + t = TestRig(Multiply) + out = t.setup_run(inputs={"inp": [42]}, parameters={"n_packages": 3}) + assert out["out"].get() == [42, 42, 42] + + def test_yes(self): + t = TestRig(Yes) + out = t.setup_run(inputs={"inp": [42]}, max_loops=3) + assert out["out"].get() == 42 + assert out["out"].get() == 42 + assert out["out"].get() == 42 + + def test_barrier(self): + t = TestRig(Barrier) + out = t.setup_run(inputs={"inp": [[1, 2], [3, 4]], "inp_signal": [False, True]}) + assert out["out"].get() == [3, 4] + + def test_batch(self): + t = TestRig(Batch) + out = t.setup_run(inputs={"inp": [[1, 2, 3, 4]]}, parameters={"n_batches": 3}) + assert out["out"].get() == [1, 2] + assert out["out"].get() == [3] + assert out["out"].get() == [4] + + def test_combine(self): + t = TestRig(Combine) + out = t.setup_run(inputs={"inp": [[1, 2], [3], [4]]}, parameters={"n_batches": 3}) + assert out["out"].get() == [1, 2, 3, 4] + + def test_merge_lists(self): + t = TestRig(MergeLists) + out = t.setup_run(inputs={"inp": [[[1, 2]], [[3]], [[4, 5]]]}, max_loops=1) + assert out["out"].get() == [1, 2, 3, 4, 5] + + def test_merge(self): + t = TestRig(Merge) + out = t.setup_run(inputs={"inp": [1, 2]}, max_loops=1) + assert out["out"].get() == 1 + assert out["out"].get() == 2 + + def test_multiplex(self): + t = TestRig(Multiplex) + out = t.setup_run_multi( + inputs={"inp": [0, 1, 2], "inp_single": [3, 4, 5]}, n_outputs=3, max_loops=1 + ) + assert out["out_single"].get() == 0 + assert out["out"][0].get() == 3 + assert out["out_single"].get() == 1 + assert out["out"][1].get() == 4 + assert out["out_single"].get() == 2 + assert out["out"][2].get() == 5 + + def test_copy(self): + t = TestRig(Copy) + out = t.setup_run_multi(inputs={"inp": 1}, n_outputs=2, max_loops=1) + assert out["out"][0].get() == 1 + assert out["out"][1].get() == 1 + + def test_round_robin(self): + t = TestRig(RoundRobin) + out = t.setup_run_multi(inputs={"inp": [1, 2]}, n_outputs=2, max_loops=2) + assert out["out"][0].get() == 1 + assert out["out"][1].get() == 2 + + def test_integer_map(self): + t = TestRig(IntegerMap) + out = t.setup_run( + inputs={"inp": [0, 1, 2, 3]}, parameters={"pattern": [0, 2, -1]}, max_loops=4, loop=True + ) + assert out["out"].get() == 1 + assert out["out"].get() == 1 + assert out["out"].get() == 2 + assert out["out"].get() == 2 + + def test_choice(self): + t = TestRig(Choice) + out = t.setup_run( + inputs={"inp": ["foo", "bar"], "inp_index": [0]}, parameters={"clip": False} + ) + assert out["out"].get() == "foo" + + def test_choice2(self): + t = TestRig(Choice) + out = t.setup_run( + inputs={"inp": ["foo", "bar"], "inp_index": [1]}, parameters={"clip": False} + ) + assert out["out"].get() == "bar" + + def test_choice_clip(self): + t = TestRig(Choice) + out = t.setup_run( + inputs={"inp": ["foo", "bar"], "inp_index": [3]}, parameters={"clip": True} + ) + assert out["out"].get() == "bar" + + def test_index_distribute(self): + t = TestRig(IndexDistribute) + out = t.setup_run_multi( + inputs={"inp": ["foo"], "inp_index": [0]}, + parameters={"clip": False}, + n_outputs=3, + ) + assert out["out"][0].get() == "foo" + assert not out["out"][1].ready + assert not out["out"][2].ready + + def test_index_distribute2(self): + t = TestRig(IndexDistribute) + out = t.setup_run_multi( + inputs={"inp": ["foo"], "inp_index": [1]}, + parameters={"clip": True}, + n_outputs=3, + ) + assert not out["out"][0].ready + assert out["out"][1].get() == "foo" + assert not out["out"][2].ready + + def test_index_distribute_clip(self): + t = TestRig(IndexDistribute) + out = t.setup_run_multi( + inputs={"inp": ["foo"], "inp_index": [4]}, + parameters={"clip": True}, + n_outputs=3, + ) + assert not out["out"][0].ready + assert not out["out"][1].ready + assert out["out"][2].get() == "foo" + + def test_time_distribute(self): + t = TestRig(TimeDistribute) + out = t.setup_run_multi( + inputs={"inp": [1, 2, 3, 4, 5, 6, 7, 8]}, + parameters={"pattern": [2, 1, 5, 0]}, + n_outputs=4, + ) + for i in range(2): + assert out["out"][0].get() == i + 1 + assert out["out"][1].get() == 3 + for i in range(5): + assert out["out"][2].get() == 4 + i + assert not out["out"][3].flush() + + def test_time_distribute_inf(self): + t = TestRig(TimeDistribute) + out = t.setup_run_multi( + inputs={"inp": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, + parameters={"pattern": [2, 1, 4, -1]}, + n_outputs=4, + ) + for i in range(2): + assert out["out"][0].get() == i + 1 + assert out["out"][1].get() == 3 + for i in range(4): + assert out["out"][2].get() == 4 + i + for i in range(3): + assert out["out"][3].get() == 8 + i + assert not out["out"][3].flush() + + def test_time_distribute_cycle(self): + t = TestRig(TimeDistribute) + out = t.setup_run_multi( + inputs={"inp": [1, 2, 3, 1, 2, 3]}, + parameters={"pattern": [2, 1], "cycle": True}, + n_outputs=2, + max_loops=6, + ) + for _ in range(2): + assert out["out"][0].get() == 1 + assert out["out"][0].get() == 2 + assert out["out"][1].get() == 3 + + def test_copy_every_n_iter(self): + t = TestRig(CopyEveryNIter) + out = t.setup_run_multi( + inputs={"inp": [1, 2, 3, 4]}, parameters={"freq": 2}, n_outputs=2, max_loops=4 + ) + for i in range(4): + assert out["out"][0].get() == i + 1 + assert out["out"][1].get() == 2 + assert out["out"][1].get() == 4 + + def test_accumulate(self): + t = TestRig(Accumulate) + out = t.setup_run(inputs={"inp": [1, 2]}, parameters={"n_packets": 2}, max_loops=1) + assert out["out"].get() == [1, 2] + + def test_scatter(self): + t = TestRig(Scatter) + out = t.setup_run(inputs={"inp": [[1, 2]]}, max_loops=1) + assert out["out"].get() == 1 + assert out["out"].get() == 2 + + def test_scatter_fail(self): + t = TestRig(Scatter) + with pytest.raises(ValueError): + _ = t.setup_run(inputs={"inp": [1]}, max_loops=1) + + def test_delay(self): + t = TestRig(Delay) + start_time = time.time() + out = t.setup_run(inputs={"inp": 1}, parameters={"delay": 5.0}, max_loops=1) + end_time = time.time() - start_time + assert out["out"].get() == 1 + assert 4.5 < end_time < 10 diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..ed93f66 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,423 @@ +"""Miscallaneous testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +from pathlib import Path +import shutil +import time +import pytest + +from maize.core.component import Component +from maize.core.interface import Input, Output, Parameter +from maize.core.runtime import Status, setup_build_logging +from maize.core.node import Node +from maize.core.workflow import Workflow +from maize.steps.io import LoadData, LoadFile, LoadFiles, Log, Return, SaveFile, SaveFiles +from maize.steps.plumbing import Copy, Delay, RoundRobin, Scatter, Multiply +from maize.utilities.testing import TestRig, MockChannel +from maize.utilities.macros import node_to_function, function_to_node, parallel, lambda_node + + +@pytest.fixture +def mock_parent(): + return Component(name="parent") + + +@pytest.fixture +def node_hybrid(): + class TestNode(Node): + inp: Input[int] = Input() + inp_default: Input[int] = Input(default=17) + inp_optional: Input[int] = Input(optional=True) + out: Output[int] = Output() + + def run(self): + a = self.inp.receive() + b = self.inp_default.receive() + c = 0 + if self.inp_optional.ready(): + c = self.inp_optional.receive() + self.out.send(a + b + c) + + return TestNode + + +@pytest.fixture +def node_optional(): + class TestNode(Node): + inp: Input[Path] = Input(mode="copy") + inp_opt: Input[Path] = Input(mode="copy", optional=True) + out: Output[Path] = Output(mode="copy") + + def build(self): + super().build() + self.logger = setup_build_logging("build") + + def run(self) -> None: + file = self.inp.receive() + time.sleep(1) + if self.inp_opt.ready(): + file_extra = self.inp_opt.receive() + self.out.send(file_extra) + else: + self.out.send(file) + + return TestNode + + +@pytest.fixture +def node_receive_optional(): + class TestNode(Node): + inp: Input[Path] = Input(mode="copy") + inp_opt: Input[Path] = Input(mode="copy", optional=True) + out: Output[Path] = Output(mode="copy") + + def build(self): + super().build() + self.logger = setup_build_logging("build") + + def run(self) -> None: + file = self.inp.receive() + file_extra = self.inp_opt.receive_optional() + if file_extra is not None: + self.out.send(file_extra) + else: + self.out.send(file) + + return TestNode + + +@pytest.fixture +def node_input(): + class TestNode(Node): + inp: Input[Path] = Input(mode="copy", optional=True) + out: Output[Path] = Output(mode="copy") + + def build(self): + super().build() + self.logger = setup_build_logging("build") + + def run(self) -> None: + file = self.inp.receive() + self.out.send(file) + + return TestNode + + +@pytest.fixture +def node_optional_with_channel(node_optional, shared_datadir, mock_parent): + node = node_optional(parent=mock_parent) + channel = MockChannel(shared_datadir / "testorigin.abc") + node.inp.set_channel(channel) + node.out.set_channel(channel) + node.logger = setup_build_logging(name="test") + return node + + +class Test_interfaces: + def test_input(self, node_hybrid): + rig = TestRig(node_hybrid) + res = rig.setup_run(inputs={"inp": [42]}) + assert res["out"].get() == 42 + 17 + + rig = TestRig(node_hybrid) + res = rig.setup_run(inputs={"inp": [42], "inp_optional": [2]}) + assert res["out"].get() == 42 + 17 + 2 + + rig = TestRig(node_hybrid) + res = rig.setup_run(inputs={"inp": [42], "inp_default": [16], "inp_optional": [2]}) + assert res["out"].get() == 42 + 16 + 2 + + +def test_lambda_node(): + rig = TestRig(lambda_node(lambda x: x + 2)) + res = rig.setup_run(inputs={"inp": 42}) + assert res["out"].get() == 44 + + +def test_function_to_node(): + def func(a: int, b: bool = True, c: str = "foo") -> int: + return a + 1 if b else a + + rig = TestRig(function_to_node(func)) + res = rig.setup_run(inputs={"inp": [42]}) + assert res["out"].get() == 43 + + res = rig.setup_run(inputs={"inp": [42]}, parameters={"b": False}) + assert res["out"].get() == 42 + + +def test_node_to_function(node_hybrid): + func = node_to_function(node_hybrid) + assert func(inp=42, inp_default=17, inp_optional=0)["out"] == 42 + 17 + + +def test_multi_file_load_save(shared_datadir, tmp_path): + flow = Workflow(level="debug") + data = flow.add( + LoadFiles[Path], + parameters={ + "files": [shared_datadir / "testorigin.abc", shared_datadir / "testorigin2.abc"] + }, + ) + save = flow.add(SaveFiles[Path], parameters={"destination": tmp_path}) + flow.connect(data.out, save.inp) + flow.execute() + assert (tmp_path / "testorigin.abc").exists() + assert (tmp_path / "testorigin2.abc").exists() + + +def test_multi_file_load_copy_save(shared_datadir, tmp_path): + dest1 = tmp_path / "save1" + dest2 = tmp_path / "save2" + dest1.mkdir(), dest2.mkdir() + flow = Workflow(level="debug") + data = flow.add( + LoadFiles[Path], + parameters={ + "files": [shared_datadir / "testorigin.abc", shared_datadir / "testorigin2.abc"] + }, + ) + copy = flow.add(Copy[list[Path]]) + log = flow.add(Log[list[Path]]) + save1 = flow.add(SaveFiles[Path], name="save1", parameters={"destination": dest1}) + save2 = flow.add(SaveFiles[Path], name="save2", parameters={"destination": dest2}) + flow.connect(data.out, log.inp) + flow.connect(log.out, copy.inp) + flow.connect(copy.out, save1.inp) + flow.connect(copy.out, save2.inp) + flow.execute() + assert (dest1 / "testorigin.abc").exists() + assert (dest1 / "testorigin2.abc").exists() + assert (dest2 / "testorigin.abc").exists() + assert (dest2 / "testorigin2.abc").exists() + + +def test_parallel_multi_file(shared_datadir, tmp_path): + + class Dictionize(Node): + inp: Input[list[Path]] = Input() + out: Output[dict[int, Path]] = Output() + + def run(self) -> None: + files = self.inp.receive() + for i, file in enumerate(files): + self.out.send({i: file}) + + class DeDictionize(Node): + inp: Input[dict[int, Path]] = Input() + out: Output[list[Path]] = Output() + n: Parameter[int] = Parameter(default=2) + + def run(self) -> None: + files = [] + for _ in range(self.n.value): + files.extend(list(self.inp.receive().values())) + self.out.send(files) + + dest = tmp_path / "save" + dest.mkdir() + flow = Workflow(level="debug") + data = flow.add( + LoadFiles[Path], + parameters={ + "files": [shared_datadir / "testorigin.abc", shared_datadir / "testorigin2.abc"] + }, + ) + dicz = flow.add(Dictionize) + log = flow.add(parallel(Log[dict[int, Path]], n_branches=2, loop=True)) + dedi = flow.add(DeDictionize) + save = flow.add(SaveFiles[Path], parameters={"destination": dest}) + flow.connect(data.out, dicz.inp) + flow.connect(dicz.out, log.inp) + flow.connect(log.out, dedi.inp) + flow.connect(dedi.out, save.inp) + flow.execute() + assert (dest / "testorigin.abc").exists() + assert (dest / "testorigin2.abc").exists() + + +def test_parallel_file(shared_datadir, tmp_path): + flow = Workflow(level="debug") + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + mult = flow.add(Multiply[Path], parameters={"n_packages": 2}) + scat = flow.add(Scatter[Path]) + dela = flow.add(parallel(Log[Path], n_branches=2, loop=True)) + roro = flow.add(RoundRobin[Path]) + out1 = flow.add(SaveFile[Path], name="out1", parameters={"destination": tmp_path / "test1.abc"}) + out2 = flow.add(SaveFile[Path], name="out2", parameters={"destination": tmp_path / "test2.abc"}) + flow.connect(data.out, mult.inp) + flow.connect(mult.out, scat.inp) + flow.connect(scat.out, dela.inp) + flow.connect(dela.out, roro.inp) + flow.connect(roro.out, out1.inp) + flow.connect(roro.out, out2.inp) + flow.execute() + assert (tmp_path / "test1.abc").exists() + assert (tmp_path / "test2.abc").exists() + + +def test_optional_file(shared_datadir, tmp_path, node_optional): + flow = Workflow(level="debug", cleanup_temp=False) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + data_opt = flow.add(LoadFile[Path], name="opt") + data_opt.file.optional = True + test = flow.add(node_optional) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / "test.abc"}) + + flow.connect(data.out, test.inp) + flow.connect(data_opt.out, test.inp_opt) + flow.connect(test.out, out.inp) + flow.execute() + assert (tmp_path / "test.abc").exists() + + +def test_receive_optional_file(shared_datadir, tmp_path, node_receive_optional): + flow = Workflow(level="debug", cleanup_temp=False) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + data_opt = flow.add(LoadFile[Path], name="opt") + delay = flow.add(Delay[Path], name="del", parameters={"delay": 2}) + data_opt.file.optional = True + test = flow.add(node_receive_optional) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path}) + + flow.connect(data.out, test.inp) + flow.connect(data_opt.out, delay.inp) + flow.connect(delay.out, test.inp_opt) + flow.connect(test.out, out.inp) + flow.execute() + assert (tmp_path / "testorigin.abc").exists() + + +def test_receive_optional_file_alt(shared_datadir, tmp_path, node_receive_optional): + flow = Workflow(level="debug", cleanup_temp=False) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + data_opt = flow.add( + LoadFile[Path], name="opt", parameters={"file": shared_datadir / "testorigin2.abc"} + ) + delay = flow.add(Delay[Path], name="del", parameters={"delay": 2}) + test = flow.add(node_receive_optional) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path}) + + flow.connect(data.out, test.inp) + flow.connect(data_opt.out, delay.inp) + flow.connect(delay.out, test.inp_opt) + flow.connect(test.out, out.inp) + flow.execute() + assert (tmp_path / "testorigin2.abc").exists() + + +def test_optional_file_input(shared_datadir, tmp_path, node_optional, node_input): + flow = Workflow(level="debug", cleanup_temp=False) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + data_opt = flow.add(node_input, name="opt") + data_opt.inp.set(shared_datadir / "testorigin2.abc") + test = flow.add(node_optional) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path}) + + flow.connect(data.out, test.inp) + flow.connect(data_opt.out, test.inp_opt) + flow.connect(test.out, out.inp) + flow.execute() + assert (tmp_path / "testorigin2.abc").exists() + + +def test_execute_optional(node_optional_with_channel): + node_optional_with_channel.execute() + assert node_optional_with_channel.status == Status.STOPPED + assert not node_optional_with_channel.ports_active() + assert not node_optional_with_channel.signal.is_set() + + +def test_parallel_file_many(shared_datadir, tmp_path): + class _Test(Node): + inp: Input[Path] = Input(mode="copy") + out: Output[Path] = Output(mode="copy") + + def run(self) -> None: + file = self.inp.receive() + out = Path("local.abc") + shutil.copy(file, out) + self.out.send(out) + + n_files = 2 + flow = Workflow(level="debug", cleanup_temp=False) + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + mult = flow.add(Multiply[Path], parameters={"n_packages": n_files}) + scat = flow.add(Scatter[Path]) + dela = flow.add(parallel(_Test, n_branches=4, loop=True)) + roro = flow.add(RoundRobin[Path]) + for i in range(n_files): + out = flow.add( + SaveFile[Path], name=f"out{i}", parameters={"destination": tmp_path / f"test{i}.abc"} + ) + flow.connect(roro.out, out.inp) + + flow.connect(data.out, mult.inp) + flow.connect(mult.out, scat.inp) + flow.connect(scat.out, dela.inp) + flow.connect(dela.out, roro.inp) + flow.execute() + for i in range(n_files): + assert (tmp_path / f"test{i}.abc").exists() + + +def test_file_copy(shared_datadir, tmp_path): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + copy = flow.add(Copy[Path]) + out1 = flow.add(SaveFile[Path], name="out1", parameters={"destination": tmp_path / "test1.abc"}) + out2 = flow.add(SaveFile[Path], name="out2", parameters={"destination": tmp_path / "test2.abc"}) + flow.connect(data.out, copy.inp) + flow.connect(copy.out, out1.inp) + flow.connect(copy.out, out2.inp) + flow.execute() + assert (tmp_path / "test1.abc").exists() + assert (tmp_path / "test2.abc").exists() + + +def test_file_copy_delay(shared_datadir, tmp_path): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + copy = flow.add(Copy[Path]) + del1 = flow.add(Delay[Path], name="del1", parameters={"delay": 2}) + del2 = flow.add(Delay[Path], name="del2", parameters={"delay": 5}) + out1 = flow.add(SaveFile[Path], name="out1", parameters={"destination": tmp_path / "test1.abc"}) + out2 = flow.add(SaveFile[Path], name="out2", parameters={"destination": tmp_path / "test2.abc"}) + flow.connect(data.out, del1.inp) + flow.connect(del1.out, copy.inp) + flow.connect(copy.out, out2.inp) + flow.connect(copy.out, del2.inp) + flow.connect(del2.out, out1.inp) + flow.execute() + assert (tmp_path / "test1.abc").exists() + assert (tmp_path / "test2.abc").exists() + + +def test_nested_file_copy(shared_datadir, tmp_path): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + copy1 = flow.add(Copy[Path], name="copy1") + copy2 = flow.add(Copy[Path], name="copy2") + out1 = flow.add(SaveFile[Path], name="out1", parameters={"destination": tmp_path / "test1.abc"}) + out2 = flow.add(SaveFile[Path], name="out2", parameters={"destination": tmp_path / "test2.abc"}) + out3 = flow.add(SaveFile[Path], name="out3", parameters={"destination": tmp_path / "test3.abc"}) + flow.connect(data.out, copy1.inp) + flow.connect(copy1.out, out1.inp) + flow.connect(copy1.out, copy2.inp) + flow.connect(copy2.out, out2.inp) + flow.connect(copy2.out, out3.inp) + flow.execute() + assert (tmp_path / "test1.abc").exists() + assert (tmp_path / "test2.abc").exists() + assert (tmp_path / "test3.abc").exists() + + +def test_file_load(shared_datadir, tmp_path): + flow = Workflow() + data = flow.add(LoadFile[Path], parameters={"file": shared_datadir / "testorigin.abc"}) + out = flow.add(SaveFile[Path], parameters={"destination": tmp_path / "test.abc"}) + flow.connect(data.out, out.inp) + flow.execute() + assert (tmp_path / "test.abc").exists() + assert (shared_datadir / "testorigin.abc").exists() diff --git a/tests/utilities/__init__.py b/tests/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utilities/conftest.py b/tests/utilities/conftest.py new file mode 100644 index 0000000..4f17714 --- /dev/null +++ b/tests/utilities/conftest.py @@ -0,0 +1,12 @@ +"""Utility testing data""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +import pytest + +from maize.core.component import Component + + +@pytest.fixture +def mock_component(): + return Component() diff --git a/tests/utilities/data/dict.json b/tests/utilities/data/dict.json new file mode 100644 index 0000000..1794219 --- /dev/null +++ b/tests/utilities/data/dict.json @@ -0,0 +1 @@ +{"a": 42, "b": "foo", "c": {"d": ["a", "b"]}} diff --git a/tests/utilities/data/dict.toml b/tests/utilities/data/dict.toml new file mode 100644 index 0000000..93e8802 --- /dev/null +++ b/tests/utilities/data/dict.toml @@ -0,0 +1,3 @@ +a = 42 +b = "foo" +c.d = ["a", "b"] diff --git a/tests/utilities/data/dict.yaml b/tests/utilities/data/dict.yaml new file mode 100644 index 0000000..b0ed33e --- /dev/null +++ b/tests/utilities/data/dict.yaml @@ -0,0 +1,4 @@ +a: 42 +b: foo +c: + d: ["a", "b"] \ No newline at end of file diff --git a/tests/utilities/data/testorigin.abc b/tests/utilities/data/testorigin.abc new file mode 100644 index 0000000..e69de29 diff --git a/tests/utilities/test_execution.py b/tests/utilities/test_execution.py new file mode 100644 index 0000000..6ceaff5 --- /dev/null +++ b/tests/utilities/test_execution.py @@ -0,0 +1,268 @@ +"""Execution testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name + +import time + +import pytest +from maize.utilities.execution import ( + ProcessError, + ResourceManagerConfig, + check_executable, + _UnmemoizedCommandRunner as CommandRunner, +) +from maize.utilities.validation import FailValidator, SuccessValidator +from maize.utilities.utilities import load_modules, set_environment + + +@pytest.fixture +def nonexistent_command(): + return "thiscommanddoesntexist" + + +@pytest.fixture +def echo_command(): + return "echo 'test'" + + +@pytest.fixture +def sleep_command(): + return "sleep 2" + + +@pytest.fixture +def slurm_config(): + return ResourceManagerConfig(system="slurm", queue="core", launcher="srun", walltime="00:05:00") + + +class Test_CommandRunner: + def test_run_only(self, echo_command): + cmd = CommandRunner() + res = cmd.run_only(echo_command, verbose=True) + assert "test" in res.stdout.decode() + + def test_run_only_fail(self, nonexistent_command): + with pytest.raises(FileNotFoundError): + cmd = CommandRunner() + cmd.run_only(nonexistent_command) + + def test_run_only_fail2(self): + with pytest.raises(ProcessError): + cmd = CommandRunner(name="default") + res = cmd.run_only("cat nonexistent") + cmd = CommandRunner(raise_on_failure=False) + res = cmd.run_only("cat nonexistent") + assert res.returncode == 1 + + def test_run_timeout(self, sleep_command): + cmd = CommandRunner(raise_on_failure=False) + res = cmd.run_only(sleep_command, timeout=1) + assert res.returncode == 130 + + def test_run_timeout_error(self, sleep_command): + cmd = CommandRunner() + with pytest.raises(ProcessError): + cmd.run_only(sleep_command, timeout=1) + + def test_run_only_input(self): + cmd = CommandRunner() + res = cmd.run_only("cat", verbose=True, command_input="foo") + assert "foo" in res.stdout.decode() + + def test_run_async(self, echo_command): + cmd = CommandRunner() + proc = cmd.run_async(echo_command) + assert "test" in proc.wait().stdout.decode() + + def test_run_async_sleep(self, sleep_command): + cmd = CommandRunner() + start = time.time() + proc = cmd.run_async(sleep_command) + assert proc.is_alive() + proc.wait() + assert not proc.is_alive() + assert (time.time() - start) < 3 + + def test_run_async_sleep_kill(self, sleep_command): + cmd = CommandRunner() + start = time.time() + proc = cmd.run_async(sleep_command) + assert proc.is_alive() + proc.kill() + assert not proc.is_alive() + assert (time.time() - start) < 2 + + def test_run_validate_no_validators(self, echo_command): + cmd = CommandRunner() + res = cmd.run(echo_command) + assert "test" in res.stdout.decode() + + def test_run_validate(self, echo_command): + cmd = CommandRunner(validators=[SuccessValidator("test")]) + res = cmd.run(echo_command) + assert "test" in res.stdout.decode() + + def test_run_validate_fail(self, echo_command): + cmd = CommandRunner(validators=[FailValidator("test")]) + with pytest.raises(ProcessError): + cmd.run(echo_command) + + def test_run_multi(self, echo_command): + cmd = CommandRunner() + results = cmd.run_parallel([echo_command, echo_command]) + for result in results: + assert result.returncode == 0 + + def test_run_multi_batch(self, echo_command): + cmd = CommandRunner() + results = cmd.run_parallel([echo_command for _ in range(5)], n_batch=2) + assert len(results) == 2 + for result in results: + assert result.returncode == 0 + + def test_run_multi_batch2(self, echo_command): + cmd = CommandRunner() + results = cmd.run_parallel([echo_command for _ in range(5)], batchsize=2) + assert len(results) == 3 + for result in results: + assert result.returncode == 0 + + def test_run_multi_batch3(self, echo_command, tmp_path): + cmd = CommandRunner() + results = cmd.run_parallel( + [echo_command for _ in range(5)], working_dirs=[tmp_path for _ in range(5)], batchsize=2 + ) + assert len(results) == 3 + for result in results: + assert result.returncode == 0 + + def test_run_multi_batch_fail(self, echo_command): + cmd = CommandRunner() + with pytest.raises(ValueError): + cmd.run_parallel([echo_command for _ in range(5)], batchsize=2, n_batch=2) + + def test_run_multi_batch_fail2(self, echo_command): + cmd = CommandRunner() + with pytest.raises(ValueError): + cmd.run_parallel( + [echo_command for _ in range(5)], + command_inputs=["foo" for _ in range(5)], + batchsize=2, + ) + + def test_run_multi_batch_fail3(self, echo_command, tmp_path): + cmd = CommandRunner() + with pytest.raises(ValueError): + cmd.run_parallel( + [echo_command for _ in range(5)], + working_dirs=[tmp_path, "foo", tmp_path, tmp_path, tmp_path], + batchsize=2, + ) + + def test_run_multi_input(self): + cmd = CommandRunner() + inputs = ["foo", "bar"] + results = cmd.run_parallel(["cat", "cat"], verbose=True, command_inputs=inputs) + for inp, res in zip(inputs, results): + assert inp in res.stdout.decode() + + def test_run_multi_time(self, sleep_command): + cmd = CommandRunner() + start = time.time() + results = cmd.run_parallel([sleep_command, sleep_command], n_jobs=2) + assert (time.time() - start) < 3 + for result in results: + assert result.returncode == 0 + + def test_run_multi_timeout(self, sleep_command): + cmd = CommandRunner(raise_on_failure=False) + start = time.time() + results = cmd.run_parallel([sleep_command, sleep_command], n_jobs=2, timeout=1) + assert (time.time() - start) < 2.5 + for result in results: + assert result.returncode == 130 + + def test_run_multi_timeout_error(self, sleep_command): + cmd = CommandRunner() + with pytest.raises(ProcessError): + cmd.run_parallel([sleep_command, sleep_command], n_jobs=2, timeout=1) + + def test_run_multi_seq(self, sleep_command): + cmd = CommandRunner() + start = time.time() + results = cmd.run_parallel([sleep_command, sleep_command], n_jobs=1) + assert (time.time() - start) > 3 + for result in results: + assert result.returncode == 0 + + def test_run_multi_validate(self, echo_command): + cmd = CommandRunner(validators=[SuccessValidator("test")]) + results = cmd.run_parallel([echo_command, echo_command], validate=True) + for res in results: + assert "test" in res.stdout.decode() + + def test_run_preexec(self): + cmd = CommandRunner() + res = cmd.run("env", pre_execution="export BLAH=foo") + assert "BLAH=foo" in res.stdout.decode() + + +class Test_check_executable: + def test_fail(self, nonexistent_command): + assert not check_executable(nonexistent_command) + + def test_success(self): + command = "ls" + assert check_executable(command) + + +@pytest.mark.skipif( + not check_executable("sinfo"), reason="Testing slurm requires a functioning Slurm batch system" +) +class Test_batch: + def test_run_only(self, echo_command, slurm_config): + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + res = cmd.run_only(echo_command, verbose=True) + assert "test" in res.stdout.decode() + + def test_run_only_timeformat(self, echo_command, slurm_config): + slurm_config.walltime = "02-05:00" + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + res = cmd.run_only(echo_command, verbose=True) + assert "test" in res.stdout.decode() + + def test_run_only_fail(self, nonexistent_command, slurm_config): + with pytest.raises((FileNotFoundError, ProcessError)): + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + cmd.run_only(nonexistent_command) + + def test_run_modules(self, slurm_config): + load_modules("GCC") + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + res = cmd.run_only("gcc --version", verbose=True) + assert "(GCC)" in res.stdout.decode() + + def test_run_env(self, slurm_config): + set_environment({"MAIZE_TEST": "MAIZE_TEST_SUCCESS"}) + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + res = cmd.run_only("env", verbose=True) + assert "MAIZE_TEST_SUCCESS" in res.stdout.decode() + + def test_run_validate(self, echo_command, slurm_config): + cmd = CommandRunner( + validators=[SuccessValidator("test")], prefer_batch=True, rm_config=slurm_config + ) + res = cmd.run(echo_command) + assert "test" in res.stdout.decode() + + def test_run_multi(self, echo_command, slurm_config): + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + results = cmd.run_parallel([echo_command, echo_command]) + for result in results: + assert result.returncode == 0 + + def test_run_multi_many(self, echo_command, slurm_config): + cmd = CommandRunner(prefer_batch=True, rm_config=slurm_config) + results = cmd.run_parallel([echo_command for _ in range(50)]) + for result in results: + assert result.returncode == 0 diff --git a/tests/utilities/test_io.py b/tests/utilities/test_io.py new file mode 100644 index 0000000..d86e973 --- /dev/null +++ b/tests/utilities/test_io.py @@ -0,0 +1,234 @@ +"""IO testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +import argparse +from dataclasses import dataclass +from pathlib import Path +import subprocess +from types import ModuleType +import typing +from typing import Literal + +import dill +import pytest +from maize.utilities.execution import ResourceManagerConfig + +from maize.utilities.io import ( + args_from_function, + common_parent, + get_plugins, + load_file, + remove_dir_contents, + sendtree, + parse_groups, + setup_workflow, + wait_for_file, + with_keys, + with_fields, + read_input, + write_input, + Config +) + + +@pytest.fixture +def parser_groups(): + parser = argparse.ArgumentParser() + a = parser.add_argument_group("a") + a.add_argument("-a", type=int) + b = parser.add_argument_group("b") + b.add_argument("-b", type=str) + return parser + + +@pytest.fixture +def obj(): + @dataclass + class obj: + a: int = 42 + b: str = "foo" + + return obj + + +@pytest.fixture +def files(tmp_path): + paths = [ + Path("a/b/c/d"), + Path("a/b/c/e"), + Path("a/b/f/g"), + Path("a/g/h/e"), + ] + return [tmp_path / path for path in paths] + + +def test_remove_dir_contents(tmp_path): + path = tmp_path / "dir" + path.mkdir() + file = path / "file" + file.touch() + remove_dir_contents(path) + assert not file.exists() + assert path.exists() + + +def test_wait_for_file(tmp_path): + path = tmp_path / "file.dat" + subprocess.Popen(f"sleep 5 && echo 'foo' > {path.as_posix()}", shell=True) + wait_for_file(path) + assert path.exists() + + +def test_wait_for_file2(tmp_path): + path = tmp_path / "file.dat" + subprocess.Popen(f"sleep 1 && echo 'foo' > {path.as_posix()}", shell=True) + wait_for_file(path, timeout=2) + assert path.exists() + + +def test_wait_for_file3(tmp_path): + path = tmp_path / "file.dat" + subprocess.Popen(f"sleep 1 && touch {path.as_posix()}", shell=True) + wait_for_file(path, timeout=2, zero_byte_check=False) + assert path.exists() + + +def test_wait_for_file_to(tmp_path): + path = tmp_path / "file.dat" + subprocess.Popen(f"sleep 5 && touch {path.as_posix()}", shell=True) + with pytest.raises(TimeoutError): + wait_for_file(path, timeout=2) + assert path.exists() + + +def test_common_parent(files): + com = common_parent(files) + assert com.relative_to(com.parent) == Path("a") + com = common_parent(files[:3]) + assert com.relative_to(com.parent) == Path("b") + com = common_parent(files[:2]) + assert com.relative_to(com.parent) == Path("c") + com = common_parent([files[1], files[3]]) + assert com.relative_to(com.parent) == Path("a") + + +class TestConfig: + def test_env_init(self, mocker): + conf = Path("conf.toml") + with conf.open("w") as out: + out.write("scratch = 'test'") + mocker.patch("os.environ", {"MAIZE_CONFIG": conf.absolute().as_posix()}) + config = Config.from_default() + assert config.scratch == Path("test") + + def test_default(self, mocker): + mocker.patch("os.environ", {}) + config = Config.from_default() + assert config.nodes == {} + assert config.environment == {} + assert config.batch_config == ResourceManagerConfig() + + +def test_get_plugins(): + assert "pytest.__main__" in get_plugins("pytest") + + +def test_load_file(tmp_path): + module_file = tmp_path / "module.py" + with module_file.open("w") as mod: + mod.write("CONSTANT = 42\n") + assert isinstance(load_file(module_file), ModuleType) + + +def test_args_from_function(mocker): + def func(foo: int, flag: bool, lit: typing.Literal["foo", "bar"]) -> str: + return "baz" + + mocker.patch("sys.argv", ["testing", "--foo", "42", "--flag", "--lit", "foo"]) + parser = argparse.ArgumentParser() + parser = args_from_function(parser, func) + args = parser.parse_args() + assert args.foo == 42 + assert args.flag + assert args.lit == "foo" + + +def test_sendtree_link(files, tmp_path): + for file in files: + file.parent.mkdir(parents=True, exist_ok=True) + file.touch() + + res = sendtree({i: file for i, file in enumerate(files)}, tmp_path, mode="link") + for file in res.values(): + assert (tmp_path / file).exists() + assert (tmp_path / file).is_symlink() + + +def test_sendtree_copy(files, tmp_path): + for file in files: + file.parent.mkdir(parents=True, exist_ok=True) + file.touch() + + res = sendtree({i: file for i, file in enumerate(files)}, tmp_path, mode="copy") + for file in res.values(): + assert (tmp_path / file).exists() + assert not (tmp_path / file).is_symlink() + + +def test_sendtree_move(files, tmp_path): + for file in files: + file.parent.mkdir(parents=True, exist_ok=True) + file.touch() + + res = sendtree({i: file for i, file in enumerate(files)}, tmp_path, mode="move") + for file in res.values(): + assert (tmp_path / file).exists() + assert not (tmp_path / file).is_symlink() + + +def test_parse_groups(parser_groups, mocker): + mocker.patch("sys.argv", ["testing", "-a", "42", "-b", "foo"]) + groups = parse_groups(parser_groups) + assert vars(groups["a"]) == {"a": 42} + assert vars(groups["b"]) == {"b": "foo"} + + +def test_setup_workflow(nested_graph_with_params, mocker): + mocker.patch("sys.argv", ["testing", "--val", "38", "--quiet", "--debug"]) + setup_workflow(nested_graph_with_params) + t = nested_graph_with_params.nodes["t"] + assert t.get() == 50 + + +def test_setup_workflow_check(nested_graph_with_params, mocker): + mocker.patch("sys.argv", ["testing", "--val", "38", "--check"]) + setup_workflow(nested_graph_with_params) + + +def test_with_keys(): + assert with_keys({"a": 42, "b": 50}, {"a"}) == {"a": 42} + + +def test_with_fields(obj): + assert with_fields(obj, ("a",)) == {"a": 42} + + +def test_read_input(shared_datadir): + res = read_input(shared_datadir / "dict.yaml") + assert res == {"a": 42, "b": "foo", "c": {"d": ["a", "b"]}} + res = read_input(shared_datadir / "dict.json") + assert res == {"a": 42, "b": "foo", "c": {"d": ["a", "b"]}} + res = read_input(shared_datadir / "dict.toml") + assert res == {"a": 42, "b": "foo", "c": {"d": ["a", "b"]}} + with pytest.raises(ValueError): + read_input(shared_datadir / "testorigin.abc") + with pytest.raises(FileNotFoundError): + read_input(shared_datadir / "nothere.yaml") + + +def test_write_input(tmp_path): + data = {"a": 42, "b": "foo", "c": {"d": ["a", "b"]}} + write_input(tmp_path / "dict.yaml", data) + res = read_input(tmp_path / "dict.yaml") + assert res == data diff --git a/tests/utilities/test_resources.py b/tests/utilities/test_resources.py new file mode 100644 index 0000000..c2040c5 --- /dev/null +++ b/tests/utilities/test_resources.py @@ -0,0 +1,32 @@ +"""Resources testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name + +import pytest + +from maize.utilities.resources import gpu_count, ChunkedSemaphore, Resources + + +def test_gpu_count(): + assert gpu_count() in range(1024) + + +def test_ChunkedSemaphore(): + sem = ChunkedSemaphore(10, sleep=1) + sem.acquire(5) + sem.acquire(5) + with pytest.raises(ValueError): + sem.acquire(20) + sem.release(5) + sem.release(5) + with pytest.raises(ValueError): + sem.release(5) + +def test_Resources(mock_component): + sem = Resources(10, parent=mock_component) + with sem(5): + pass + + with pytest.raises(ValueError): + with sem(20): + pass diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py new file mode 100644 index 0000000..afbf52d --- /dev/null +++ b/tests/utilities/test_testing.py @@ -0,0 +1,44 @@ +"""Testing testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name, attribute-defined-outside-init, unused-import + +import pytest + +from maize.utilities.testing import MockChannel +from maize.utilities.macros import node_to_function + + +@pytest.fixture +def empty_mock_channel(): + channel = MockChannel() + return channel + + +@pytest.fixture +def loaded_mock_channel(): + channel = MockChannel(items=[1, 2, 3]) + return channel + + +class Test_MockChannel: + def test_channel_send(self, empty_mock_channel): + empty_mock_channel.send(42) + assert empty_mock_channel.ready + + def test_channel_receive(self, loaded_mock_channel): + assert loaded_mock_channel.receive() == 1 + + def test_channel_receive_empty(self, empty_mock_channel): + assert empty_mock_channel.receive(timeout=1) is None + + def test_channel_close_receive(self, loaded_mock_channel): + loaded_mock_channel.close() + assert loaded_mock_channel.receive() == 1 + + def test_flush(self, loaded_mock_channel): + assert loaded_mock_channel.flush() == [1, 2, 3] + + +def test_node_to_function(example_a): + afunc = node_to_function(example_a) + assert afunc(val=42) == {"out": 42} diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py new file mode 100644 index 0000000..0acaf80 --- /dev/null +++ b/tests/utilities/test_utilities.py @@ -0,0 +1,137 @@ +"""Utilities testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name + +from pathlib import Path +import re +import string +from typing import Annotated, Any, Literal + +from maize.core.graph import Graph +from maize.core.workflow import Workflow +from maize.steps.plumbing import Delay, Merge +from maize.utilities.utilities import ( + deprecated, + unique_id, + graph_cycles, + typecheck, + Timer, + tuple_to_nested_dict, + nested_dict_to_tuple, + has_file, + make_list, + matching_types, + find_probable_files_from_command, + match_context, +) + + +def test_unique_id(): + assert len(unique_id()) == 6 + assert set(unique_id()).issubset(set(string.ascii_lowercase + string.digits)) + + +def test_graph_cycles(example_a): + class SubGraph(Graph): + def build(self): + a = self.add(example_a, "a") + m = self.add(Merge[int], "m") + a.out >> m.inp + self.inp = self.map_port(m.inp, "inp") + self.out = self.map_port(m.out, "out") + + g = Workflow() + a = g.add(SubGraph, "a") + b = g.add(Delay[int], "b") + c = g.add(Delay[int], "c") + a >> b >> c + c.out >> a.inp + cycles = graph_cycles(g) + assert len(cycles) > 0 + + +def test_typecheck(): + assert typecheck(42, int) + assert not typecheck("foo", int) + assert typecheck(42, Annotated[int, lambda x: x > 10]) + assert typecheck(42, Annotated[int, lambda x: x > 10, lambda x: x < 50]) + assert not typecheck(8, Annotated[int, lambda x: x > 10]) + assert typecheck(42, Annotated[int | float, lambda x: x > 10]) + assert typecheck(42, int | str) + assert typecheck(42, None) + assert typecheck(42, Any) + assert typecheck(42, Literal[42, 17]) + assert typecheck({"foo": 42}, dict[str, int]) + assert typecheck({"foo": 42, "bar": 39}, dict[str, int]) + + +def test_deprecated(): + def func(a: int) -> int: + return a + 1 + + new = deprecated("func is deprecated")(func) + assert new(17) == 18 + + class cls: + def func(self, a: int) -> int: + return a + 1 + + new = deprecated("cls is deprecated")(cls) + assert new().func(17) == 18 + + +def test_Timer(): + t = Timer() + ini_time = t.elapsed_time + assert not t.running + t.start() + assert t.running + assert t.elapsed_time > ini_time + assert t.running + t.pause() + assert not t.running + assert t.stop() > ini_time + + +def test_tuple_to_nested_dict(): + res = tuple_to_nested_dict("a", "b", "c", 42) + assert res == {"a": {"b": {"c": 42}}} + + +def test_nested_dict_to_tuple(): + res = nested_dict_to_tuple({"a": {"b": {"c": 42}}}) + assert res == ("a", "b", "c", 42) + + +def test_has_file(shared_datadir, tmp_path): + assert has_file(shared_datadir) + empty = tmp_path / "empty" + empty.mkdir() + assert not has_file(empty) + + +def test_make_list(): + assert make_list(42) == [42] + assert make_list([42]) == [42] + assert make_list({42, 17}) == sorted([42, 17]) + assert make_list((42, 17)) == [42, 17] + + +def test_matching_types(): + assert matching_types(int, int) + assert matching_types(bool, int) + assert not matching_types(bool, int, strict=True) + assert not matching_types(str, int) + assert matching_types(str, None) + assert matching_types(Annotated[int, "something"], int) + + +def test_find_probable_files_from_command(): + assert find_probable_files_from_command("cat ./test") == [Path("./test")] + assert find_probable_files_from_command("cat test") == [] + assert find_probable_files_from_command("cat test.abc") == [Path("test.abc")] + + +def test_match_context(): + match = re.search("foo", "-----foo-----") + assert match_context(match, chars=3) == "---foo---" diff --git a/tests/utilities/test_validation.py b/tests/utilities/test_validation.py new file mode 100644 index 0000000..f8edea8 --- /dev/null +++ b/tests/utilities/test_validation.py @@ -0,0 +1,77 @@ +"""Validation testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name + +import pytest +from pathlib import Path +from maize.utilities.execution import CommandRunner +from maize.utilities.validation import ( + FailValidator, + SuccessValidator, + FileValidator, + ContentValidator, +) + + +@pytest.fixture +def echo_command(): + return "echo 'test'" + + +@pytest.fixture +def process_result(echo_command): + cmd = CommandRunner() + return cmd.run_only(echo_command) + + +@pytest.fixture +def tmp_file_gen(tmp_path): + return tmp_path / "test.xyz" + + +@pytest.fixture +def toml_file(shared_datadir: Path) -> Path: + return shared_datadir / "dict.toml" + + +@pytest.fixture +def yaml_file(shared_datadir: Path) -> Path: + return shared_datadir / "dict.yaml" + + +@pytest.fixture +def process_result_file(tmp_file_gen): + cmd = CommandRunner() + return cmd.run_only(["touch", tmp_file_gen.as_posix()]) + + +class Test_Validator: + def test_fail_validator(self, process_result): + val = FailValidator("test") + assert not val(process_result) + val = FailValidator(["test", "something"]) + assert not val(process_result) + val = FailValidator(["something"]) + assert val(process_result) + + def test_success_validator(self, process_result): + val = SuccessValidator("test") + assert val(process_result) + val = SuccessValidator(["test", "something"]) + assert not val(process_result) + val = SuccessValidator(["something"]) + assert not val(process_result) + + def test_file_validator(self, process_result_file, tmp_file_gen): + val = FileValidator(tmp_file_gen, zero_byte_check=False) + assert val(process_result_file) + val = FileValidator(tmp_file_gen, zero_byte_check=True) + assert not val(process_result_file) + val = FileValidator(tmp_file_gen / "fake", zero_byte_check=False) + assert not val(process_result_file) + + def test_content_validator(self, toml_file, yaml_file): + val = ContentValidator({toml_file: ["foo", "c.d"], yaml_file: ["foo"]}) + assert val(process_result_file) + val = ContentValidator({toml_file: ["bar"]}) + assert not val(process_result_file) diff --git a/tests/utilities/test_visual.py b/tests/utilities/test_visual.py new file mode 100644 index 0000000..b510140 --- /dev/null +++ b/tests/utilities/test_visual.py @@ -0,0 +1,14 @@ +"""Visualization testing""" + +# pylint: disable=redefined-outer-name, import-error, missing-function-docstring, missing-class-docstring, invalid-name + +import pytest +from maize.utilities.visual import nested_graphviz + + +class Test_Visual: + def test_nested_graphviz(self, nested_graph): + dot = nested_graphviz(nested_graph) + assert "subgraph" in dot.body[0] + assert "cluster-sg" in dot.body[0] + assert "cluster-ssg" in dot.body[4]