Skip to content

Commit

Permalink
Add MemoryRecordStore (#14919)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Aug 14, 2024
1 parent a257267 commit c2b81d7
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/prefect/records/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .store import RecordStore
from .base import RecordStore
167 changes: 167 additions & 0 deletions src/prefect/records/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
import socket
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional

from prefect.results import BaseResult


class RecordStore:
def read(self, key: str):
raise NotImplementedError

def write(self, key: str, value: dict):
raise NotImplementedError

def exists(self, key: str) -> bool:
return False

def acquire_lock(
self,
key: str,
holder: Optional[str] = None,
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
"""
Acquire a lock for a transaction record with the given key. Will block other
actors from updating this transaction record until the lock is
released.
Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock. If not provided,
a default holder based on the current host, process, and thread will
be used.
acquire_timeout: Max number of seconds to wait for the record to become
available if it is locked while attempting to acquire a lock. Pass 0
to attempt to acquire a lock without waiting. Blocks indefinitely by
default.
hold_timeout: Max number of seconds to hold the lock for. Holds the lock
indefinitely by default.
Returns:
bool: True if the lock was successfully acquired; False otherwise.
"""
raise NotImplementedError

def release_lock(self, key: str, holder: Optional[str] = None):
"""
Releases the lock on the corresponding transaction record.
Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock. Must match the
holder provided when acquiring the lock.
"""
raise NotImplementedError

def is_locked(self, key: str) -> bool:
"""
Simple check to see if the corresponding record is currently locked.
Args:
key: Unique identifier for the transaction record.
Returns:
True is the record is locked; False otherwise.
"""
raise NotImplementedError

def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool:
"""
Check if the current holder is the lock holder for the transaction record.
Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock. If not provided,
a default holder based on the current host, process, and thread will
be used.
Returns:
bool: True if the current holder is the lock holder; False otherwise.
"""
raise NotImplementedError

def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool:
"""
Wait for the corresponding transaction record to become free.
Args:
key: Unique identifier for the transaction record.
timeout: Maximum time to wait. None means to wait indefinitely.
Returns:
bool: True if the lock becomes free within the timeout; False
otherwise.
"""
...

@staticmethod
def generate_default_holder() -> str:
"""
Generate a default holder string using hostname, PID, and thread ID.
Returns:
str: A unique identifier string.
"""
hostname = socket.gethostname()
pid = os.getpid()
thread_name = threading.current_thread().name
thread_id = threading.get_ident()
return f"{hostname}:{pid}:{thread_id}:{thread_name}"

@contextmanager
def lock(
self,
key: str,
holder: Optional[str] = None,
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
):
"""
Context manager to lock the transaction record during the execution
of the nested code block.
Args:
key: Unique identifier for the transaction record.
holder: Unique identifier for the holder of the lock. If not provided,
a default holder based on the current host, process, and thread will
be used.
acquire_timeout: Max number of seconds to wait for the record to become
available if it is locked while attempting to acquire a lock. Pass 0
to attempt to acquire a lock without waiting. Blocks indefinitely by
default.
hold_timeout: Max number of seconds to hold the lock for. Holds the lock
indefinitely by default.
Example:
Hold a lock while during an operation:
```python
with TransactionRecord(key="my-transaction-record-key").lock():
do_stuff()
```
"""
self.acquire_lock(
key=key,
holder=holder,
acquire_timeout=acquire_timeout,
hold_timeout=hold_timeout,
)

try:
yield
finally:
self.release_lock(key=key, holder=holder)


@dataclass
class TransactionRecord:
"""
A dataclass representation of a transaction record.
"""

key: str
result: Optional[BaseResult] = None
165 changes: 165 additions & 0 deletions src/prefect/records/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import threading
from typing import Dict, Optional, TypedDict

from prefect.results import BaseResult

from .base import RecordStore, TransactionRecord


class _LockInfo(TypedDict):
"""
A dictionary containing information about a lock.
Attributes:
holder: The holder of the lock.
lock: The lock object.
expiration_timer: The timer for the lock expiration
"""

holder: str
lock: threading.Lock
expiration_timer: Optional[threading.Timer]


class MemoryRecordStore(RecordStore):
"""
A record store that stores data in memory.
"""

_instance = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
self._locks_dict_lock = threading.Lock()
self._locks: Dict[str, _LockInfo] = {}
self._records: Dict[str, TransactionRecord] = {}

def read(self, key: str) -> Optional[TransactionRecord]:
return self._records.get(key)

def write(self, key: str, value: BaseResult, holder: Optional[str] = None) -> None:
holder = holder or self.generate_default_holder()

with self._locks_dict_lock:
if self.is_locked(key) and not self.is_lock_holder(key, holder):
raise ValueError(
f"Cannot write to transaction with key {key} because it is locked by another holder."
)
self._records[key] = TransactionRecord(key=key, result=value)

def exists(self, key: str) -> bool:
return key in self._records

def _expire_lock(self, key: str):
"""
Expire the lock for the given key.
Used as a callback for the expiration timer of a lock.
Args:
key: The key of the lock to expire.
"""
with self._locks_dict_lock:
if key in self._locks:
lock_info = self._locks[key]
if lock_info["lock"].locked():
lock_info["lock"].release()
if lock_info["expiration_timer"]:
lock_info["expiration_timer"].cancel()
del self._locks[key]

def acquire_lock(
self,
key: str,
holder: Optional[str] = None,
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
holder = holder or self.generate_default_holder()
with self._locks_dict_lock:
if key not in self._locks:
lock = threading.Lock()
lock.acquire()
expiration_timer = None
if hold_timeout is not None:
expiration_timer = threading.Timer(
hold_timeout, self._expire_lock, args=(key,)
)
expiration_timer.start()
self._locks[key] = _LockInfo(
holder=holder, lock=lock, expiration_timer=expiration_timer
)
return True
elif self._locks[key]["holder"] == holder:
return True
else:
existing_lock_info = self._locks[key]

if acquire_timeout is not None:
existing_lock_acquired = existing_lock_info["lock"].acquire(
timeout=acquire_timeout
)
else:
existing_lock_acquired = existing_lock_info["lock"].acquire()

if existing_lock_acquired:
with self._locks_dict_lock:
if (
expiration_timer := existing_lock_info["expiration_timer"]
) is not None:
expiration_timer.cancel()
expiration_timer = None
if hold_timeout is not None:
expiration_timer = threading.Timer(
hold_timeout, self._expire_lock, args=(key,)
)
expiration_timer.start()
self._locks[key] = _LockInfo(
holder=holder,
lock=existing_lock_info["lock"],
expiration_timer=expiration_timer,
)
return True
return False

def release_lock(self, key: str, holder: Optional[str] = None) -> None:
holder = holder or self.generate_default_holder()
with self._locks_dict_lock:
if key in self._locks and self._locks[key]["holder"] == holder:
if (
expiration_timer := self._locks[key]["expiration_timer"]
) is not None:
expiration_timer.cancel()
self._locks[key]["lock"].release()
del self._locks[key]
else:
raise ValueError(
f"No lock held by {holder} for transaction with key {key}"
)

def is_locked(self, key: str) -> bool:
return key in self._locks and self._locks[key]["lock"].locked()

def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool:
holder = holder or self.generate_default_holder()
lock_info = self._locks.get(key)
return (
lock_info is not None
and lock_info["lock"].locked()
and lock_info["holder"] == holder
)

def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool:
if lock := self._locks.get(key, {}).get("lock"):
if timeout is not None:
lock_acquired = lock.acquire(timeout=timeout)
else:
lock_acquired = lock.acquire()
if lock_acquired:
lock.release()
return lock_acquired
return True
2 changes: 1 addition & 1 deletion src/prefect/records/result_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from prefect.results import BaseResult, PersistedResult, ResultFactory
from prefect.utilities.asyncutils import run_coro_as_sync

from .store import RecordStore
from .base import RecordStore


@dataclass
Expand Down
9 changes: 0 additions & 9 deletions src/prefect/records/store.py

This file was deleted.

Loading

0 comments on commit c2b81d7

Please sign in to comment.