Skip to content

Commit

Permalink
Add tests for Scheduler job and job definition creation with input fo…
Browse files Browse the repository at this point in the history
…lder, refactor execution manager test (#513) (#515)

* use fixtures, rename job used to job-4

* test scheduler job creation and job def creation with input folder

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove duplciate fixture

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
andrii-i and pre-commit-ci[bot] authored Apr 30, 2024
1 parent 4e7fdc3 commit d3734fd
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 33 deletions.
6 changes: 4 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
DB_FILE_PATH = f"{HERE}/jupyter_scheduler/tests/testdb.sqlite"
DB_URL = f"sqlite:///{DB_FILE_PATH}"

TEST_ROOT_DIR = f"{HERE}/jupyter_scheduler/tests/test_root_dir"


@pytest.fixture
def jp_server_config(jp_server_config):
Expand Down Expand Up @@ -44,7 +46,7 @@ def jp_scheduler_db():


@pytest.fixture
def jp_scheduler(jp_data_dir):
def jp_scheduler():
return Scheduler(
db_url=DB_URL, root_dir=str(jp_data_dir), environments_manager=MockEnvironmentManager()
db_url=DB_URL, root_dir=str(TEST_ROOT_DIR), environments_manager=MockEnvironmentManager()
)
53 changes: 23 additions & 30 deletions jupyter_scheduler/tests/test_execution_manager.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,43 @@
import os
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import PropertyMock, patch

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from conftest import DB_URL
from jupyter_scheduler.executors import DefaultExecutionManager
from jupyter_scheduler.orm import Base, Job
from jupyter_scheduler.orm import Job

NOTEBOOK_DIR = Path(__file__).resolve().parent / "test_staging_dir" / "job-3"
JOB_ID = "69856f4e-ce94-45fd-8f60-3a587457fce7"
NOTEBOOK_NAME = "side_effects.ipynb"
SIDE_EFECT_FILE_NAME = "output_side_effect.txt"

NOTEBOOK_DIR = Path(__file__).resolve().parent / "test_staging_dir" / "job-4"
NOTEBOOK_PATH = NOTEBOOK_DIR / NOTEBOOK_NAME
SIDE_EFFECT_FILE = NOTEBOOK_DIR / "output_side_effect.txt"
SIDE_EFFECT_FILE = NOTEBOOK_DIR / SIDE_EFECT_FILE_NAME


def test_execution_manager_with_side_effects():
db_url = "sqlite://"
engine = create_engine(db_url, echo=False)
Base.metadata.create_all(engine)
db_session = sessionmaker(bind=engine)
with db_session() as session:
@pytest.fixture
def load_job(jp_scheduler_db):
with jp_scheduler_db() as session:
job = Job(
runtime_environment_name="abc",
input_filename=NOTEBOOK_NAME,
job_id="123",
job_id=JOB_ID,
)
session.add(job)
session.commit()

manager = DefaultExecutionManager(
job_id="123",
root_dir=str(NOTEBOOK_DIR),
db_url=db_url,
staging_paths={"input": str(NOTEBOOK_PATH)},
)

with patch.object(
DefaultExecutionManager,
"db_session",
new_callable=PropertyMock,
) as mock_db_session:
mock_db_session.return_value = db_session
manager.add_side_effects_files(str(NOTEBOOK_DIR))

assert (
"output_side_effect.txt" in job.packaged_files
), "Side effect file was not added to packaged_files"
def test_add_side_effects_files(jp_scheduler_db, load_job):
manager = DefaultExecutionManager(
job_id=JOB_ID,
root_dir=str(NOTEBOOK_DIR),
db_url=DB_URL,
staging_paths={"input": str(NOTEBOOK_PATH)},
)
manager.add_side_effects_files(str(NOTEBOOK_DIR))

with jp_scheduler_db() as session:
job = session.query(Job).filter(Job.job_id == JOB_ID).one()
assert SIDE_EFECT_FILE_NAME in job.packaged_files
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello world!
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "dbbca65f-4b6a-4490-94b6-f7cdd87e7023",
"metadata": {},
"outputs": [],
"source": [
"file_path = 'a/b/helloworld.txt'\n",
"\n",
"with open(file_path, 'r') as file:\n",
" content = file.read()\n",
"\n",
"print(content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
46 changes: 45 additions & 1 deletion jupyter_scheduler/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import pytest

from jupyter_scheduler.models import (
CreateJob,
CreateJobDefinition,
ListJobDefinitionsQuery,
SortDirection,
SortField,
UpdateJobDefinition,
)
from jupyter_scheduler.orm import JobDefinition
from jupyter_scheduler.orm import Job, JobDefinition


def test_create_job_definition(jp_scheduler):
Expand All @@ -39,6 +40,49 @@ def test_create_job_definition(jp_scheduler):
assert "hello world" == definition.name


def test_create_job_definition_with_input_folder(jp_scheduler):
job_definition_id = jp_scheduler.create_job_definition(
CreateJobDefinition(
input_uri="job-5/import-helloworld.ipynb",
runtime_environment_name="default",
name="import hello world",
output_formats=["ipynb"],
package_input_folder=True,
)
)

with jp_scheduler.db_session() as session:
definitions = session.query(JobDefinition).all()
assert 1 == len(definitions)
definition = definitions[0]
assert job_definition_id
assert job_definition_id == definition.job_definition_id
assert "import hello world" == definition.name
assert "a/b/helloworld.txt" in definition.packaged_files


def test_create_job_with_input_folder(jp_scheduler):
job_id = jp_scheduler.create_job(
CreateJob(
input_uri="job-5/import-helloworld.ipynb",
runtime_environment_name="default",
name="import hello world",
output_formats=["ipynb"],
package_input_folder=True,
)
)

with jp_scheduler.db_session() as session:
jobs = session.query(Job).all()
assert 1 == len(jobs)
job = jobs[0]
assert job_id
assert job_id == job.job_id
assert "import hello world" == job.name
assert "default" == job.runtime_environment_name
assert "a/b/helloworld.txt" in job.packaged_files


job_definition_1 = {
"job_definition_id": "f4f8c8a9-f539-429a-b69e-b567f578646e",
"name": "hello world 1",
Expand Down

0 comments on commit d3734fd

Please sign in to comment.