-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a257267
commit c2b81d7
Showing
6 changed files
with
513 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .store import RecordStore | ||
from .base import RecordStore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import os | ||
import socket | ||
import threading | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from prefect.results import BaseResult | ||
|
||
|
||
class RecordStore: | ||
def read(self, key: str): | ||
raise NotImplementedError | ||
|
||
def write(self, key: str, value: dict): | ||
raise NotImplementedError | ||
|
||
def exists(self, key: str) -> bool: | ||
return False | ||
|
||
def acquire_lock( | ||
self, | ||
key: str, | ||
holder: Optional[str] = None, | ||
acquire_timeout: Optional[float] = None, | ||
hold_timeout: Optional[float] = None, | ||
) -> bool: | ||
""" | ||
Acquire a lock for a transaction record with the given key. Will block other | ||
actors from updating this transaction record until the lock is | ||
released. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
holder: Unique identifier for the holder of the lock. If not provided, | ||
a default holder based on the current host, process, and thread will | ||
be used. | ||
acquire_timeout: Max number of seconds to wait for the record to become | ||
available if it is locked while attempting to acquire a lock. Pass 0 | ||
to attempt to acquire a lock without waiting. Blocks indefinitely by | ||
default. | ||
hold_timeout: Max number of seconds to hold the lock for. Holds the lock | ||
indefinitely by default. | ||
Returns: | ||
bool: True if the lock was successfully acquired; False otherwise. | ||
""" | ||
raise NotImplementedError | ||
|
||
def release_lock(self, key: str, holder: Optional[str] = None): | ||
""" | ||
Releases the lock on the corresponding transaction record. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
holder: Unique identifier for the holder of the lock. Must match the | ||
holder provided when acquiring the lock. | ||
""" | ||
raise NotImplementedError | ||
|
||
def is_locked(self, key: str) -> bool: | ||
""" | ||
Simple check to see if the corresponding record is currently locked. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
Returns: | ||
True is the record is locked; False otherwise. | ||
""" | ||
raise NotImplementedError | ||
|
||
def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool: | ||
""" | ||
Check if the current holder is the lock holder for the transaction record. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
holder: Unique identifier for the holder of the lock. If not provided, | ||
a default holder based on the current host, process, and thread will | ||
be used. | ||
Returns: | ||
bool: True if the current holder is the lock holder; False otherwise. | ||
""" | ||
raise NotImplementedError | ||
|
||
def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool: | ||
""" | ||
Wait for the corresponding transaction record to become free. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
timeout: Maximum time to wait. None means to wait indefinitely. | ||
Returns: | ||
bool: True if the lock becomes free within the timeout; False | ||
otherwise. | ||
""" | ||
... | ||
|
||
@staticmethod | ||
def generate_default_holder() -> str: | ||
""" | ||
Generate a default holder string using hostname, PID, and thread ID. | ||
Returns: | ||
str: A unique identifier string. | ||
""" | ||
hostname = socket.gethostname() | ||
pid = os.getpid() | ||
thread_name = threading.current_thread().name | ||
thread_id = threading.get_ident() | ||
return f"{hostname}:{pid}:{thread_id}:{thread_name}" | ||
|
||
@contextmanager | ||
def lock( | ||
self, | ||
key: str, | ||
holder: Optional[str] = None, | ||
acquire_timeout: Optional[float] = None, | ||
hold_timeout: Optional[float] = None, | ||
): | ||
""" | ||
Context manager to lock the transaction record during the execution | ||
of the nested code block. | ||
Args: | ||
key: Unique identifier for the transaction record. | ||
holder: Unique identifier for the holder of the lock. If not provided, | ||
a default holder based on the current host, process, and thread will | ||
be used. | ||
acquire_timeout: Max number of seconds to wait for the record to become | ||
available if it is locked while attempting to acquire a lock. Pass 0 | ||
to attempt to acquire a lock without waiting. Blocks indefinitely by | ||
default. | ||
hold_timeout: Max number of seconds to hold the lock for. Holds the lock | ||
indefinitely by default. | ||
Example: | ||
Hold a lock while during an operation: | ||
```python | ||
with TransactionRecord(key="my-transaction-record-key").lock(): | ||
do_stuff() | ||
``` | ||
""" | ||
self.acquire_lock( | ||
key=key, | ||
holder=holder, | ||
acquire_timeout=acquire_timeout, | ||
hold_timeout=hold_timeout, | ||
) | ||
|
||
try: | ||
yield | ||
finally: | ||
self.release_lock(key=key, holder=holder) | ||
|
||
|
||
@dataclass | ||
class TransactionRecord: | ||
""" | ||
A dataclass representation of a transaction record. | ||
""" | ||
|
||
key: str | ||
result: Optional[BaseResult] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import threading | ||
from typing import Dict, Optional, TypedDict | ||
|
||
from prefect.results import BaseResult | ||
|
||
from .base import RecordStore, TransactionRecord | ||
|
||
|
||
class _LockInfo(TypedDict): | ||
""" | ||
A dictionary containing information about a lock. | ||
Attributes: | ||
holder: The holder of the lock. | ||
lock: The lock object. | ||
expiration_timer: The timer for the lock expiration | ||
""" | ||
|
||
holder: str | ||
lock: threading.Lock | ||
expiration_timer: Optional[threading.Timer] | ||
|
||
|
||
class MemoryRecordStore(RecordStore): | ||
""" | ||
A record store that stores data in memory. | ||
""" | ||
|
||
_instance = None | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if cls._instance is None: | ||
cls._instance = super().__new__(cls) | ||
return cls._instance | ||
|
||
def __init__(self): | ||
self._locks_dict_lock = threading.Lock() | ||
self._locks: Dict[str, _LockInfo] = {} | ||
self._records: Dict[str, TransactionRecord] = {} | ||
|
||
def read(self, key: str) -> Optional[TransactionRecord]: | ||
return self._records.get(key) | ||
|
||
def write(self, key: str, value: BaseResult, holder: Optional[str] = None) -> None: | ||
holder = holder or self.generate_default_holder() | ||
|
||
with self._locks_dict_lock: | ||
if self.is_locked(key) and not self.is_lock_holder(key, holder): | ||
raise ValueError( | ||
f"Cannot write to transaction with key {key} because it is locked by another holder." | ||
) | ||
self._records[key] = TransactionRecord(key=key, result=value) | ||
|
||
def exists(self, key: str) -> bool: | ||
return key in self._records | ||
|
||
def _expire_lock(self, key: str): | ||
""" | ||
Expire the lock for the given key. | ||
Used as a callback for the expiration timer of a lock. | ||
Args: | ||
key: The key of the lock to expire. | ||
""" | ||
with self._locks_dict_lock: | ||
if key in self._locks: | ||
lock_info = self._locks[key] | ||
if lock_info["lock"].locked(): | ||
lock_info["lock"].release() | ||
if lock_info["expiration_timer"]: | ||
lock_info["expiration_timer"].cancel() | ||
del self._locks[key] | ||
|
||
def acquire_lock( | ||
self, | ||
key: str, | ||
holder: Optional[str] = None, | ||
acquire_timeout: Optional[float] = None, | ||
hold_timeout: Optional[float] = None, | ||
) -> bool: | ||
holder = holder or self.generate_default_holder() | ||
with self._locks_dict_lock: | ||
if key not in self._locks: | ||
lock = threading.Lock() | ||
lock.acquire() | ||
expiration_timer = None | ||
if hold_timeout is not None: | ||
expiration_timer = threading.Timer( | ||
hold_timeout, self._expire_lock, args=(key,) | ||
) | ||
expiration_timer.start() | ||
self._locks[key] = _LockInfo( | ||
holder=holder, lock=lock, expiration_timer=expiration_timer | ||
) | ||
return True | ||
elif self._locks[key]["holder"] == holder: | ||
return True | ||
else: | ||
existing_lock_info = self._locks[key] | ||
|
||
if acquire_timeout is not None: | ||
existing_lock_acquired = existing_lock_info["lock"].acquire( | ||
timeout=acquire_timeout | ||
) | ||
else: | ||
existing_lock_acquired = existing_lock_info["lock"].acquire() | ||
|
||
if existing_lock_acquired: | ||
with self._locks_dict_lock: | ||
if ( | ||
expiration_timer := existing_lock_info["expiration_timer"] | ||
) is not None: | ||
expiration_timer.cancel() | ||
expiration_timer = None | ||
if hold_timeout is not None: | ||
expiration_timer = threading.Timer( | ||
hold_timeout, self._expire_lock, args=(key,) | ||
) | ||
expiration_timer.start() | ||
self._locks[key] = _LockInfo( | ||
holder=holder, | ||
lock=existing_lock_info["lock"], | ||
expiration_timer=expiration_timer, | ||
) | ||
return True | ||
return False | ||
|
||
def release_lock(self, key: str, holder: Optional[str] = None) -> None: | ||
holder = holder or self.generate_default_holder() | ||
with self._locks_dict_lock: | ||
if key in self._locks and self._locks[key]["holder"] == holder: | ||
if ( | ||
expiration_timer := self._locks[key]["expiration_timer"] | ||
) is not None: | ||
expiration_timer.cancel() | ||
self._locks[key]["lock"].release() | ||
del self._locks[key] | ||
else: | ||
raise ValueError( | ||
f"No lock held by {holder} for transaction with key {key}" | ||
) | ||
|
||
def is_locked(self, key: str) -> bool: | ||
return key in self._locks and self._locks[key]["lock"].locked() | ||
|
||
def is_lock_holder(self, key: str, holder: Optional[str] = None) -> bool: | ||
holder = holder or self.generate_default_holder() | ||
lock_info = self._locks.get(key) | ||
return ( | ||
lock_info is not None | ||
and lock_info["lock"].locked() | ||
and lock_info["holder"] == holder | ||
) | ||
|
||
def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool: | ||
if lock := self._locks.get(key, {}).get("lock"): | ||
if timeout is not None: | ||
lock_acquired = lock.acquire(timeout=timeout) | ||
else: | ||
lock_acquired = lock.acquire() | ||
if lock_acquired: | ||
lock.release() | ||
return lock_acquired | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.