diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index c2b241bf..2c6be5f0 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -294,13 +294,15 @@ def all(self, load_documents=True, load_annotations=True) -> "QueryResult": """ Executes the query and returns a :class:`.QueryResult` object containing all datapoints + If there's an active MLflow run, logs an artifact with information about the query to the run. + Args: load_documents: Automatically download all document blob fields load_annotations: Automatically download all annotation blob fields """ self._check_preprocess() - self._autolog_mlflow() res = self._source.client.get_datapoints(self) + self._autolog_mlflow(res) res._load_autoload_fields(documents=load_documents, annotations=load_annotations) return res @@ -833,7 +835,7 @@ def save_dataset(self, name: str) -> "Datasource": copy_with_ds_assigned.load_from_dataset(dataset_name=name, change_query=False) return copy_with_ds_assigned - def _autolog_mlflow(self): + def _autolog_mlflow(self, qr: "QueryResult"): if not is_mlflow_installed: return # Run ONLY if there's an active run going on @@ -842,14 +844,20 @@ def _autolog_mlflow(self): return source_name = self.source.name - now_time = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename + now_time = qr.query_data_time.strftime("%Y-%m-%dT%H-%M-%S") # Not ISO format to make it a valid filename uuid_chunk = str(uuid.uuid4())[-4:] artifact_name = f"autolog_{source_name}_{now_time}_{uuid_chunk}.dagshub.json" - threading.Thread(target=self.log_to_mlflow, kwargs={"artifact_name": artifact_name, "run": active_run}).start() + threading.Thread( + target=self.log_to_mlflow, + kwargs={"artifact_name": artifact_name, "run": active_run, "as_of": qr.query_data_time}, + ).start() def log_to_mlflow( - self, artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME, run: Optional["mlflow.entities.Run"] = None + self, + artifact_name=DEFAULT_MLFLOW_ARTIFACT_NAME, + run: Optional["mlflow.entities.Run"] = None, + as_of: Optional[datetime.datetime] = None, ) -> "mlflow.Entities.Run": """ Logs the current datasource state to MLflow as an artifact. @@ -857,6 +865,11 @@ def log_to_mlflow( Args: artifact_name: Name of the artifact that will be stored in the MLflow run. run: MLflow run to save to. If ``None``, uses the active MLflow run or creates a new run. + as_of: The querying time for which to save the artifact. + Any time the datasource is recreated from the artifact, it will be queried as of this timestamp. + If None, the current machine time will be used. + If the artifact is autologged to MLflow (will happen if you have an active MLflow run), + then the timestamp of the query will be used. Returns: Run to which the artifact was logged. @@ -869,7 +882,7 @@ def log_to_mlflow( client.set_tag(run.info.run_id, MLFLOW_DATASOURCE_TAG_NAME, self.source.id) if self.assigned_dataset is not None: client.set_tag(run.info.run_id, MLFLOW_DATASET_TAG_NAME, self.assigned_dataset.dataset_id) - client.log_dict(run.info.run_id, self._to_dict(), artifact_name) + client.log_dict(run.info.run_id, self._to_dict(as_of), artifact_name) log_message(f'Saved the datasource state to MLflow (run "{run.info.run_name}") as "{artifact_name}"') return run @@ -902,23 +915,30 @@ def save_to_file(self, path: Union[str, PathLike] = ".") -> Path: return path - def _serialize(self) -> "DatasourceSerializedState": + def _serialize(self, as_of: datetime.datetime) -> "DatasourceSerializedState": res = DatasourceSerializedState( repo=self.source.repo, datasource_id=self.source.id, datasource_name=self.source.name, query=self._query, - timestamp=datetime.datetime.now().timestamp(), + timestamp=as_of.timestamp(), modified=self.is_query_different_from_dataset, link=self._generate_visualize_url(), ) if self.assigned_dataset is not None: res.dataset_id = self.assigned_dataset.dataset_id res.dataset_name = self.assigned_dataset.dataset_name + if self._query.as_of is not None: + res.timed_link = res.link + elif as_of is not None: + timed_ds = self.as_of(as_of) + res.timed_link = timed_ds._generate_visualize_url() return res - def _to_dict(self) -> Dict: - res = self._serialize().to_dict() + def _to_dict(self, as_of: Optional[datetime.datetime] = None) -> Dict: + if as_of is None: + as_of = datetime.datetime.now() + res = self._serialize(as_of).to_dict() # Skip Nones in the result res = {k: v for k, v in res.items() if v is not None} return res @@ -1779,6 +1799,8 @@ class DatasourceSerializedState(DataClassJsonMixin): """Does the query differ from the query in the assigned dataset""" link: Optional[str] = None """URL to open this datasource on DagsHub""" + timed_link: Optional[str] = None + """URL to open this datasource with the data at the time of querying""" @dataclass