Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: Use timestamp returned from the server in QueryResult #529

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,15 @@ def all(self, load_documents=True, load_annotations=True) -> "QueryResult":
"""
Executes the query and returns a :class:`.QueryResult` object containing all datapoints

If there's an active MLflow run, logs an artifact with information about the query to the run.

Args:
load_documents: Automatically download all document blob fields
load_annotations: Automatically download all annotation blob fields
"""
self._check_preprocess()
self._autolog_mlflow()
res = self._source.client.get_datapoints(self)
self._autolog_mlflow(res)
res._load_autoload_fields(documents=load_documents, annotations=load_annotations)

return res
Expand Down Expand Up @@ -833,7 +835,7 @@ def save_dataset(self, name: str) -> "Datasource":
copy_with_ds_assigned.load_from_dataset(dataset_name=name, change_query=False)
return copy_with_ds_assigned

def _autolog_mlflow(self):
def _autolog_mlflow(self, qr: "QueryResult"):
if not is_mlflow_installed:
return
# Run ONLY if there's an active run going on
Expand All @@ -842,21 +844,32 @@ def _autolog_mlflow(self):
return
source_name = self.source.name

now_time = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename
now_time = qr.query_data_time.strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename
uuid_chunk = str(uuid.uuid4())[-4:]

artifact_name = f"autolog_{source_name}_{now_time}_{uuid_chunk}.dagshub.json"
threading.Thread(target=self.log_to_mlflow, kwargs={"artifact_name": artifact_name, "run": active_run}).start()
threading.Thread(
target=self.log_to_mlflow,
kwargs={"artifact_name": artifact_name, "run": active_run, "as_of": qr.query_data_time},
).start()

def log_to_mlflow(
self, artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME, run: Optional["mlflow.entities.Run"] = None
self,
artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME,
run: Optional["mlflow.entities.Run"] = None,
as_of: Optional[datetime.datetime] = None,
) -> "mlflow.Entities.Run":
"""
Logs the current datasource state to MLflow as an artifact.

Args:
artifact_name: Name of the artifact that will be stored in the MLflow run.
run: MLflow run to save to. If ``None``, uses the active MLflow run or creates a new run.
as_of: The querying time for which to save the artifact.
Any time the datasource is recreated from the artifact, it will be queried as of this timestamp.
If None, the current machine time will be used.
If the artifact is autologged to MLflow (will happen if you have an active MLflow run),
then the timestamp of the query will be used.

Returns:
Run to which the artifact was logged.
Expand All @@ -869,7 +882,7 @@ def log_to_mlflow(
client.set_tag(run.info.run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id)
if self.assigned_dataset is not None:
client.set_tag(run.info.run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id)
client.log_dict(run.info.run_id, self._to_dict(), artifact_name)
client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name)
log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"')
return run

Expand Down Expand Up @@ -902,23 +915,30 @@ def save_to_file(self, path: Union[str, PathLike] = ".") -> Path:

return path

def _serialize(self) -> "DatasourceSerializedState":
def _serialize(self, as_of: datetime.datetime) -> "DatasourceSerializedState":
res = DatasourceSerializedState(
repo=self.source.repo,
datasource_id=self.source.id,
datasource_name=self.source.name,
query=self._query,
timestamp=datetime.datetime.now().timestamp(),
timestamp=as_of.timestamp(),
modified=self.is_query_different_from_dataset,
link=self._generate_visualize_url(),
)
if self.assigned_dataset is not None:
res.dataset_id = self.assigned_dataset.dataset_id
res.dataset_name = self.assigned_dataset.dataset_name
if self._query.as_of is not None:
res.timed_link = res.link
elif as_of is not None:
timed_ds = self.as_of(as_of)
res.timed_link = timed_ds._generate_visualize_url()
return res

def _to_dict(self) -> Dict:
res = self._serialize().to_dict()
def _to_dict(self, as_of: Optional[datetime.datetime] = None) -> Dict:
if as_of is None:
as_of = datetime.datetime.now()
res = self._serialize(as_of).to_dict()
# Skip Nones in the result
res = {k: v for k, v in res.items() if v is not None}
return res
Expand Down Expand Up @@ -1779,6 +1799,8 @@ class DatasourceSerializedState(DataClassJsonMixin):
"""Does the query differ from the query in the assigned dataset"""
link: Optional[str] = None
"""URL to open this datasource on DagsHub"""
timed_link: Optional[str] = None
"""URL to open this datasource with the data at the time of querying"""


@dataclass
Expand Down
Loading