diff --git a/tests/decryption/test_decrypt.py b/tests/decryption/test_decrypt.py index e502718..c579eb6 100644 --- a/tests/decryption/test_decrypt.py +++ b/tests/decryption/test_decrypt.py @@ -1,6 +1,8 @@ """Tests for decrypt.py""" +import os from pathlib import Path import shutil +from unittest import mock from crypt4gh.keys import get_private_key as get_sk_bytes, get_public_key as get_pk_bytes import pytest @@ -8,8 +10,10 @@ from crypt4gh_middleware.decrypt import ( get_private_keys, decrypt_files, - move_files + move_files, + get_args ) +from tests.utils import patch_cli INPUT_DIR = Path(__file__).parents[2]/"inputs" INPUT_TEXT = "hello world from the input!" @@ -145,3 +149,41 @@ def test_permission_error(self, tmp_path): output_dir.chmod(0o400) with pytest.raises(PermissionError): move_files(file_paths=[INPUT_DIR/"hello.txt"], output_dir=output_dir) + + +class TestGetArgs: + """Test get_args.""" + + def test_get_args(self): + """Test that the arguments are parsed correctly.""" + with patch_cli(["decrypt.py", "--output-dir", "dir", "file.txt"]): + args = get_args() + assert args.output_dir == Path("dir") + assert args.file_paths == [Path("file.txt")] + + def test_multiple_files(self): + """Test that multiple file paths can be passed.""" + files = ["file.txt", "file2.txt", "file3.txt"] + with patch_cli(["decrypt.py"] + files): + args = get_args() + assert args.file_paths == [Path(file) for file in files] + + def test_default_output_dir(self): + """Test that output_dir defaults to $TMPDIR when no directory is passed.""" + with (patch_cli(["decrypt.py", "file.txt"]), + mock.patch.dict(os.environ, {"TMPDIR": "/mock/tmpdir"})): + args = get_args() + assert args.output_dir == Path("/mock/tmpdir") + assert args.file_paths == [Path("file.txt")] + + def test_invalid_argument(self): + """Test that a system exit occurs when an invalid argument is passed.""" + with (patch_cli(["decrypt.py", "--bad-argument", "dir", "file.txt"]), + pytest.raises(SystemExit)): + get_args() + + def test_no_file_paths(self): + """Test that a system exit occurs when no file paths are passed.""" + with (patch_cli(["decrypt.py", "--output-dir", "dir"]), + pytest.raises(SystemExit)): + get_args() diff --git a/tests/utils.py b/tests/utils.py index fbc38d6..3a37831 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,15 @@ """ Utility functions for tests.""" from functools import wraps +import contextlib import signal +from unittest import mock + + +@contextlib.contextmanager +def patch_cli(args): + """Context manager that patches sys.argv.""" + with mock.patch("sys.argv", args): + yield def timeout(time_limit):