Skip to content

Commit

Permalink
PyTorch adapter: add a way to disable cache updates (cvat-ai#5549)
Browse files Browse the repository at this point in the history
This will let users to run their PyTorch code without network access,
provided that they have already cached the data.

### How has this been tested?
<!-- Please describe in detail how you tested your changes.
Include details of your testing environment, and the tests you ran to
see how your change affects other areas of the code, etc. -->
Unit tests.
  • Loading branch information
SpecLad authored Jan 6, 2023
1 parent fd7d802 commit 33c624a
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 111 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/opencv/cvat/pull/5535>)
- \[SDK\] Class to represent a project as a PyTorch dataset
(<https://github.com/opencv/cvat/pull/5523>)
- \[SDK\] A PyTorch adapter setting to disable cache updates
(<https://github.com/opencv/cvat/pull/5549>)

### Changed
- The Docker Compose files now use the Compose Specification version
Expand Down
1 change: 1 addition & 0 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT

from .caching import UpdatePolicy
from .common import FrameAnnotations, Target, UnsupportedDatasetError
from .project_dataset import ProjectVisionDataset
from .task_dataset import TaskVisionDataset
Expand Down
222 changes: 222 additions & 0 deletions cvat-sdk/cvat_sdk/pytorch/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import base64
import json
import shutil
from abc import ABCMeta, abstractmethod
from enum import Enum, auto
from pathlib import Path
from typing import Callable, Mapping, Type, TypeVar

import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json
from cvat_sdk.core.client import Client
from cvat_sdk.core.proxies.projects import Project
from cvat_sdk.core.proxies.tasks import Task
from cvat_sdk.core.utils import atomic_writer


class UpdatePolicy(Enum):
"""
Defines policies for when the local cache is updated from the CVAT server.
"""

IF_MISSING_OR_STALE = auto()
"""
Update the cache whenever cached data is missing or the server has a newer version.
"""

NEVER = auto()
"""
Never update the cache. If an operation requires data that is not cached,
it will fail.
No network access will be performed if this policy is used.
"""


_ModelType = TypeVar("_ModelType", bound=OpenApiModel)


class CacheManager(metaclass=ABCMeta):
def __init__(self, client: Client) -> None:
self._client = client
self._logger = client.logger

self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}"

@property
def server_dir_name(self) -> str:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode()

def task_dir(self, task_id: int) -> Path:
return self._server_dir / f"tasks/{task_id}"

def task_json_path(self, task_id: int) -> Path:
return self.task_dir(task_id) / "task.json"

def chunk_dir(self, task_id: int) -> Path:
return self.task_dir(task_id) / "chunks"

def project_dir(self, project_id: int) -> Path:
return self._server_dir / f"projects/{project_id}"

def project_json_path(self, project_id: int) -> Path:
return self.project_dir(project_id) / "project.json"

def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType:
with open(path, "rb") as f:
return model_type._new_from_openapi_data(**json.load(f))

def save_model(self, path: Path, model: OpenApiModel) -> None:
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(to_json(model), f, indent=4)
print(file=f) # add final newline

@abstractmethod
def retrieve_task(self, task_id: int) -> Task:
...

@abstractmethod
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
...

@abstractmethod
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
...

@abstractmethod
def retrieve_project(self, project_id: int) -> Project:
...


class _CacheManagerOnline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Fetching task {task_id}...")
task = self._client.tasks.retrieve(task_id)

self._initialize_task_dir(task)
return task

def _initialize_task_dir(self, task: Task) -> None:
task_dir = self.task_dir(task.id)
task_json_path = self.task_json_path(task.id)

try:
saved_task = self.load_model(task_json_path, models.TaskRead)
except Exception:
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted")

# If the cache was corrupted, the directory might already be there; clear it.
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(task_dir)

task_dir.mkdir(exist_ok=True, parents=True)
self.save_model(task_json_path, task._model)

def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self.task_dir(task_id) / filename

try:
model = self.load_model(path, model_type)
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)

self._logger.info(f"Downloading {model_description}...")
model = downloader()
self._logger.info(f"Downloaded {model_description}")

self.save_model(path, model)

return model

def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously

self._logger.info(f"Downloading chunk #{chunk_index}...")

with atomic_writer(chunk_path, "wb") as chunk_file:
task.download_chunk(chunk_index, chunk_file, quality="original")

def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Fetching project {project_id}...")
project = self._client.projects.retrieve(project_id)

project_dir = self.project_dir(project_id)
project_dir.mkdir(parents=True, exist_ok=True)
project_json_path = self.project_json_path(project_id)

# There are currently no files cached alongside project.json,
# so we don't need to check if we need to purge them.

self.save_model(project_json_path, project._model)

return project


class _CacheManagerOffline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Retrieving task {task_id} from cache...")
return Task(self._client, self.load_model(self.task_json_path(task_id), models.TaskRead))

def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
self._logger.info(f"Loading {model_description} from cache...")
return self.load_model(self.task_dir(task_id) / filename, model_type)

def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"

if not chunk_path.exists():
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached")

def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Retrieving project {project_id} from cache...")
return Project(
self._client, self.load_model(self.project_json_path(project_id), models.ProjectRead)
)


_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = {
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline,
UpdatePolicy.NEVER: _CacheManagerOffline,
}


def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager:
return _CACHE_MANAGER_CLASSES[update_policy](client)
8 changes: 0 additions & 8 deletions cvat-sdk/cvat_sdk/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: MIT

import base64
from pathlib import Path
from typing import List, Mapping

import attrs
Expand Down Expand Up @@ -42,9 +40,3 @@ class Target:
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an integer index. This mapping is consistent across all samples for a given task.
"""


def get_server_cache_dir(client: cvat_sdk.core.Client) -> Path:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
return client.config.cache_dir / f"servers/{server_dir_name}"
27 changes: 14 additions & 13 deletions cvat-sdk/cvat_sdk/pytorch/project_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.common import get_server_cache_dir
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset


Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(
label_name_to_index: Mapping[str, int] = None,
task_filter: Optional[Callable[[models.ITaskRead], bool]] = None,
include_subsets: Optional[Container[str]] = None,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the project with ID `project_id` on the
Expand All @@ -61,29 +62,24 @@ def __init__(
* If `include_subsets` is set to a container, then tasks whose subset is
not a member of this container will be excluded.
`update_policy` determines when and if the local cache will be updated.
"""

self._logger = client.logger

self._logger.info(f"Fetching project {project_id}...")
project = client.projects.retrieve(project_id)

# We don't actually need to save anything to this directory (yet),
# but VisionDataset.__init__ requires a root, so make one.
# It could be useful in the future to store the project data for
# offline-only mode.
project_dir = get_server_cache_dir(client) / f"projects/{project_id}"
project_dir.mkdir(parents=True, exist_ok=True)
cache_manager = make_cache_manager(client, update_policy)
project = cache_manager.retrieve_project(project_id)

super().__init__(
os.fspath(project_dir),
os.fspath(cache_manager.project_dir(project_id)),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)

self._logger.info("Fetching project tasks...")
tasks = project.get_tasks()
tasks = [cache_manager.retrieve_task(task_id) for task_id in project.tasks]

if task_filter is not None:
tasks = list(filter(task_filter, tasks))
Expand All @@ -95,7 +91,12 @@ def __init__(

self._underlying = torch.utils.data.ConcatDataset(
[
TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index)
TaskVisionDataset(
client,
task.id,
label_name_to_index=label_name_to_index,
update_policy=update_policy,
)
for task in tasks
]
)
Expand Down
Loading

0 comments on commit 33c624a

Please sign in to comment.