diff --git a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/__init__.py b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/__init__.py index e879f35..51cb281 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/__init__.py +++ b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/__init__.py @@ -2,9 +2,12 @@ import yaml from typer import Argument, Typer +from typing_extensions import Annotated from ... import gql, ui from ...config_utils import dagster_cloud_options +from ...core.artifacts import download_artifact, upload_artifact +from ...core.headers.auth import DagsterCloudInstanceScope from ...utils import create_stub_app try: @@ -43,3 +46,59 @@ def get_command( with gql.graphql_client_from_url(url, api_token) as client: settings = gql.get_deployment_settings(client) ui.print_yaml(settings) + + +@app.command(name="upload-artifact") +@dagster_cloud_options(allow_empty=True, requires_url=True) +def upload_artifact_command( + key: Annotated[ + str, + Argument( + help="The key for the artifact being uploaded.", + ), + ], + path: Annotated[ + str, + Argument( + help="The path to the file to upload.", + ), + ], + api_token: str, + url: str, +): + """Upload a deployment scoped artifact.""" + upload_artifact( + url=url, + scope=DagsterCloudInstanceScope.DEPLOYMENT, + api_token=api_token, + key=key, + path=path, + ) + + +@app.command(name="download-artifact") +@dagster_cloud_options(allow_empty=True, requires_url=True) +def download_artifact_command( + key: Annotated[ + str, + Argument( + help="The key for the artifact to download.", + ), + ], + path: Annotated[ + str, + Argument( + help="Path to the file that the contents should be written to.", + ), + ], + api_token: str, + url: str, +): + """Download a deployment scoped artifact.""" + download_artifact( + url=url, + scope=DagsterCloudInstanceScope.DEPLOYMENT, + api_token=api_token, + key=key, + path=path, + ) diff --git a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/commands.py b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/commands.py index dd8a50a..7727415 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/commands.py +++ b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/commands.py @@ -5,7 +5,7 @@ from .... import gql, ui from ....config_utils import dagster_cloud_options -from .config_schema import process_alert_policies_config +from .config_schema import INSIGHTS_ALERT_POLICIES_SCHEMA, process_alert_policies_config DEFAULT_ALERT_POLICIES_YAML_FILENAME = "alert_policies.yaml" @@ -44,7 +44,7 @@ def sync_command( config = yaml.load(f.read(), Loader=yaml.SafeLoader) try: - process_alert_policies_config(config) + process_alert_policies_config(config, INSIGHTS_ALERT_POLICIES_SCHEMA) alert_policies = gql.reconcile_alert_policies(client, config) diff --git a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/config_schema.py b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/config_schema.py index 0ba6e87..de8a6a0 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/config_schema.py +++ b/dagster-cloud-cli/dagster_cloud_cli/commands/deployment/alert_policies/config_schema.py @@ -1,3 +1,5 @@ +from typing import Any + import dagster._check as check from dagster import Array, Enum, EnumValue, Field, Selector, Shape from dagster._config import validate_config @@ -5,23 +7,25 @@ SINGLETON_REPOSITORY_NAME, ) +from dagster_cloud_cli.core.alert_types import InsightsAlertComparisonOperator -def validate_alert_policy_config(alert_policy_config): - validation = validate_config(ALERT_POLICY_SCHEMA, alert_policy_config) - return [error.message for error in validation.errors] +def validate_alert_policy_config(alert_policy_config, schema: Any): + validation = validate_config(schema, alert_policy_config) + return [error.message for error in validation.errors] if validation.errors else [] -def validate_alert_policies_config(alert_policies_config): - validation = validate_config(ALERT_POLICIES_SCHEMA, alert_policies_config) - return [error.message for error in validation.errors] +def validate_alert_policies_config(alert_policies_config, schema: Any): + validation = validate_config(schema, alert_policies_config) + return [error.message for error in validation.errors] if validation.errors else [] -def process_alert_policies_config(alert_policies_config): + +def process_alert_policies_config(alert_policies_config, schema: Any): validation = validate_config(ALERT_POLICIES_SCHEMA, alert_policies_config) check.invariant( validation.success, - ", ".join([error.message for error in validation.errors]), + ", ".join([error.message for error in validation.errors] if validation.errors else []), ) # Validate each individual alert policy @@ -43,224 +47,389 @@ def process_alert_policies_config(alert_policies_config): ) -ALERT_POLICY_SCHEMA = Shape( - fields={ - "name": Field( - config=str, - is_required=True, - description="Alert policy name.", - ), - "description": Field( - config=str, - default_value="", - description="Description of alert policy", - ), - "tags": Field( - config=Array( - Shape( +TARGET_TYPES_SCHEMA = { + "asset_group_target": Field( + config=Shape( + fields={ + "asset_group": Field( + config=str, + is_required=True, + description="The name of the asset group.", + ), + "location_name": Field( + config=str, + is_required=True, + description=("The name of the code location that contains the asset" " group."), + ), + "repo_name": Field( + config=str, + is_required=False, + description=( + "The name of the repository that contains the asset" + " group. Only required if there are multiple" + " repositories with the same code location." + ), + default_value=SINGLETON_REPOSITORY_NAME, + ), + } + ) + ), + "asset_key_target": Field( + config=Shape( + fields={ + "asset_key": Field( + config=Array(str), + is_required=True, + description="The key of the asset.", + ) + } + ) + ), +} + + +insights_operator_enum = Enum.from_python_enum(InsightsAlertComparisonOperator) + + +INSIGHTS_TARGET_TYPES_SCHEMA = { + **TARGET_TYPES_SCHEMA, + "insights_deployment_threshold_target": Field( + config=Shape( + fields={ + "metric_name": Field( + config=str, + is_required=True, + description="The name of the metric to target.", + ), + "threshold": Field( + config=float, + is_required=True, + description="The threshold value to alert if exceeded.", + ), + "selection_period_days": Field( + config=int, + is_required=True, + description="The number of days to use for the selection period.", + ), + "operator": Field( + config=insights_operator_enum, + is_required=True, + description="The operator to use for the threshold comparison.", + ), + } + ) + ), + "insights_asset_group_threshold_target": Field( + config=Shape( + fields={ + "metric_name": Field( + config=str, + is_required=True, + description="The name of the metric to target.", + ), + "threshold": Field( + config=float, + is_required=True, + description="The threshold value to alert if exceeded.", + ), + "selection_period_days": Field( + config=int, + is_required=True, + description="The number of days to use for the selection period.", + ), + "operator": Field( + config=insights_operator_enum, + is_required=True, + description="The operator to use for the threshold comparison.", + ), + "asset_group": Shape( fields={ - "key": Field( + "location_name": Field( config=str, is_required=True, - description="Specify a tag key.", + description="The name of the code location that contains the asset group.", ), - "value": Field( + "asset_group_name": Field( config=str, is_required=True, - description="Specify a tag value.", - ), - }, - description="A tag key-value pair.", - ) - ), - description=( - "The alert policy will apply to code artifacts that have all the specified tags." - " When tags are explicitly omitted, this alert policy will apply to all code" - " artifacts." - ), - is_required=False, - ), - "event_types": Field( - config=Array( - Enum( - name="AlertPolicyEventType", - enum_values=[ - EnumValue("JOB_FAILURE", description="Alert on job failure."), - EnumValue("JOB_SUCCESS", description="Alert on job success."), - EnumValue("TICK_FAILURE", description="Alert on schedule/sensor failure."), - EnumValue("AGENT_UNAVAILABLE", description="Alert on agent downtime."), - EnumValue( - "CODE_LOCATION_ERROR", description="Alert on code location error." - ), - EnumValue( - "ASSET_MATERIALIZATION_SUCCESS", - description="Alert when an asset successfully materializes.", + description="The name of the asset group.", ), - EnumValue( - "ASSET_MATERIALIZATION_FAILURE", + "repo_name": Field( + config=str, + is_required=False, description=( - "Alert when a planned asset materialization fails to occur." + "The name of the repository that contains the asset group." ), + default_value=SINGLETON_REPOSITORY_NAME, ), - EnumValue( - "ASSET_CHECK_PASSED", description="Alert on asset check success." + } + ), + } + ) + ), + "insights_asset_threshold_target": Field( + config=Shape( + fields={ + "metric_name": Field( + config=str, + is_required=True, + description="The name of the metric to target.", + ), + "threshold": Field( + config=float, + is_required=True, + description="The threshold value to alert if exceeded.", + ), + "selection_period_days": Field( + config=int, + is_required=True, + description="The number of days to use for the selection period.", + ), + "operator": Field( + config=insights_operator_enum, + is_required=True, + description="The operator to use for the threshold comparison.", + ), + "asset_key": Field( + config=Array(str), + is_required=True, + description="The key of the asset.", + ), + } + ) + ), + "insights_job_threshold_target": Field( + config=Shape( + fields={ + "metric_name": Field( + config=str, + is_required=True, + description="The name of the metric to target.", + ), + "threshold": Field( + config=float, + is_required=True, + description="The threshold value to alert if exceeded.", + ), + "selection_period_days": Field( + config=int, + is_required=True, + description="The number of days to use for the selection period.", + ), + "operator": Field( + config=insights_operator_enum, + is_required=True, + description="The operator to use for the threshold comparison.", + ), + "job": Shape( + fields={ + "job_name": Field( + config=str, + is_required=True, + description="The name of the job.", ), - EnumValue( - "ASSET_CHECK_EXECUTION_FAILURE", - description=( - "Alert when a planned asset check fails before it evaluates." - ), + "location_name": Field( + config=str, + is_required=True, + description="The name of the code location that contains the job.", ), - EnumValue( - "ASSET_CHECK_SEVERITY_WARN", - description=( - "Alert when a planned asset check fails with severity warn." - ), + "repo_name": Field( + config=str, + is_required=False, + description=("The name of the repository that contains the job."), + default_value=SINGLETON_REPOSITORY_NAME, ), - EnumValue( - "ASSET_CHECK_SEVERITY_ERROR", - description=( - "Alert when a planned asset check fails with severity error." + } + ), + } + ) + ), +} + +ALERT_EVENT_TYPES = [ + EnumValue("JOB_FAILURE", description="Alert on job failure."), + EnumValue("JOB_SUCCESS", description="Alert on job success."), + EnumValue("TICK_FAILURE", description="Alert on schedule/sensor failure."), + EnumValue("AGENT_UNAVAILABLE", description="Alert on agent downtime."), + EnumValue("CODE_LOCATION_ERROR", description="Alert on code location error."), + EnumValue( + "ASSET_MATERIALIZATION_SUCCESS", + description="Alert when an asset successfully materializes.", + ), + EnumValue( + "ASSET_MATERIALIZATION_FAILURE", + description=("Alert when a planned asset materialization fails to occur."), + ), + EnumValue("ASSET_CHECK_PASSED", description="Alert on asset check success."), + EnumValue( + "ASSET_CHECK_EXECUTION_FAILURE", + description=("Alert when a planned asset check fails before it evaluates."), + ), + EnumValue( + "ASSET_CHECK_SEVERITY_WARN", + description=("Alert when a planned asset check fails with severity warn."), + ), + EnumValue( + "ASSET_CHECK_SEVERITY_ERROR", + description=("Alert when a planned asset check fails with severity error."), + ), + EnumValue( + "ASSET_OVERDUE", + description="Alert when an asset is overdue, based on its freshness policy.", + ), +] + +INSIGHTS_ALERT_EVENT_TYPES = [ + *ALERT_EVENT_TYPES, + EnumValue( + "INSIGHTS_CONSUMPTION_EXCEEDED", + description="Alert when insights consumption exceeds the threshold.", + ), +] + +ALERT_POLICY_SCHEMA, INSIGHTS_ALERT_POLICY_SCHEMA = [ + Shape( + fields={ + "name": Field( + config=str, + is_required=True, + description="Alert policy name.", + ), + "description": Field( + config=str, + default_value="", + description="Description of alert policy", + ), + "tags": Field( + config=Array( + Shape( + fields={ + "key": Field( + config=str, + is_required=True, + description="Specify a tag key.", ), - ), - EnumValue( - "ASSET_OVERDUE", - description="Alert when an asset is overdue, based on its freshness policy.", - ), - ], - ) + "value": Field( + config=str, + is_required=True, + description="Specify a tag value.", + ), + }, + description="A tag key-value pair.", + ) + ), + description=( + "The alert policy will apply to code artifacts that have all the specified tags." + " When tags are explicitly omitted, this alert policy will apply to all code" + " artifacts." + ), + is_required=False, ), - description="The selected system event types that will trigger the alert policy.", - ), - "notification_service": Field( - Selector( - fields={ - "email": Field( - config=Shape( - fields={ - "email_addresses": Field( - config=Array(str), - is_required=True, - description="Email addresses to send alerts to.", - ) - } - ), - description=( - "Details to customize email notifications for this alert policy." - ), - ), - "slack": Field( - config=Shape( - fields={ - "slack_workspace_name": Field( - config=str, - is_required=True, - description="The name of your slack workspace.", - ), - "slack_channel_name": Field( - config=str, - is_required=True, - description=( - "The name of the slack channel in which to post alerts." - ), - ), - } - ) - ), - "email_owners": Field(config=Shape(fields={})), - "microsoft_teams": Field( - config=Shape( - fields={ - "webhook_url": Field( - config=str, - is_required=True, - description="The incoming webhook URL for your Microsoft Team connector. " - "Must match the form https://xxxxx.webhook.office.com/xxxxx", - ) - } - ) - ), - "pagerduty": Field( - config=Shape( - fields={ - "integration_key": Field( - config=str, - is_required=True, - description="The integration key for your PagerDuty app.", - ) - } - ) - ), - } + "event_types": Field( + config=Array( + Enum( + name="AlertPolicyEventType", + enum_values=event_types, + ) + ), + description="The selected system event types that will trigger the alert policy.", ), - is_required=True, - description="Configure how the alert policy should send a notification.", - ), - "enabled": Field( - config=bool, - default_value=True, - description="Whether the alert policy is active or not.", - ), - "alert_targets": Field( - config=Array( + "notification_service": Field( Selector( fields={ - "asset_group_target": Field( + "email": Field( + config=Shape( + fields={ + "email_addresses": Field( + config=Array(str), + is_required=True, + description="Email addresses to send alerts to.", + ) + } + ), + description=( + "Details to customize email notifications for this alert policy." + ), + ), + "slack": Field( config=Shape( fields={ - "asset_group": Field( + "slack_workspace_name": Field( config=str, is_required=True, - description="The name of the asset group.", + description="The name of your slack workspace.", ), - "location_name": Field( + "slack_channel_name": Field( config=str, is_required=True, description=( - "The name of the code location that contains the asset" - " group." + "The name of the slack channel in which to post alerts." ), ), - "repo_name": Field( + } + ) + ), + "email_owners": Field(config=Shape(fields={})), + "microsoft_teams": Field( + config=Shape( + fields={ + "webhook_url": Field( config=str, - is_required=False, - description=( - "The name of the repository that contains the asset" - " group. Only required if there are multiple" - " repositories with the same code location." - ), - default_value=SINGLETON_REPOSITORY_NAME, - ), + is_required=True, + description="The incoming webhook URL for your Microsoft Team connector. " + "Must match the form https://xxxxx.webhook.office.com/xxxxx", + ) } ) ), - "asset_key_target": Field( + "pagerduty": Field( config=Shape( fields={ - "asset_key": Field( - config=Array(str), + "integration_key": Field( + config=str, is_required=True, - description="The key of the asset.", + description="The integration key for your PagerDuty app.", ) } ) ), - }, - description=( - "Information for targeting events for this alert policy. If no target is" - " specified, the alert policy will apply to all events of a particular" - " type." + } + ), + is_required=True, + description="Configure how the alert policy should send a notification.", + ), + "enabled": Field( + config=bool, + default_value=True, + description="Whether the alert policy is active or not.", + ), + "alert_targets": Field( + config=Array( + Selector( + fields=target_types_schema, + description=( + "Information for targeting events for this alert policy. If no target is" + " specified, the alert policy will apply to all events of a particular" + " type." + ), ), ), + is_required=False, ), - is_required=False, - ), - }, - description="Details to customize an alert policy in Dagster Cloud.", -) + }, + description="Details to customize an alert policy in Dagster Cloud.", + ) + for target_types_schema, event_types in ( + (TARGET_TYPES_SCHEMA, ALERT_EVENT_TYPES), + (INSIGHTS_TARGET_TYPES_SCHEMA, INSIGHTS_ALERT_EVENT_TYPES), + ) +] -ALERT_POLICIES_SCHEMA = Shape( - fields={ - "alert_policies": Array(ALERT_POLICY_SCHEMA), - } -) +INSIGHTS_ALERT_POLICIES_SCHEMA, ALERT_POLICIES_SCHEMA = [ + Shape( + fields={ + "alert_policies": Array(alert_policy_schema), + } + ) + for alert_policy_schema in (INSIGHTS_ALERT_POLICY_SCHEMA, ALERT_POLICY_SCHEMA) +] diff --git a/dagster-cloud-cli/dagster_cloud_cli/commands/organization/__init__.py b/dagster-cloud-cli/dagster_cloud_cli/commands/organization/__init__.py index ba44099..924909d 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/commands/organization/__init__.py +++ b/dagster-cloud-cli/dagster_cloud_cli/commands/organization/__init__.py @@ -2,9 +2,12 @@ import yaml from typer import Argument, Typer +from typing_extensions import Annotated from ... import gql, ui from ...config_utils import dagster_cloud_options +from ...core.artifacts import download_artifact, upload_artifact +from ...core.headers.auth import DagsterCloudInstanceScope from .saml import commands as saml_cli app = Typer(help="Customize your Dagster Cloud organization.") @@ -61,3 +64,71 @@ def get_command( with gql.graphql_client_from_url(url, api_token) as client: settings = gql.get_organization_settings(client) ui.print_yaml(settings) + + +@app.command(name="upload-artifact") +@dagster_cloud_options(allow_empty=True, requires_url=True) +def upload_artifact_command( + key: Annotated[ + str, + Argument( + help="The key for the artifact being uploaded.", + ), + ], + path: Annotated[ + str, + Argument( + help="The path to the file to upload.", + ), + ], + organization: str, + api_token: str, + url: str, +): + """Upload a organization scoped artifact.""" + if not url and not organization: + raise ui.error("Must provide either organization name or URL.") + if not url: + url = gql.url_from_config(organization=organization) + + upload_artifact( + url=url, + scope=DagsterCloudInstanceScope.ORGANIZATION, + api_token=api_token, + key=key, + path=path, + ) + + +@app.command(name="download-artifact") +@dagster_cloud_options(allow_empty=True, requires_url=True) +def download_artifact_command( + key: Annotated[ + str, + Argument( + help="The key for the artifact to download.", + ), + ], + path: Annotated[ + str, + Argument( + help="Path to the file that the contents should be written to.", + ), + ], + organization: str, + api_token: str, + url: str, +): + """Download a organization scoped artifact.""" + if not url and not organization: + raise ui.error("Must provide either organization name or URL.") + if not url: + url = gql.url_from_config(organization=organization) + + download_artifact( + url=url, + scope=DagsterCloudInstanceScope.ORGANIZATION, + api_token=api_token, + key=key, + path=path, + ) diff --git a/dagster-cloud-cli/dagster_cloud_cli/commands/workspace/__init__.py b/dagster-cloud-cli/dagster_cloud_cli/commands/workspace/__init__.py index 3027493..3bc3b39 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/commands/workspace/__init__.py +++ b/dagster-cloud-cli/dagster_cloud_cli/commands/workspace/__init__.py @@ -1,6 +1,6 @@ import time from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional import dagster._check as check import yaml @@ -15,7 +15,7 @@ dagster_cloud_options, get_location_document, ) -from dagster_cloud_cli.core.graphql_client import GqlShimClient +from dagster_cloud_cli.core.graphql_client import DagsterCloudGraphQLClient from dagster_cloud_cli.core.workspace import CodeDeploymentMetadata from dagster_cloud_cli.utils import add_options @@ -44,7 +44,7 @@ def _get_location_input(location: str, kwargs: Dict[str, Any]) -> gql.CliInputCo def _add_or_update_location( - client: GqlShimClient, + client: DagsterCloudGraphQLClient, location_document: Dict[str, Any], location_load_timeout: int, agent_heartbeat_timeout: int, @@ -120,6 +120,24 @@ def update_command( ) +def _format_error(load_error: Mapping[str, Any]): + result = [ + load_error["message"], + "".join(load_error["stack"]), + ] + + for chain_link in load_error["errorChain"]: + result.append( + "The above exception was caused by the following exception:" + if chain_link["isExplicitLink"] + else "The above exception occurred during handling of the following exception:" + ) + + result.extend([chain_link["error"]["message"], "".join(chain_link["error"]["stack"])]) + + return "\n".join(result) + + def wait_for_load( client, locations, @@ -179,8 +197,8 @@ def wait_for_load( "Some locations failed to load after being synced by the agent:\n" + "\n".join( [ - f"Error loading {error_location}:" - f" {nodes_by_location[error_location]['locationOrLoadError']}" + f"Error loading {error_location}:\n" + f"{_format_error(nodes_by_location[error_location]['locationOrLoadError'])}" for error_location in error_locations ] ) diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/_utils.py b/dagster-cloud-cli/dagster_cloud_cli/core/_utils.py deleted file mode 100644 index 66e7be3..0000000 --- a/dagster-cloud-cli/dagster_cloud_cli/core/_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -# Same as dagster._utils.merge_dicts -def merge_dicts(*args) -> dict: - """Returns a dictionary with with all the keys in all of the input dictionaries. - - If multiple input dictionaries have different values for the same key, the returned dictionary - contains the value from the dictionary that comes latest in the list. - """ - result: dict = {} - for arg in args: - result.update(arg) - return result diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/alert_types.py b/dagster-cloud-cli/dagster_cloud_cli/core/alert_types.py new file mode 100644 index 0000000..58d5369 --- /dev/null +++ b/dagster-cloud-cli/dagster_cloud_cli/core/alert_types.py @@ -0,0 +1,26 @@ +from enum import Enum + +from dagster._serdes import whitelist_for_serdes + + +@whitelist_for_serdes +class InsightsAlertComparisonOperator(Enum): + """Possible comparison operators for an insights alert type, used to + determine when to trigger an alert based on the value of the metric. + """ + + LESS_THAN = "LESS_THAN" + GREATER_THAN = "GREATER_THAN" + + def compare(self, computed_value: float, target_value: float) -> bool: + if self == InsightsAlertComparisonOperator.LESS_THAN: + return computed_value < target_value + return computed_value > target_value + + def as_text(self) -> str: + """Used in alert text to describe the comparison operator, + e.g. usage is less than the limit or usage is greater than the limit. + """ + if self == InsightsAlertComparisonOperator.LESS_THAN: + return "less than" + return "greater than" diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/artifacts.py b/dagster-cloud-cli/dagster_cloud_cli/core/artifacts.py new file mode 100644 index 0000000..817bfa8 --- /dev/null +++ b/dagster-cloud-cli/dagster_cloud_cli/core/artifacts.py @@ -0,0 +1,190 @@ +import os +from pathlib import Path +from typing import Optional, Union + +import requests + +from dagster_cloud_cli.config_utils import ( + URL_ENV_VAR_NAME, + get_deployment, + get_organization, + get_user_token, +) +from dagster_cloud_cli.core.headers.auth import DagsterCloudInstanceScope +from dagster_cloud_cli.core.headers.impl import get_dagster_cloud_api_headers +from dagster_cloud_cli.gql import url_from_config + + +def _dagster_cloud_http_client(): + # indirection for mocking, could share a session + return requests + + +def download_artifact( + *, + url: str, + api_token: str, + scope: DagsterCloudInstanceScope, + key: str, + path: Union[Path, str], +): + response = _dagster_cloud_http_client().post( + url=f"{url}/gen_artifact_get", + headers=get_dagster_cloud_api_headers( + api_token, + scope, + ), + json={"key": key}, + ) + response.raise_for_status() + + payload = response.json() + + response = requests.get( + payload["url"], + ) + response.raise_for_status() + Path(path).write_bytes(response.content) + + +def upload_artifact( + *, + url: str, + api_token: str, + scope: DagsterCloudInstanceScope, + key: str, + path: Union[Path, str], + deployment: Optional[str] = None, +): + upload_file = Path(path).resolve(strict=True) + + response = _dagster_cloud_http_client().post( + url=f"{url}/gen_artifact_post", + headers=get_dagster_cloud_api_headers( + api_token, + scope, + ), + json={"key": key}, + ) + response.raise_for_status() + payload = response.json() + + response = requests.post( + payload["url"], + data=payload["fields"], + files={ + "file": upload_file.open(), + }, + ) + response.raise_for_status() + + +def _resolve_org(passed_org: Optional[str]) -> str: + org = passed_org or get_organization() + if org is None: + raise Exception( + "Unable to resolve organization, pass organization or set the " + "DAGSTER_CLOUD_ORGANIZATION environment variable." + ) + return org + + +def _resolve_token(passed_token: Optional[str]) -> str: + api_token = get_user_token() + if api_token is None: + raise Exception( + "Unable to resolve api_token, pass api_token or set the " + "DAGSTER_CLOUD_API_TOKEN environment variable" + ) + return api_token + + +def _resolve_deploy(passed_deploy: Optional[str]) -> str: + deploy = passed_deploy or get_deployment() + if deploy is None: + raise Exception( + "Unable to resolve deployment, pass deployment or set the " + "DAGSTER_CLOUD_DEPLOYMENT environment variable." + ) + return deploy + + +def _resolve_url(organization: str, deployment: Optional[str] = None) -> str: + env_url = os.getenv(URL_ENV_VAR_NAME) + if env_url: + return env_url + + return url_from_config(organization, deployment) + + +def upload_organization_artifact( + key: str, + path: Union[str, Path], + organization: Optional[str] = None, + api_token: Optional[str] = None, +): + upload_artifact( + url=_resolve_url( + organization=_resolve_org(organization), + ), + api_token=_resolve_token(api_token), + scope=DagsterCloudInstanceScope.ORGANIZATION, + key=key, + path=path, + ) + + +def upload_deployment_artifact( + key: str, + path: Union[str, Path], + organization: Optional[str] = None, + deployment: Optional[str] = None, + api_token: Optional[str] = None, +): + upload_artifact( + url=_resolve_url( + organization=_resolve_org(organization), + deployment=_resolve_deploy(deployment), + ), + api_token=_resolve_token(api_token), + scope=DagsterCloudInstanceScope.DEPLOYMENT, + key=key, + path=path, + ) + + +def download_organization_artifact( + key: str, + path: Union[str, Path], + organization: Optional[str] = None, + api_token: Optional[str] = None, +): + download_artifact( + url=_resolve_url( + organization=_resolve_org(organization), + ), + api_token=_resolve_token(api_token), + scope=DagsterCloudInstanceScope.ORGANIZATION, + key=key, + path=path, + ) + + +def download_deployment_artifact( + key: str, + path: Union[str, Path], + organization: Optional[str] = None, + deployment: Optional[str] = None, + api_token: Optional[str] = None, +): + deployment = _resolve_deploy(deployment) + download_artifact( + url=_resolve_url( + organization=_resolve_org(organization), + deployment=deployment, + ), + api_token=_resolve_token(api_token), + scope=DagsterCloudInstanceScope.DEPLOYMENT, + key=key, + path=path, + ) diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/errors.py b/dagster-cloud-cli/dagster_cloud_cli/core/errors.py index 5315159..9543650 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/core/errors.py +++ b/dagster-cloud-cli/dagster_cloud_cli/core/errors.py @@ -1,7 +1,7 @@ from requests import HTTPError -class GraphQLStorageError(Exception): +class DagsterCloudAgentServerError(Exception): """Raise this when there's an error in the GraphQL layer.""" diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/graphql_client.py b/dagster-cloud-cli/dagster_cloud_cli/core/graphql_client.py index 76abb35..6176ed9 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/core/graphql_client.py +++ b/dagster-cloud-cli/dagster_cloud_cli/core/graphql_client.py @@ -1,9 +1,9 @@ import logging import re import time -from contextlib import ExitStack, contextmanager +from contextlib import contextmanager from email.utils import mktime_tz, parsedate_tz -from typing import Any, Dict, Mapping, Optional +from typing import Any, Callable, Dict, Mapping, Optional import dagster._check as check import requests @@ -13,7 +13,11 @@ ReadTimeout as RequestsReadTimeout, ) -from .errors import DagsterCloudHTTPError, DagsterCloudMaintenanceException, GraphQLStorageError +from .errors import ( + DagsterCloudAgentServerError, + DagsterCloudHTTPError, + DagsterCloudMaintenanceException, +) from .headers.auth import DagsterCloudInstanceScope from .headers.impl import get_dagster_cloud_api_headers @@ -25,8 +29,6 @@ RETRY_STATUS_CODES = [ - # retry on server errors to recover on transient issue - 500, 502, 503, 504, @@ -34,9 +36,177 @@ ] -class GqlShimClient: - """Adapter for gql.Client that wraps errors in human-readable format.""" +class DagsterCloudAgentHttpClient: + def __init__( + self, + session: requests.Session, + headers: Optional[Dict[str, Any]] = None, + verify: bool = True, + timeout: int = DEFAULT_TIMEOUT, + cookies: Optional[Dict[str, Any]] = None, + proxies: Optional[Dict[str, Any]] = None, + max_retries: int = 0, + backoff_factor: float = DEFAULT_BACKOFF_FACTOR, + ): + self.headers = headers or {} + self.verify = verify + self.timeout = timeout + self.cookies = cookies + self._session = session + self._proxies = proxies + self._max_retries = max_retries + self._backoff_factor = backoff_factor + + @property + def session(self) -> requests.Session: + return self._session + + def post(self, *args, **kwargs): + return self.execute("POST", *args, **kwargs) + + def get(self, *args, **kwargs): + return self.execute("GET", *args, **kwargs) + + def put(self, *args, **kwargs): + return self.execute("PUT", *args, **kwargs) + + def execute( + self, + method: str, + url: str, + headers: Optional[Mapping[str, str]] = None, + idempotent: bool = False, + **kwargs, + ): + retry_on_read_timeout = idempotent or bool( + headers.get("Idempotency-Key") if headers else False + ) + + return _retry_loop( + lambda: self._execute_retry(method, url, headers, **kwargs), + max_retries=self._max_retries, + backoff_factor=self._backoff_factor, + retry_on_read_timeout=retry_on_read_timeout, + ) + def _execute_retry( + self, + method: str, + url: str, + headers: Optional[Mapping[str, Any]], + **kwargs, + ): + response = self._session.request( + method, + url, + headers={ + **(self.headers if self.headers is not None else {}), + **(headers if headers is not None else {}), + }, + cookies=self.cookies, + timeout=self.timeout, + verify=self.verify, + proxies=self._proxies, + **kwargs, + ) + try: + result = response.json() + if not isinstance(result, dict): + result = {} + except ValueError: + result = {} + + if "maintenance" in result: + maintenance_info = result["maintenance"] + raise DagsterCloudMaintenanceException( + message=maintenance_info.get("message"), + timeout=maintenance_info.get("timeout"), + retry_interval=maintenance_info.get("retry_interval"), + ) + + if "errors" in result: + raise DagsterCloudAgentServerError(f"Error in GraphQL response: {result['errors']}") + + response.raise_for_status() + + return result + + +def _retry_loop( + execute_retry: Callable, + max_retries: int, + backoff_factor: float, + retry_on_read_timeout: bool, +): + start_time = time.time() + retry_number = 0 + error_msg_set = set() + requested_sleep_time = None + while True: + try: + return execute_retry() + except (HTTPError, RequestsConnectionError, RequestsReadTimeout) as e: + retryable_error = False + if isinstance(e, HTTPError): + retryable_error = e.response.status_code in RETRY_STATUS_CODES + error_msg = e.response.status_code + requested_sleep_time = _get_retry_after_sleep_time(e.response.headers) + elif isinstance(e, RequestsReadTimeout): + retryable_error = retry_on_read_timeout + error_msg = str(e) + else: + retryable_error = True + error_msg = str(e) + + error_msg_set.add(error_msg) + if retryable_error and retry_number < max_retries: + retry_number += 1 + sleep_time = 0 + if requested_sleep_time: + sleep_time = requested_sleep_time + elif retry_number > 1: + sleep_time = backoff_factor * (2 ** (retry_number - 1)) + + if sleep_time > 0: + logger.warning( + f"Error in Dagster Cloud request ({error_msg}). Retrying in" + f" {sleep_time} seconds..." + ) + time.sleep(sleep_time) + else: + logger.warning(f"Error in Dagster Cloud request ({error_msg}). Retrying now.") + else: + # Throw the error straight if no retries were involved + if max_retries == 0 or not retryable_error: + if isinstance(e, HTTPError): + raise DagsterCloudHTTPError(e) from e + else: + raise + else: + if len(error_msg_set) == 1: + status_code_msg = str(next(iter(error_msg_set))) + else: + status_code_msg = str(error_msg_set) + raise DagsterCloudAgentServerError( + f"Max retries ({max_retries}) exceeded, too many" + f" {status_code_msg} error responses." + ) from e + except DagsterCloudMaintenanceException as e: + if time.time() - start_time > e.timeout: + raise + + logger.warning( + "Dagster Cloud is currently unavailable due to scheduled maintenance. Retrying" + f" in {e.retry_interval} seconds..." + ) + time.sleep(e.retry_interval) + except DagsterCloudAgentServerError: + raise + except Exception as e: + raise DagsterCloudAgentServerError(str(e)) from e + + +class DagsterCloudGraphQLClient: def __init__( self, url: str, @@ -49,8 +219,6 @@ def __init__( max_retries: int = 0, backoff_factor: float = DEFAULT_BACKOFF_FACTOR, ): - self._exit_stack = ExitStack() - self.url = url self.headers = headers self.verify = verify @@ -72,81 +240,20 @@ def execute( headers: Optional[Mapping[str, str]] = None, idempotent_mutation: bool = False, ): - start_time = time.time() - retry_number = 0 - error_msg_set = set() - requested_sleep_time = None - while True: - try: - return self._execute_retry(query, variable_values, headers) - except (HTTPError, RequestsConnectionError, RequestsReadTimeout) as e: - retryable_error = False - if isinstance(e, HTTPError): - retryable_error = e.response.status_code in RETRY_STATUS_CODES - error_msg = e.response.status_code - requested_sleep_time = _get_retry_after_sleep_time(e.response.headers) - elif isinstance(e, RequestsReadTimeout): - # "mutation " must appear in the document if its a mutation - if "mutation " in query and not idempotent_mutation: - # mutations can be made idempotent if they use Idempotency-Key header - retryable_error = ( - bool(headers.get("Idempotency-Key")) if headers is not None else False - ) - # otherwise assume its a query that is naturally idempotent - else: - retryable_error = True + if "mutation " in query and not idempotent_mutation: + # mutations can be made idempotent if they use Idempotency-Key header + retry_on_read_timeout = ( + bool(headers.get("Idempotency-Key")) if headers is not None else False + ) + else: + retry_on_read_timeout = True - error_msg = str(e) - else: - retryable_error = True - error_msg = str(e) - - error_msg_set.add(error_msg) - if retryable_error and retry_number < self._max_retries: - retry_number += 1 - sleep_time = 0 - if requested_sleep_time: - sleep_time = requested_sleep_time - elif retry_number > 1: - sleep_time = self._backoff_factor * (2 ** (retry_number - 1)) - - if sleep_time > 0: - logger.warning( - f"Error in Dagster Cloud request ({error_msg}). Retrying in" - f" {sleep_time} seconds..." - ) - time.sleep(sleep_time) - else: - logger.warning( - f"Error in Dagster Cloud request ({error_msg}). Retrying now." - ) - else: - # Throw the error straight if no retries were involved - if self._max_retries == 0 or not retryable_error: - if isinstance(e, HTTPError): - raise DagsterCloudHTTPError(e) from e - else: - raise GraphQLStorageError(str(e)) from e - else: - if len(error_msg_set) == 1: - status_code_msg = str(next(iter(error_msg_set))) - else: - status_code_msg = str(error_msg_set) - raise GraphQLStorageError( - f"Max retries ({self._max_retries}) exceeded, too many" - f" {status_code_msg} error responses." - ) from e - except DagsterCloudMaintenanceException as e: - if time.time() - start_time > e.timeout: - raise - - logger.warning( - "Dagster Cloud is currently unavailable due to scheduled maintenance. Retrying" - f" in {e.retry_interval} seconds..." - ) - time.sleep(e.retry_interval) - except Exception as e: - raise GraphQLStorageError(str(e)) from e + return _retry_loop( + lambda: self._execute_retry(query, variable_values, headers), + max_retries=self._max_retries, + backoff_factor=self._backoff_factor, + retry_on_read_timeout=retry_on_read_timeout, + ) def _execute_retry( self, @@ -190,9 +297,9 @@ def _execute_retry( ) if "errors" in result: - raise GraphQLStorageError(f"Error in GraphQL response: {result['errors']}") - else: - return result + raise DagsterCloudAgentServerError(f"Error in GraphQL response: {result['errors']}") + + return result def get_agent_headers(config_value: Dict[str, Any], scope: DagsterCloudInstanceScope): @@ -210,13 +317,33 @@ def create_graphql_requests_session(): yield session -def create_proxy_client( +def create_agent_http_client( + session: requests.Session, + config_value: Dict[str, Any], + scope: DagsterCloudInstanceScope = DagsterCloudInstanceScope.DEPLOYMENT, +): + return DagsterCloudAgentHttpClient( + headers=get_agent_headers(config_value, scope=scope), + verify=config_value.get("verify", True), + timeout=config_value.get("timeout", DEFAULT_TIMEOUT), + cookies=config_value.get("cookies", {}), + # Requests library modifies proxies dictionary so create a copy + proxies=( + check.is_dict(config_value.get("proxies")).copy() if config_value.get("proxies") else {} + ), + session=session, + max_retries=config_value.get("retries", DEFAULT_RETRIES), + backoff_factor=config_value.get("backoff_factor", DEFAULT_BACKOFF_FACTOR), + ) + + +def create_agent_graphql_client( session: requests.Session, url: str, config_value: Dict[str, Any], scope: DagsterCloudInstanceScope = DagsterCloudInstanceScope.DEPLOYMENT, ): - return GqlShimClient( + return DagsterCloudGraphQLClient( url=url, headers=get_agent_headers(config_value, scope=scope), verify=config_value.get("verify", True), @@ -254,7 +381,7 @@ def _get_retry_after_sleep_time(headers): @contextmanager def create_cloud_webserver_client(url: str, api_token: str, retries=3): with create_graphql_requests_session() as session: - yield GqlShimClient( + yield DagsterCloudGraphQLClient( session=session, url=f"{url}/graphql", headers=get_dagster_cloud_api_headers( diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/headers/impl.py b/dagster-cloud-cli/dagster_cloud_cli/core/headers/impl.py index 56ddc94..0db1314 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/core/headers/impl.py +++ b/dagster-cloud-cli/dagster_cloud_cli/core/headers/impl.py @@ -2,7 +2,6 @@ from typing import Dict, Optional from ...version import __version__ -from .._utils import merge_dicts from .auth import ( API_TOKEN_HEADER, DAGSTER_CLOUD_SCOPE_HEADER, @@ -18,13 +17,13 @@ def get_dagster_cloud_api_headers( deployment_name: Optional[str] = None, additional_headers: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: - return merge_dicts( - { + return { + **{ API_TOKEN_HEADER: agent_token, PYTHON_VERSION_HEADER: platform.python_version(), DAGSTER_CLOUD_VERSION_HEADER: __version__, DAGSTER_CLOUD_SCOPE_HEADER: scope.value, }, - {DEPLOYMENT_NAME_HEADER: deployment_name} if deployment_name else {}, - additional_headers if additional_headers else {}, - ) + **({DEPLOYMENT_NAME_HEADER: deployment_name} if deployment_name else {}), + **(additional_headers if additional_headers else {}), + } diff --git a/dagster-cloud-cli/dagster_cloud_cli/core/pex_builder/util.py b/dagster-cloud-cli/dagster_cloud_cli/core/pex_builder/util.py index cf0d3ec..67bf6dc 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/core/pex_builder/util.py +++ b/dagster-cloud-cli/dagster_cloud_cli/core/pex_builder/util.py @@ -45,10 +45,13 @@ def get_pex_flags(python_version: version.Version, build_sdists: bool = True) -> # available for a dependency, then the build will fail. # see also https://linear.app/elementl/issue/CLOUD-2023/pex-builds-fail-for-dbt-core-dependency resolve_local = ["--resolve-local-platforms"] if build_sdists else [] + # This is mainly useful in local mac test environments + include_current = ["--platform=current"] if os.getenv("PEX_INCLUDE_CURRENT_PLATFORM") else [] return [ # this platform matches what can run on our serverless base images # the version tag is a major/minor string like "38" f"--platform=manylinux2014_x86_64-cp-{version_tag}-cp{version_tag}", + *include_current, # this ensures PEX_PATH is not cleared and any subprocess invoked can also use this. # this is important for running console scripts that use the pex environment (eg dbt) "--no-strip-pex-env", @@ -80,7 +83,7 @@ def build_pex( https://peps.python.org/pep-0425/ https://peps.python.org/pep-0427/ - Packages for the current platform are always included. (--platform=current) + Packages for the current platform are only included if requested with PEX_INCLUDE_CURRENT_PLATFORM The manylinux platform ensures pexes built on local machines (macos, windows) are compatible with linux on cloud. diff --git a/dagster-cloud-cli/dagster_cloud_cli/gql.py b/dagster-cloud-cli/dagster_cloud_cli/gql.py index cd6c297..e2fff66 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/gql.py +++ b/dagster-cloud-cli/dagster_cloud_cli/gql.py @@ -5,13 +5,13 @@ CliEventType, ) -from .core.graphql_client import GqlShimClient, create_cloud_webserver_client +from .core.graphql_client import DagsterCloudGraphQLClient, create_cloud_webserver_client @contextmanager def graphql_client_from_url( url: str, token: str, retries: int = 3 -) -> Generator[GqlShimClient, None, None]: +) -> Generator[DagsterCloudGraphQLClient, None, None]: with create_cloud_webserver_client(url.rstrip("/"), token, retries) as client: yield client @@ -38,7 +38,7 @@ def url_from_config(organization: str, deployment: Optional[str] = None) -> str: """ -def fetch_full_deployments(client: GqlShimClient) -> List[Any]: +def fetch_full_deployments(client: DagsterCloudGraphQLClient) -> List[Any]: return client.execute(FULL_DEPLOYMENTS_QUERY)["data"]["fullDeployments"] @@ -112,7 +112,7 @@ def get_location_input(self): """ -def fetch_agent_status(client: GqlShimClient) -> List[Any]: +def fetch_agent_status(client: DagsterCloudGraphQLClient) -> List[Any]: return client.execute(AGENT_STATUS_QUERY)["data"]["agents"] @@ -128,7 +128,7 @@ def fetch_agent_status(client: GqlShimClient) -> List[Any]: """ -def fetch_workspace_entries(client: GqlShimClient) -> List[Any]: +def fetch_workspace_entries(client: DagsterCloudGraphQLClient) -> List[Any]: return client.execute(WORKSPACE_ENTRIES_QUERY)["data"]["workspace"]["workspaceEntries"] @@ -149,6 +149,13 @@ def fetch_workspace_entries(client: GqlShimClient) -> List[Any]: ... on PythonError { message stack + errorChain { + isExplicitLink + error { + message + stack + } + } } } } @@ -162,7 +169,7 @@ def fetch_workspace_entries(client: GqlShimClient) -> List[Any]: """ -def fetch_code_locations(client: GqlShimClient) -> List[Any]: +def fetch_code_locations(client: DagsterCloudGraphQLClient) -> List[Any]: result = client.execute(REPOSITORY_LOCATIONS_QUERY)["data"]["workspaceOrError"] if result["__typename"] != "Workspace": raise Exception("Unable to query code locations: ", result["message"]) @@ -188,7 +195,9 @@ def fetch_code_locations(client: GqlShimClient) -> List[Any]: """ -def add_or_update_code_location(client: GqlShimClient, location_document: Dict[str, Any]) -> None: +def add_or_update_code_location( + client: DagsterCloudGraphQLClient, location_document: Dict[str, Any] +) -> None: result = client.execute( ADD_OR_UPDATE_LOCATION_FROM_DOCUMENT_MUTATION, variable_values={"document": location_document}, @@ -215,7 +224,7 @@ def add_or_update_code_location(client: GqlShimClient, location_document: Dict[s """ -def delete_code_location(client: GqlShimClient, location_name: str) -> None: +def delete_code_location(client: DagsterCloudGraphQLClient, location_name: str) -> None: result = client.execute( DELETE_LOCATION_MUTATION, variable_values={"locationName": location_name} ) @@ -246,7 +255,7 @@ def delete_code_location(client: GqlShimClient, location_name: str) -> None: def reconcile_code_locations( - client: GqlShimClient, locations_document: Dict[str, Any] + client: DagsterCloudGraphQLClient, locations_document: Dict[str, Any] ) -> List[str]: result = client.execute( RECONCILE_LOCATIONS_FROM_DOCUMENT_MUTATION, @@ -279,7 +288,7 @@ def reconcile_code_locations( """ -def fetch_locations_as_document(client: GqlShimClient) -> Dict[str, Any]: +def fetch_locations_as_document(client: DagsterCloudGraphQLClient) -> Dict[str, Any]: result = client.execute(GET_LOCATIONS_AS_DOCUMENT_QUERY) return result["data"]["locationsAsDocument"]["document"] @@ -304,7 +313,9 @@ def fetch_locations_as_document(client: GqlShimClient) -> Dict[str, Any]: """ -def set_deployment_settings(client: GqlShimClient, deployment_settings: Dict[str, Any]) -> None: +def set_deployment_settings( + client: DagsterCloudGraphQLClient, deployment_settings: Dict[str, Any] +) -> None: result = client.execute( SET_DEPLOYMENT_SETTINGS_MUTATION, variable_values={"deploymentSettings": deployment_settings}, @@ -323,7 +334,7 @@ def set_deployment_settings(client: GqlShimClient, deployment_settings: Dict[str """ -def get_deployment_settings(client: GqlShimClient) -> Dict[str, Any]: +def get_deployment_settings(client: DagsterCloudGraphQLClient) -> Dict[str, Any]: result = client.execute(DEPLOYMENT_SETTINGS_QUERY) if result.get("data", {}).get("deploymentSettings", {}).get("settings") is None: @@ -363,7 +374,7 @@ def get_deployment_settings(client: GqlShimClient) -> Dict[str, Any]: """ -def get_alert_policies(client: GqlShimClient) -> Dict[str, Any]: +def get_alert_policies(client: DagsterCloudGraphQLClient) -> Dict[str, Any]: result = client.execute(ALERT_POLICIES_QUERY) if result.get("data", {}).get("alertPolicies", {}) is None: @@ -397,7 +408,7 @@ def get_alert_policies(client: GqlShimClient) -> Dict[str, Any]: def reconcile_alert_policies( - client: GqlShimClient, alert_policy_inputs: Sequence[dict] + client: DagsterCloudGraphQLClient, alert_policy_inputs: Sequence[dict] ) -> Sequence[str]: result = client.execute( RECONCILE_ALERT_POLICIES_FROM_DOCUMENT_MUTATION, @@ -435,7 +446,9 @@ def reconcile_alert_policies( """ -def set_organization_settings(client: GqlShimClient, organization_settings: Dict[str, Any]) -> None: +def set_organization_settings( + client: DagsterCloudGraphQLClient, organization_settings: Dict[str, Any] +) -> None: result = client.execute( SET_ORGANIZATION_SETTINGS_MUTATION, variable_values={"organizationSettings": organization_settings}, @@ -454,7 +467,7 @@ def set_organization_settings(client: GqlShimClient, organization_settings: Dict """ -def get_organization_settings(client: GqlShimClient) -> Dict[str, Any]: +def get_organization_settings(client: DagsterCloudGraphQLClient) -> Dict[str, Any]: result = client.execute(ORGANIZATION_SETTINGS_QUERY) if result.get("data", {}).get("organizationSettings", {}).get("settings") is None: @@ -483,7 +496,7 @@ def get_organization_settings(client: GqlShimClient) -> Dict[str, Any]: def create_or_update_branch_deployment( - client: GqlShimClient, + client: DagsterCloudGraphQLClient, repo_name: str, branch_name: str, commit_hash: str, @@ -558,7 +571,7 @@ def create_or_update_branch_deployment( def launch_run( - client: GqlShimClient, + client: DagsterCloudGraphQLClient, location_name: str, repo_name: str, job_name: str, @@ -601,7 +614,7 @@ def launch_run( """ -def get_ecr_info(client: GqlShimClient) -> Any: +def get_ecr_info(client: DagsterCloudGraphQLClient) -> Any: data = client.execute(GET_ECR_CREDS_QUERY)["data"] return { "registry_url": data["serverless"]["registryUrl"], @@ -624,7 +637,7 @@ def get_ecr_info(client: GqlShimClient) -> Any: """ -def run_status(client: GqlShimClient, run_id: str) -> Any: +def run_status(client: DagsterCloudGraphQLClient, run_id: str) -> Any: data = client.execute( GET_RUN_STATUS_QUERY, variable_values={"runId": run_id}, @@ -650,7 +663,7 @@ def run_status(client: GqlShimClient, run_id: str) -> Any: def mark_cli_event( - client: GqlShimClient, + client: DagsterCloudGraphQLClient, event_type: CliEventType, duration_seconds: float, success: bool = True, @@ -700,7 +713,7 @@ def mark_cli_event( """ -def get_deployment_by_name(client: GqlShimClient, deployment: str) -> Dict[str, Any]: +def get_deployment_by_name(client: DagsterCloudGraphQLClient, deployment: str) -> Dict[str, Any]: result = client.execute( GET_DEPLOYMENT_BY_NAME_QUERY, variable_values={"deploymentName": deployment} )["data"]["deploymentByName"] @@ -711,7 +724,7 @@ def get_deployment_by_name(client: GqlShimClient, deployment: str) -> Dict[str, raise Exception(f"Unable to find deployment {deployment}") -def delete_branch_deployment(client: GqlShimClient, deployment: str) -> Any: +def delete_branch_deployment(client: DagsterCloudGraphQLClient, deployment: str) -> Any: deployment_info = get_deployment_by_name(client, deployment) if not deployment_info["deploymentType"] == "BRANCH": raise Exception(f"Deployment {deployment} is not a branch deployment") diff --git a/dagster-cloud-cli/dagster_cloud_cli/version.py b/dagster-cloud-cli/dagster_cloud_cli/version.py index 93d2517..14d9d2f 100644 --- a/dagster-cloud-cli/dagster_cloud_cli/version.py +++ b/dagster-cloud-cli/dagster_cloud_cli/version.py @@ -1 +1 @@ -__version__ = "1.6.14" +__version__ = "1.7.0" diff --git a/dagster-cloud-examples/dagster_cloud_examples/version.py b/dagster-cloud-examples/dagster_cloud_examples/version.py index 93d2517..14d9d2f 100644 --- a/dagster-cloud-examples/dagster_cloud_examples/version.py +++ b/dagster-cloud-examples/dagster_cloud_examples/version.py @@ -1 +1 @@ -__version__ = "1.6.14" +__version__ = "1.7.0" diff --git a/dagster-cloud-examples/setup.py b/dagster-cloud-examples/setup.py index 28363a8..fc9eef6 100644 --- a/dagster-cloud-examples/setup.py +++ b/dagster-cloud-examples/setup.py @@ -19,7 +19,7 @@ def get_version() -> str: name="dagster-cloud-examples", version=ver, packages=find_packages(exclude=["dagster_cloud_examples_tests*"]), - install_requires=["dagster_cloud==1.6.14"], + install_requires=["dagster_cloud==1.7.0"], extras_require={"tests": ["mypy", "pylint", "pytest"]}, author="Elementl", author_email="hello@elementl.com", diff --git a/dagster-cloud/dagster_cloud/agent/dagster_cloud_agent.py b/dagster-cloud/dagster_cloud/agent/dagster_cloud_agent.py index 53777e8..b071fb8 100644 --- a/dagster-cloud/dagster_cloud/agent/dagster_cloud_agent.py +++ b/dagster-cloud/dagster_cloud/agent/dagster_cloud_agent.py @@ -1,9 +1,7 @@ import logging import os import sys -import tempfile import time -import zlib from collections import deque from concurrent.futures import Future, ThreadPoolExecutor from contextlib import ExitStack @@ -31,7 +29,7 @@ from dagster._utils.interrupts import raise_interrupts_as from dagster._utils.merger import merge_dicts from dagster._utils.typed_dict import init_optional_typeddict -from dagster_cloud_cli.core.errors import GraphQLStorageError, raise_http_error +from dagster_cloud_cli.core.errors import raise_http_error from dagster_cloud_cli.core.workspace import CodeDeploymentMetadata from dagster_cloud.api.dagster_cloud_api import ( @@ -54,7 +52,7 @@ UserCodeLauncherEntry, ) -from ..util import SERVER_HANDLE_TAG, is_isolated_run +from ..util import SERVER_HANDLE_TAG, compressed_namedtuple_upload_file, is_isolated_run from ..version import __version__ from .queries import ( ADD_AGENT_HEARTBEATS_MUTATION, @@ -268,9 +266,7 @@ def run_loop( self._check_add_heartbeat(instance, agent_uuid, heartbeat_interval_seconds) except Exception: self._logger.error( - "Failed to add heartbeat: \n{}".format( - serializable_error_info_from_exc_info(sys.exc_info()) - ) + f"Failed to add heartbeat: \n{serializable_error_info_from_exc_info(sys.exc_info())}" ) # Check for any received interrupts @@ -282,9 +278,7 @@ def run_loop( except Exception: self._logger.error( - "Failed to check for workspace updates: \n{}".format( - serializable_error_info_from_exc_info(sys.exc_info()) - ) + f"Failed to check for workspace updates: \n{serializable_error_info_from_exc_info(sys.exc_info())}" ) # Check for any received interrupts @@ -374,7 +368,7 @@ def _check_add_heartbeat( self._last_heartbeat_time = curr_time - res = instance.organization_scoped_graphql_client().execute( + instance.organization_scoped_graphql_client().execute( ADD_AGENT_HEARTBEATS_MUTATION, variable_values={ "serializedAgentHeartbeats": serialized_agent_heartbeats, @@ -382,9 +376,6 @@ def _check_add_heartbeat( idempotent_mutation=True, ) - if "errors" in res: - raise GraphQLStorageError(res) - @property def executor(self) -> ThreadPoolExecutor: return self._executor @@ -620,7 +611,7 @@ def _get_location_origin_from_request( DagsterCloudApi.GET_SUBSET_EXTERNAL_PIPELINE_RESULT, }: external_pipeline_origin = request.request_args.job_origin - return external_pipeline_origin.external_repository_origin.code_location_origin + return external_pipeline_origin.repository_origin.code_location_origin elif api_name in { DagsterCloudApi.GET_EXTERNAL_PARTITION_CONFIG, DagsterCloudApi.GET_EXTERNAL_PARTITION_TAGS, @@ -787,7 +778,7 @@ def _handle_api_request( run_location_name = cast( str, - run.external_job_origin.external_repository_origin.code_location_origin.location_name, + run.external_job_origin.repository_origin.code_location_origin.location_name, ) server = user_code_launcher.get_grpc_server(deployment_name, run_location_name) @@ -831,7 +822,7 @@ def _handle_api_request( else: run_location_name = cast( str, - run.external_job_origin.external_repository_origin.code_location_origin.location_name, + run.external_job_origin.repository_origin.code_location_origin.location_name, ) server = user_code_launcher.get_grpc_server( @@ -872,8 +863,8 @@ def _process_api_request( if request_api not in DagsterCloudApi.__members__: api_result = DagsterCloudApiUnknownCommandResponse(request_api) self._logger.warning( - "Ignoring request {request}: Unknown command. This is likely due to running an " - "older version of the agent.".format(request=json_request) + f"Ignoring request {json_request}: Unknown command. This is likely due to running an " + "older version of the agent." ) else: try: @@ -956,11 +947,8 @@ def run_iteration( else: self._logger.warning( - "Iteration #{iteration}: Waiting to pull requests from the queue since there are" - " already {num_pending_requests} in the queue".format( - iteration=self._iteration, - num_pending_requests=len(self._pending_requests), - ) + f"Iteration #{self._iteration}: Waiting to pull requests from the queue since there are" + f" already {len(self._pending_requests)} in the queue" ) invalid_requests = [] @@ -1045,17 +1033,12 @@ def upload_api_response( deployment_name: str, upload_response: DagsterCloudUploadApiResponse, ): - with tempfile.TemporaryDirectory() as temp_dir: - dst = os.path.join(temp_dir, "api_response.tmp") - with open(dst, "wb") as f: - f.write(zlib.compress(serialize_value(upload_response).encode("utf-8"))) - - with open(dst, "rb") as f: - resp = instance.rest_requests_session.put( - instance.dagster_cloud_upload_api_response_url, - headers=instance.headers_for_deployment(deployment_name), - files={"api_response.tmp": f}, - timeout=instance.dagster_cloud_api_timeout, - proxies=instance.dagster_cloud_api_proxies, - ) - raise_http_error(resp) + with compressed_namedtuple_upload_file(upload_response) as f: + resp = instance.requests_managed_retries_session.put( + instance.dagster_cloud_upload_api_response_url, + headers=instance.headers_for_deployment(deployment_name), + files={"api_response.tmp": f}, + timeout=instance.dagster_cloud_api_timeout, + proxies=instance.dagster_cloud_api_proxies, + ) + raise_http_error(resp) diff --git a/dagster-cloud/dagster_cloud/anomaly_detection/__init__.py b/dagster-cloud/dagster_cloud/anomaly_detection/__init__.py new file mode 100644 index 0000000..81e1a83 --- /dev/null +++ b/dagster-cloud/dagster_cloud/anomaly_detection/__init__.py @@ -0,0 +1,4 @@ +from .defs import ( + build_anomaly_detection_freshness_checks as build_anomaly_detection_freshness_checks, +) +from .types import AnomalyDetectionModelParams as AnomalyDetectionModelParams diff --git a/dagster-cloud/dagster_cloud/anomaly_detection/defs.py b/dagster-cloud/dagster_cloud/anomaly_detection/defs.py new file mode 100644 index 0000000..80cd2ea --- /dev/null +++ b/dagster-cloud/dagster_cloud/anomaly_detection/defs.py @@ -0,0 +1,209 @@ +from typing import Iterable, Optional, Sequence, Union, cast + +from dagster import ( + AssetCheckExecutionContext, + MetadataValue, + _check as check, +) +from dagster._core.definitions.asset_check_result import AssetCheckResult +from dagster._core.definitions.asset_check_spec import AssetCheckSeverity, AssetCheckSpec +from dagster._core.definitions.asset_checks import AssetChecksDefinition +from dagster._core.definitions.asset_key import AssetKey +from dagster._core.definitions.assets import AssetsDefinition +from dagster._core.definitions.decorators.asset_check_decorator import multi_asset_check +from dagster._core.definitions.events import CoercibleToAssetKey +from dagster._core.definitions.freshness_checks.utils import ( + asset_to_keys_iterable, + seconds_in_words, + unique_id_from_asset_keys, +) +from dagster._core.definitions.source_asset import SourceAsset +from dagster._core.errors import ( + DagsterError, + DagsterInvariantViolationError, +) +from dagster._core.instance import DagsterInstance +from dagster_cloud_cli.core.graphql_client import create_cloud_webserver_client + +from dagster_cloud import DagsterCloudAgentInstance + +from .mutation import ANOMALY_DETECTION_INFERENCE_MUTATION +from .types import ( + AnomalyDetectionModelParams, + BetaFreshnessAnomalyDetectionParams, +) + +DEFAULT_MODEL_PARAMS = BetaFreshnessAnomalyDetectionParams(sensitivity=0.1) + + +class DagsterCloudAnomalyDetectionFailed(DagsterError): + """Raised when an anomaly detection check fails host-side.""" + + +def _build_check_for_assets( + asset_keys: Sequence[AssetKey], + params: AnomalyDetectionModelParams, +) -> AssetChecksDefinition: + @multi_asset_check( + specs=[ + AssetCheckSpec( + name="freshness_anomaly_detection_check", + description=f"Detects anomalies in the freshness of the asset using model {params.model_version.value.lower()}.", + asset=asset_key, + ) + for asset_key in asset_keys + ], + can_subset=True, + name=f"anomaly_detection_freshness_check_{unique_id_from_asset_keys(asset_keys)}", + ) + def the_check(context: AssetCheckExecutionContext) -> Iterable[AssetCheckResult]: + if not _is_agent_instance(context.instance): + raise DagsterInvariantViolationError( + f"This anomaly detection check is not being launched from a dagster agent. " + "Anomaly detection is only available for dagster cloud deployments." + f"Instance type: {type(context.instance)}." + ) + instance = cast(DagsterCloudAgentInstance, context.instance) + with create_cloud_webserver_client( + instance.dagit_url, + check.str_param(instance.dagster_cloud_agent_token, "dagster_cloud_agent_token"), + ) as client: + for check_key in context.selected_asset_check_keys: + asset_key = check_key.asset_key + if not context.job_def.asset_layer.has(asset_key): + raise Exception(f"Could not find targeted asset {asset_key.to_string()}.") + result = client.execute( + ANOMALY_DETECTION_INFERENCE_MUTATION, + { + "modelVersion": params.model_version.value, + "params": { + **dict(params), + "asset_key_user_string": asset_key.to_user_string(), + }, + }, + ) + metadata = { + "model_params": {**params.as_metadata}, + "model_version": params.model_version.value, + } + if result["anomalyDetectionInference"]["__typename"] != "AnomalyDetectionSuccess": + yield handle_anomaly_detection_inference_failure( + result, metadata, params, asset_key + ) + continue + response = result["anomalyDetectionInference"]["response"] + overdue_seconds = check.float_param(response["overdue_seconds"], "overdue_seconds") + overdue_deadline_timestamp = response["overdue_deadline_timestamp"] + metadata["overdue_deadline_timestamp"] = MetadataValue.timestamp( + overdue_deadline_timestamp + ) + metadata["model_training_range_start_timestamp"] = MetadataValue.timestamp( + response["model_training_range_start_timestamp"] + ) + metadata["model_training_range_end_timestamp"] = MetadataValue.timestamp( + response["model_training_range_end_timestamp"] + ) + + last_updated_timestamp = response["last_updated_timestamp"] + if last_updated_timestamp is None: + yield AssetCheckResult( + passed=True, + description="The asset has never been materialized or otherwise observed to have been updated", + ) + continue + + evaluation_timestamp = response["evaluation_timestamp"] + last_update_lag_str = seconds_in_words( + evaluation_timestamp - last_updated_timestamp + ) + expected_lag_str = seconds_in_words( + overdue_deadline_timestamp - last_updated_timestamp + ) + gt_or_lte_str = "greater than" if overdue_seconds > 0 else "less than or equal to" + lag_comparison_str = ( + f"At the time of this check's evaluation, {last_update_lag_str} had passed since its " + f"last update. This is {gt_or_lte_str} the allowed {expected_lag_str} threshold, which " + "is based on its prior history of updates." + ) + + if overdue_seconds > 0: + metadata["overdue_minutes"] = round(overdue_seconds / 60, 2) + + yield AssetCheckResult( + passed=False, + severity=AssetCheckSeverity.WARN, + metadata=metadata, + description=f"The asset is overdue for an update. {lag_comparison_str}", + asset_key=asset_key, + ) + else: + yield AssetCheckResult( + passed=True, + metadata=metadata, + description=f"The asset is fresh. {lag_comparison_str}", + asset_key=asset_key, + ) + + return the_check + + +def handle_anomaly_detection_inference_failure( + result: dict, metadata: dict, params: AnomalyDetectionModelParams, asset_key: AssetKey +) -> AssetCheckResult: + if ( + result["anomalyDetectionInference"]["__typename"] == "AnomalyDetectionFailure" + and result["anomalyDetectionInference"]["message"] + == params.model_version.minimum_required_records_msg + ): + # Intercept failure in the case of not enough records, and return a pass to avoid + # being too noisy with failures. + return AssetCheckResult( + passed=True, + severity=AssetCheckSeverity.WARN, + metadata=metadata, + description=result["anomalyDetectionInference"]["message"], + asset_key=asset_key, + ) + raise DagsterCloudAnomalyDetectionFailed( + f"Anomaly detection failed: {result['anomalyDetectionInference']['message']}" + ) + + +def build_anomaly_detection_freshness_checks( + *, + assets: Sequence[Union[CoercibleToAssetKey, AssetsDefinition, SourceAsset]], + params: Optional[AnomalyDetectionModelParams], +) -> AssetChecksDefinition: + """Builds a list of asset checks which utilize anomaly detection algorithms to + determine the freshness of data. + + Args: + assets (Sequence[Union[CoercibleToAssetKey, AssetsDefinition, SourceAsset]]): The assets to construct checks for. For each passed in + asset, there will be a corresponding constructed `AssetChecksDefinition`. + params (AnomalyDetectionModelParams): The parameters to use for the model. The parameterization corresponds to the model used. + + Returns: + AssetChecksDefinition: A list of `AssetChecksDefinition` objects, each corresponding to an asset in the `assets` parameter. + + Examples: + .. code-block:: python + + from dagster_cloud import build_anomaly_detection_freshness_checks, BetaFreshnessAnomalyDetectionParams + + checks_def = build_anomaly_detection_freshness_checks( + assets=[AssetKey("foo_asset"), AssetKey("foo_asset")], + params=BetaFreshnessAnomalyDetectionParams(sensitivity=0.1), + ) + """ + params = check.opt_inst_param( + params, "params", AnomalyDetectionModelParams, DEFAULT_MODEL_PARAMS + ) + return _build_check_for_assets( + [asset_key for asset in assets for asset_key in asset_to_keys_iterable(asset)], params + ) + + +def _is_agent_instance(instance: DagsterInstance) -> bool: + if hasattr(instance, "dagster_cloud_agent_token") and hasattr(instance, "dagit_url"): + return True + return False diff --git a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/mutation.py b/dagster-cloud/dagster_cloud/anomaly_detection/mutation.py similarity index 100% rename from dagster-cloud/dagster_cloud/dagster_anomaly_detection/mutation.py rename to dagster-cloud/dagster_cloud/anomaly_detection/mutation.py diff --git a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/types.py b/dagster-cloud/dagster_cloud/anomaly_detection/types.py similarity index 79% rename from dagster-cloud/dagster_cloud/dagster_anomaly_detection/types.py rename to dagster-cloud/dagster_cloud/anomaly_detection/types.py index 55a8a36..3bdc703 100644 --- a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/types.py +++ b/dagster-cloud/dagster_cloud/anomaly_detection/types.py @@ -9,6 +9,16 @@ class AnomalyDetectionModelVersion(Enum): FRESHNESS_BETA = "FRESHNESS_BETA" + @property + def minimum_required_records(self) -> int: + if self == AnomalyDetectionModelVersion.FRESHNESS_BETA: + return 15 + raise NotImplementedError(f"Minimum required records not implemented for {self}") + + @property + def minimum_required_records_msg(self) -> str: + return f"Not enough records found to detect anomalies. Need at least {self.minimum_required_records}." + ### INTERNAL MODEL PARAMETER SETS ### diff --git a/dagster-cloud/dagster_cloud/api/dagster_cloud_api.py b/dagster-cloud/dagster_cloud/api/dagster_cloud_api.py index e89f5d4..2e025b3 100644 --- a/dagster-cloud/dagster_cloud/api/dagster_cloud_api.py +++ b/dagster-cloud/dagster_cloud/api/dagster_cloud_api.py @@ -7,6 +7,7 @@ import pendulum from dagster._core.code_pointer import CodePointer from dagster._core.definitions.selector import JobSelector +from dagster._core.events.log import EventLogEntry from dagster._core.remote_representation import ( CodeLocationOrigin, ExternalRepositoryData, @@ -402,6 +403,14 @@ def __str__(self) -> str: ) +@whitelist_for_serdes +class StoreEventBatchRequest( + NamedTuple("_StoreEventBatchRequest", [("event_log_entries", Sequence[EventLogEntry])]) +): + def __new__(cls, event_log_entries: Sequence[EventLogEntry]): + return super().__new__(cls, event_log_entries=event_log_entries) + + @whitelist_for_serdes class DagsterCloudUploadApiResponse( NamedTuple( diff --git a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/__init__.py b/dagster-cloud/dagster_cloud/artifacts/__init__.py similarity index 100% rename from dagster-cloud/dagster_cloud/dagster_anomaly_detection/__init__.py rename to dagster-cloud/dagster_cloud/artifacts/__init__.py diff --git a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/defs.py b/dagster-cloud/dagster_cloud/dagster_anomaly_detection/defs.py deleted file mode 100644 index e6cb0b8..0000000 --- a/dagster-cloud/dagster_cloud/dagster_anomaly_detection/defs.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -from typing import Optional, Sequence, Union, cast - -from dagster import ( - _check as check, - asset_check, -) -from dagster._core.definitions.asset_check_result import AssetCheckResult -from dagster._core.definitions.asset_check_spec import AssetCheckSeverity -from dagster._core.definitions.asset_checks import AssetChecksDefinition -from dagster._core.definitions.assets import AssetsDefinition -from dagster._core.definitions.events import CoercibleToAssetKey -from dagster._core.definitions.source_asset import SourceAsset -from dagster._core.errors import ( - DagsterError, - DagsterInvariantViolationError, -) -from dagster._core.execution.context.compute import AssetExecutionContext -from dagster._core.instance import DagsterInstance -from gql import Client, gql -from gql.transport.requests import RequestsHTTPTransport - -from dagster_cloud import DagsterCloudAgentInstance - -from .mutation import ANOMALY_DETECTION_INFERENCE_MUTATION -from .types import ( - AnomalyDetectionModelParams, - BetaFreshnessAnomalyDetectionParams, -) - -DEFAULT_MODEL_PARAMS = BetaFreshnessAnomalyDetectionParams(sensitivity=0.1) - - -class DagsterCloudAnomalyDetectionFailed(DagsterError): - """Raised when an anomaly detection check fails host-side.""" - - -def _build_check_for_asset( - asset: Union[CoercibleToAssetKey, AssetsDefinition, SourceAsset], - params: AnomalyDetectionModelParams, -) -> AssetChecksDefinition: - @asset_check( - asset=asset, - description=f"Detects anomalies in the freshness of the asset using model {params.model_version.value.lower()}.", - name="freshness_anomaly_detection", - ) - def the_check(context: AssetExecutionContext) -> AssetCheckResult: - if not _is_agent_instance(context.instance): - raise DagsterInvariantViolationError( - f"This anomaly detection check is not being launched from a dagster agent. " - "Anomaly detection is only available for dagster cloud deployments." - f"Instance type: {type(context.instance)}." - ) - instance = cast(DagsterCloudAgentInstance, context.instance) - transport = RequestsHTTPTransport( - url=os.getenv("DAGSTER_METRICS_DAGIT_URL", f"{instance.dagit_url}graphql"), - use_json=True, - timeout=300, - headers={"Dagster-Cloud-Api-Token": instance.dagster_cloud_agent_token}, - ) - client = Client(transport=transport, fetch_schema_from_transport=True) - asset_key = next(iter(context.assets_def.check_keys)).asset_key - if not context.job_def.asset_layer.has(asset_key): - raise Exception(f"Could not find targeted asset {asset_key.to_string()}.") - result = client.execute( - gql(ANOMALY_DETECTION_INFERENCE_MUTATION), - { - "modelVersion": params.model_version.value, - "params": { - **dict(params), - "asset_key_user_string": asset_key.to_user_string(), - }, - }, - ) - if result["anomalyDetectionInference"]["__typename"] != "AnomalyDetectionSuccess": - raise DagsterCloudAnomalyDetectionFailed( - f"Anomaly detection failed: {result['anomalyDetectionInference']['message']}" - ) - response = result["anomalyDetectionInference"]["response"] - overdue_seconds = check.float_param(response["overdue_seconds"], "overdue_seconds") - expected_event_timestamp = response["overdue_deadline_timestamp"] - model_training_range_start = response["model_training_range_start_timestamp"] - model_training_range_end = response["model_training_range_end_timestamp"] - metadata = { - "model_params": {**params.as_metadata}, - "model_version": params.model_version.value, - "model_training_range_start_timestamp": model_training_range_start, - "model_training_range_end_timestamp": model_training_range_end, - "overdue_deadline_timestamp": expected_event_timestamp, - } - if overdue_seconds > 0: - metadata["overdue_minutes"] = overdue_seconds / 60 - return AssetCheckResult( - passed=False, - severity=AssetCheckSeverity.WARN, - metadata=metadata, - ) - else: - return AssetCheckResult(passed=True, metadata=metadata) - - return the_check - - -def build_anomaly_detection_freshness_checks( - *, - assets: Sequence[Union[CoercibleToAssetKey, AssetsDefinition, SourceAsset]], - params: Optional[AnomalyDetectionModelParams], -) -> Sequence[AssetChecksDefinition]: - """Builds a list of asset checks which utilize anomaly detection algorithms to - determine the freshness of data. - - Args: - assets (Sequence[Union[CoercibleToAssetKey, AssetsDefinition, SourceAsset]]): The assets to construct checks for. For each passed in - asset, there will be a corresponding constructed `AssetChecksDefinition`. - params (AnomalyDetectionModelParams): The parameters to use for the model. The parameterization corresponds to the model used. - - Returns: - Sequence[AssetChecksDefinition]: A list of `AssetChecksDefinition` objects, each corresponding to an asset in the `assets` parameter. - - Examples: - .. code-block:: python - - from dagster_cloud import build_anomaly_detection_freshness_checks, AnomalyDetectionModel, BetaFreshnessAnomalyDetectionParams - - checks = build_anomaly_detection_freshness_checks( - assets=[AssetKey("foo_asset"), AssetKey("foo_asset")], - params=BetaFreshnessAnomalyDetectionParams(sensitivity=0.1), - ) - """ - params = check.opt_inst_param( - params, "params", AnomalyDetectionModelParams, DEFAULT_MODEL_PARAMS - ) - return [_build_check_for_asset(asset, params) for asset in assets] - - -def _is_agent_instance(instance: DagsterInstance) -> bool: - if hasattr(instance, "dagster_cloud_agent_token") and hasattr(instance, "dagit_url"): - return True - return False diff --git a/dagster-cloud/dagster_cloud/dagster_insights/bigquery/dbt_wrapper.py b/dagster-cloud/dagster_cloud/dagster_insights/bigquery/dbt_wrapper.py index a694a0f..5459167 100644 --- a/dagster-cloud/dagster_cloud/dagster_insights/bigquery/dbt_wrapper.py +++ b/dagster-cloud/dagster_cloud/dagster_insights/bigquery/dbt_wrapper.py @@ -2,32 +2,36 @@ from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Any, Iterable, Iterator, Optional, Union, ) -import dagster._check as check +import yaml from dagster import ( AssetCheckResult, AssetExecutionContext, AssetKey, AssetMaterialization, AssetObservation, - MetadataValue, OpExecutionContext, Output, ) -from dagster._core.definitions.metadata import MetadataMapping -from google.cloud import bigquery +from dagster_dbt import DbtCliInvocation +from dagster_dbt.version import __version__ as dagster_dbt_version +from packaging import version from ..insights_utils import extract_asset_info_from_event from .bigquery_utils import build_bigquery_cost_metadata, marker_asset_key_for_job if TYPE_CHECKING: - from dagster_dbt import DbtCliInvocation + from dbt.adapters.base.impl import BaseAdapter + from google.cloud import bigquery + +OPAQUE_ID_SQL_SIGIL = "bigquery_dagster_dbt_v1_opaque_id" +DEFAULT_BQ_REGION = "region-us" +MIN_DAGSTER_DBT_VERSION = "1.7.0" @dataclass @@ -47,31 +51,14 @@ def asset_partition_key(self) -> str: ) -def _extract_metadata_value(value: Optional[MetadataValue], default_value: Any = None) -> Any: - return value.value if value else default_value - - -def _extract_bigquery_info_from_metadata(metadata: MetadataMapping): - job_id = _extract_metadata_value(metadata.get("job_id")) - bytes_billed = _extract_metadata_value(metadata.get("bytes_billed"), 0) - slots_ms = _extract_metadata_value(metadata.get("slots_ms"), 0) - return job_id, bytes_billed, slots_ms - - -def _asset_partition_key(asset_key: AssetKey, partition_key: Optional[str]) -> str: - return f"{asset_key.to_string()}:{partition_key}" if partition_key else asset_key.to_string() - - def dbt_with_bigquery_insights( context: Union[OpExecutionContext, AssetExecutionContext], - dbt_cli_invocation: "DbtCliInvocation", + dbt_cli_invocation: DbtCliInvocation, dagster_events: Optional[ Iterable[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]] ] = None, + skip_config_check=False, record_observation_usage: bool = True, - explicitly_query_information_schema: bool = False, - bigquery_client: Optional[bigquery.Client] = None, - bigquery_location: Optional[str] = None, ) -> Iterator[Union[Output, AssetMaterialization, AssetObservation, AssetCheckResult]]: """Wraps a dagster-dbt invocation to associate each BigQuery query with the produced asset materializations. This allows the cost of each query to be associated with the asset @@ -102,60 +89,97 @@ def jaffle_shop_dbt_assets( dbt_cli_invocation = dbt.cli(["build"], context=context) yield from dbt_with_bigquery_insights(context, dbt_cli_invocation) """ - if explicitly_query_information_schema: - check.inst_param(bigquery_client, "bigquery_client", bigquery.Client) - check.str_param(bigquery_location, "bigquery_location") + if not skip_config_check: + adapter_type = dbt_cli_invocation.manifest["metadata"]["adapter_type"] + if adapter_type != "bigquery": + raise RuntimeError( + f"The 'bigquery' adapter must be used but instead found '{adapter_type}'" + ) + dbt_project_config = yaml.safe_load( + (dbt_cli_invocation.project_dir / "dbt_project.yml").open("r") + ) + # sanity check that the sigil is present somewhere in the query comment + query_comment = dbt_project_config.get("query-comment") + if query_comment is None: + raise RuntimeError("query-comment is required in dbt_project.yml but it was missing") + comment = query_comment.get("comment") + if comment is None: + raise RuntimeError( + "query-comment.comment is required in dbt_project.yml but it was missing" + ) + if OPAQUE_ID_SQL_SIGIL not in comment: + raise RuntimeError( + "query-comment.comment in dbt_project.yml must contain the string" + f" '{OPAQUE_ID_SQL_SIGIL}'. Read the Dagster Insights docs for more info." + ) if dagster_events is None: dagster_events = dbt_cli_invocation.stream() - asset_info_by_job_id = {} - cost_by_asset = defaultdict(list) + asset_info_by_unique_id = {} for dagster_event in dagster_events: - if isinstance(dagster_event, (AssetMaterialization, AssetObservation, Output)): + if isinstance( + dagster_event, (AssetMaterialization, AssetObservation, Output, AssetCheckResult) + ): + unique_id = dagster_event.metadata["unique_id"].value asset_key, partition = extract_asset_info_from_event( context, dagster_event, record_observation_usage ) if not asset_key: asset_key = marker_asset_key_for_job(context.job_def) - job_id, bytes_billed, slots_ms = _extract_bigquery_info_from_metadata( - dagster_event.metadata - ) - if bytes_billed or slots_ms: - asset_info_by_job_id[job_id] = (asset_key, partition) - cost_info = BigQueryCostInfo(asset_key, partition, job_id, slots_ms, bytes_billed) - cost_by_asset[cost_info.asset_partition_key].append(cost_info) + asset_info_by_unique_id[unique_id] = (asset_key, partition) yield dagster_event - if explicitly_query_information_schema and bigquery_client: - marker_asset_key = marker_asset_key_for_job(context.job_def) - run_results_json = dbt_cli_invocation.get_artifact("run_results.json") - invocation_id = run_results_json["metadata"]["invocation_id"] - assert bigquery_client - try: - cost_query = f""" - SELECT job_id, SUM(total_bytes_billed) AS bytes_billed, SUM(total_slot_ms) AS slots_ms - FROM `{bigquery_location}`.INFORMATION_SCHEMA.JOBS - WHERE query like '%{invocation_id}%' - GROUP BY job_id - """ - context.log.info(f"Querying INFORMATION_SCHEMA.JOBS for bytes billed: {cost_query}") - query_result = bigquery_client.query(cost_query) - - # overwrite cost_by_asset that is computed from the metadata - cost_by_asset = defaultdict(list) + marker_asset_key = marker_asset_key_for_job(context.job_def) + run_results_json = dbt_cli_invocation.get_artifact("run_results.json") + invocation_id = run_results_json["metadata"]["invocation_id"] + + # backcompat-proof in case the invocation does not have an instantiated adapter on it + adapter: Optional["BaseAdapter"] = getattr(dbt_cli_invocation, "adapter", None) + if not adapter: + if version.parse(dagster_dbt_version) < version.parse(MIN_DAGSTER_DBT_VERSION): + upgrade_message = f" Extracting cost information requires dagster_dbt>={MIN_DAGSTER_DBT_VERSION} (found {dagster_dbt_version}). " + else: + upgrade_message = "" + + context.log.error( + "Could not find a BigQuery adapter on the dbt CLI invocation. Skipping cost analysis." + + upgrade_message + ) + return + + cost_by_asset = defaultdict(list) + try: + with adapter.connection_named("dagster_insights:bigquery_cost"): + client: "bigquery.Client" = adapter.connections.get_thread_connection().handle + dataset = client.get_dataset(adapter.config.credentials.schema) + region = f"region-{dataset.location.lower()}" if dataset.location else DEFAULT_BQ_REGION + query_result = client.query( + rf""" + SELECT + job_id, + regexp_extract(query, r"{OPAQUE_ID_SQL_SIGIL}\[\[\[(.*?):{invocation_id}\]\]\]") as unique_id, + total_bytes_billed AS bytes_billed, + total_slot_ms AS slots_ms + FROM `{dataset.project}`.`{region}`.INFORMATION_SCHEMA.JOBS + WHERE query like '%{invocation_id}%' + """ + ) for row in query_result: - asset_key, partition = asset_info_by_job_id.get( - row.job_id, (marker_asset_key, None) + if not row.unique_id: + continue + asset_key, partition = asset_info_by_unique_id.get( + row.unique_id, (marker_asset_key, None) ) if row.bytes_billed or row.slots_ms: cost_info = BigQueryCostInfo( asset_key, partition, row.job_id, row.bytes_billed, row.slots_ms ) cost_by_asset[cost_info.asset_partition_key].append(cost_info) - except: - context.log.exception("Could not query information_schema.jobs for bytes billed") + except: + context.log.exception("Could not query information_schema for BigQuery cost information") + return for cost_info_list in cost_by_asset.values(): bytes_billed = sum(item.bytes_billed for item in cost_info_list) diff --git a/dagster-cloud/dagster_cloud/dagster_insights/insights_utils.py b/dagster-cloud/dagster_cloud/dagster_insights/insights_utils.py index e34d787..3a2be23 100644 --- a/dagster-cloud/dagster_cloud/dagster_insights/insights_utils.py +++ b/dagster-cloud/dagster_cloud/dagster_insights/insights_utils.py @@ -2,6 +2,7 @@ import dagster._check as check from dagster import ( + AssetCheckResult, AssetExecutionContext, AssetKey, AssetMaterialization, @@ -42,10 +43,11 @@ def extract_asset_info_from_event(context, dagster_event, record_observation_usa if isinstance(dagster_event, AssetMaterialization): return dagster_event.asset_key, dagster_event.partition - if isinstance(dagster_event, AssetObservation) and record_observation_usage: - return dagster_event.asset_key, dagster_event.partition + if isinstance(dagster_event, (AssetCheckResult, AssetObservation)) and record_observation_usage: + partition = dagster_event.partition if isinstance(dagster_event, AssetObservation) else None + return dagster_event.asset_key, partition - if isinstance(dagster_event, AssetObservation): + if isinstance(dagster_event, (AssetCheckResult, AssetObservation)): return None, None if isinstance(dagster_event, Output): diff --git a/dagster-cloud/dagster_cloud/dagster_insights/metrics_utils.py b/dagster-cloud/dagster_cloud/dagster_insights/metrics_utils.py index 73654c6..2fa2932 100644 --- a/dagster-cloud/dagster_cloud/dagster_insights/metrics_utils.py +++ b/dagster-cloud/dagster_cloud/dagster_insights/metrics_utils.py @@ -101,7 +101,7 @@ def get_post_request_params( raise RuntimeError("This asset only functions in a running Dagster Cloud instance") return ( - instance.rest_requests_session, + instance.requests_managed_retries_session, instance.dagster_cloud_gen_insights_url_url, instance.dagster_cloud_api_headers(DagsterCloudInstanceScope.DEPLOYMENT), instance.dagster_cloud_api_timeout, diff --git a/dagster-cloud/dagster_cloud/dagster_insights/snowflake/dbt_wrapper.py b/dagster-cloud/dagster_cloud/dagster_insights/snowflake/dbt_wrapper.py index 9c1137f..6c9a6ef 100644 --- a/dagster-cloud/dagster_cloud/dagster_insights/snowflake/dbt_wrapper.py +++ b/dagster-cloud/dagster_cloud/dagster_insights/snowflake/dbt_wrapper.py @@ -103,7 +103,9 @@ def jaffle_shop_dbt_assets( asset_and_partition_key_to_unique_id: List[Tuple[AssetKey, Optional[str], Any]] = [] for dagster_event in dagster_events: - if isinstance(dagster_event, (AssetMaterialization, AssetObservation, Output)): + if isinstance( + dagster_event, (AssetMaterialization, AssetObservation, Output, AssetCheckResult) + ): unique_id = dagster_event.metadata["unique_id"].value asset_key, partition = extract_asset_info_from_event( context, dagster_event, record_observation_usage diff --git a/dagster-cloud/dagster_cloud/instance/__init__.py b/dagster-cloud/dagster_cloud/instance/__init__.py index 810e36e..f9fdfee 100644 --- a/dagster-cloud/dagster_cloud/instance/__init__.py +++ b/dagster-cloud/dagster_cloud/instance/__init__.py @@ -20,8 +20,9 @@ from dagster._core.storage.dagster_run import DagsterRun from dagster._serdes import ConfigurableClassData from dagster_cloud_cli.core.graphql_client import ( + create_agent_graphql_client, + create_agent_http_client, create_graphql_requests_session, - create_proxy_client, get_agent_headers, ) from dagster_cloud_cli.core.headers.auth import DagsterCloudInstanceScope @@ -91,6 +92,7 @@ def __init__( self._graphql_requests_session: Optional[Session] = None self._rest_requests_session: Optional[Session] = None self._graphql_client = None + self._http_client = None assert self.dagster_cloud_url @@ -157,16 +159,16 @@ def ref_for_deployment(self, deployment_name: str) -> InstanceRef: return my_ref._replace(custom_instance_class_data=new_class_data) def organization_scoped_graphql_client(self): - return create_proxy_client( - self.graphql_requests_session, + return create_agent_graphql_client( + self.client_managed_retries_requests_session, self.dagster_cloud_graphql_url, self._dagster_cloud_api_config_for_deployment(None), scope=DagsterCloudInstanceScope.ORGANIZATION, ) def graphql_client_for_deployment(self, deployment_name: Optional[str]): - return create_proxy_client( - self.graphql_requests_session, + return create_agent_graphql_client( + self.client_managed_retries_requests_session, self.dagster_cloud_graphql_url, self._dagster_cloud_api_config_for_deployment(deployment_name), scope=DagsterCloudInstanceScope.DEPLOYMENT, @@ -181,15 +183,15 @@ def headers_for_deployment(self, deployment_name: str): def create_graphql_client( self, scope: DagsterCloudInstanceScope = DagsterCloudInstanceScope.DEPLOYMENT ): - return create_proxy_client( - self.graphql_requests_session, + return create_agent_graphql_client( + self.client_managed_retries_requests_session, self.dagster_cloud_graphql_url, self._dagster_cloud_api_config, scope=scope, ) @property - def graphql_requests_session(self): + def client_managed_retries_requests_session(self): """A shared requests Session to use between GraphQL clients. Retries handled in GraphQL client layer. @@ -202,7 +204,7 @@ def graphql_requests_session(self): return self._graphql_requests_session @property - def rest_requests_session(self): + def requests_managed_retries_session(self): """A requests session to use for non-GraphQL Rest API requests. Retries handled by requests. @@ -228,6 +230,17 @@ def graphql_client(self): return self._graphql_client + @property + def http_client(self): + if self._http_client is None: + self._http_client = create_agent_http_client( + self.client_managed_retries_requests_session, + self._dagster_cloud_api_config, + scope=DagsterCloudInstanceScope.DEPLOYMENT, + ) + + return self._http_client + @property def dagster_cloud_url(self): if "url" in self._dagster_cloud_api_config: @@ -284,6 +297,10 @@ def dagit_url(self): def dagster_cloud_graphql_url(self): return f"{self.dagster_cloud_url}/graphql" + @property + def dagster_cloud_store_events_url(self): + return f"{self.dagster_cloud_url}/store_events" + @property def dagster_cloud_upload_logs_url(self): return f"{self.dagster_cloud_url}/upload_logs" @@ -296,14 +313,6 @@ def dagster_cloud_gen_logs_url_url(self): def dagster_cloud_gen_insights_url_url(self) -> str: return f"{self.dagster_cloud_url}/gen_insights_url" - @property - def dagster_cloud_gen_artifacts_post(self) -> str: - return f"{self.dagster_cloud_url}/gen_artifacts_post" - - @property - def dagster_cloud_gen_artifacts_get(self) -> str: - return f"{self.dagster_cloud_url}/gen_artifacts_get" - @property def dagster_cloud_upload_job_snap_url(self): return f"{self.dagster_cloud_url}/upload_job_snapshot" diff --git a/dagster-cloud/dagster_cloud/metrics/__init__.py b/dagster-cloud/dagster_cloud/metrics/__init__.py index 803213b..b30924e 100644 --- a/dagster-cloud/dagster_cloud/metrics/__init__.py +++ b/dagster-cloud/dagster_cloud/metrics/__init__.py @@ -64,7 +64,7 @@ def put_context_metrics( "stepKey": context.get_step_execution_context().step.key, "codeLocationName": context.dagster_run.external_job_origin.location_name, "repositoryName": ( - context.dagster_run.external_job_origin.external_repository_origin.repository_name + context.dagster_run.external_job_origin.repository_origin.repository_name ), "assetMetricDefinitions": [ { @@ -94,7 +94,7 @@ def put_context_metrics( "stepKey": context.get_step_execution_context().step.key, "codeLocationName": context.dagster_run.external_job_origin.location_name, "repositoryName": ( - context.dagster_run.external_job_origin.external_repository_origin.repository_name + context.dagster_run.external_job_origin.repository_origin.repository_name ), "jobMetricDefinitions": [ { @@ -219,7 +219,7 @@ def store_dbt_adapter_metrics( "stepKey": context.get_step_execution_context().step.key, "codeLocationName": context.dagster_run.external_job_origin.location_name, "repositoryName": ( - context.dagster_run.external_job_origin.external_repository_origin.repository_name + context.dagster_run.external_job_origin.repository_origin.repository_name ), "assetMetricDefinitions": assetMetricDefinitions, } diff --git a/dagster-cloud/dagster_cloud/secrets/loader.py b/dagster-cloud/dagster_cloud/secrets/loader.py index f8ae6f2..9ccc11d 100644 --- a/dagster-cloud/dagster_cloud/secrets/loader.py +++ b/dagster-cloud/dagster_cloud/secrets/loader.py @@ -15,8 +15,6 @@ } """ -from dagster_cloud_cli.core.errors import GraphQLStorageError - class DagsterCloudSecretsLoader(SecretsLoader, ConfigurableClass): def __init__( @@ -26,10 +24,7 @@ def __init__( self._inst_data = inst_data def _execute_query(self, query, variables=None): - res = self._instance.graphql_client.execute(query, variable_values=variables) - if "errors" in res: - raise GraphQLStorageError(res) - return res + return self._instance.graphql_client.execute(query, variable_values=variables) def get_secrets_for_environment(self, location_name: Optional[str]) -> Dict[str, str]: res = self._execute_query( diff --git a/dagster-cloud/dagster_cloud/storage/compute_logs/compute_log_manager.py b/dagster-cloud/dagster_cloud/storage/compute_logs/compute_log_manager.py index 821b30f..8850d3d 100644 --- a/dagster-cloud/dagster_cloud/storage/compute_logs/compute_log_manager.py +++ b/dagster-cloud/dagster_cloud/storage/compute_logs/compute_log_manager.py @@ -101,7 +101,7 @@ def upload_to_cloud_storage( } if partial: params["partial"] = True - resp = self._instance.rest_requests_session.post( + resp = self._instance.requests_managed_retries_session.post( self._instance.dagster_cloud_gen_logs_url_url, params=params, headers=self._instance.dagster_cloud_api_headers(DagsterCloudInstanceScope.DEPLOYMENT), diff --git a/dagster-cloud/dagster_cloud/storage/event_logs/queries.py b/dagster-cloud/dagster_cloud/storage/event_logs/queries.py index 4f860ab..e11b67b 100644 --- a/dagster-cloud/dagster_cloud/storage/event_logs/queries.py +++ b/dagster-cloud/dagster_cloud/storage/event_logs/queries.py @@ -353,7 +353,17 @@ } } } - """ +""" + +STORE_EVENT_BATCH_MUTATION = """ + mutation StoreEventBatch($eventRecords: [EventLogEntryInput!]!) { + eventLogs { + StoreEventBatch(eventRecords: $eventRecords) { + ok + } + } + } +""" DELETE_EVENTS_MUTATION = """ mutation DeleteEvents($runId: String!) { diff --git a/dagster-cloud/dagster_cloud/storage/event_logs/storage.py b/dagster-cloud/dagster_cloud/storage/event_logs/storage.py index 64c3217..19b70f2 100644 --- a/dagster-cloud/dagster_cloud/storage/event_logs/storage.py +++ b/dagster-cloud/dagster_cloud/storage/event_logs/storage.py @@ -1,6 +1,8 @@ import json +import os from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -57,10 +59,16 @@ ) from dagster._utils.error import SerializableErrorInfo from dagster._utils.merger import merge_dicts -from dagster_cloud_cli.core.errors import GraphQLStorageError +from dagster_cloud_cli.core.errors import DagsterCloudAgentServerError from typing_extensions import Self +from dagster_cloud.api.dagster_cloud_api import StoreEventBatchRequest from dagster_cloud.storage.event_logs.utils import truncate_event +from dagster_cloud.util import compressed_namedtuple_upload_file + +if TYPE_CHECKING: + from dagster_cloud.instance import DagsterCloudAgentInstance + from .queries import ( ADD_DYNAMIC_PARTITIONS_MUTATION, @@ -100,6 +108,7 @@ IS_PERSISTENT_QUERY, REINDEX_MUTATION, SET_CONCURRENCY_SLOTS_MUTATION, + STORE_EVENT_BATCH_MUTATION, STORE_EVENT_MUTATION, UPDATE_ASSET_CACHED_STATUS_DATA_MUTATION, UPGRADE_EVENT_LOG_STORAGE_MUTATION, @@ -109,6 +118,20 @@ ) +def _input_for_event(event: EventLogEntry): + event = truncate_event(event) + return { + "errorInfo": _input_for_serializable_error_info(event.error_info), + "level": event.level, + "userMessage": event.user_message, + "runId": event.run_id, + "timestamp": event.timestamp, + "stepKey": event.step_key, + "pipelineName": event.job_name, + "dagsterEvent": _input_for_dagster_event(event.dagster_event), + } + + def _input_for_serializable_error_info(serializable_error_info: Optional[SerializableErrorInfo]): check.opt_inst_param(serializable_error_info, "serializable_error_info", SerializableErrorInfo) @@ -395,16 +418,19 @@ def _graphql_client(self): else self._instance.graphql_client ) + @property + def agent_instance(self) -> "DagsterCloudAgentInstance": + from dagster_cloud.instance import DagsterCloudAgentInstance + + return cast(DagsterCloudAgentInstance, self._instance) + def _execute_query(self, query, variables=None, headers=None, idempotent_mutation=False): - res = self._graphql_client.execute( + return self._graphql_client.execute( query, variable_values=variables, headers=headers, idempotent_mutation=idempotent_mutation, ) - if "errors" in res: - raise GraphQLStorageError(res) - return res def get_records_for_run( self, @@ -509,29 +535,44 @@ def get_step_stats_for_run( for stats in step_stats ] + def _store_events_http(self, headers, events: Sequence[EventLogEntry]): + batch_request = StoreEventBatchRequest(event_log_entries=events) + with compressed_namedtuple_upload_file(batch_request) as f: + self.agent_instance.http_client.post( + url=self.agent_instance.dagster_cloud_store_events_url, + headers=headers, + files={"store_events.tmp": f}, + ) + def store_event(self, event: EventLogEntry): check.inst_param(event, "event", EventLogEntry) + headers = {"Idempotency-Key": str(uuid4())} - event = truncate_event(event) + if os.getenv("DAGSTER_CLOUD_STORE_EVENT_OVER_HTTP"): + self._store_events_http(headers, [event]) + else: + self._execute_query( + STORE_EVENT_MUTATION, + variables={ + "eventRecord": _input_for_event(event), + }, + headers=headers, + ) + def store_event_batch(self, events: Sequence[EventLogEntry]): + check.sequence_param(events, "events", of_type=EventLogEntry) headers = {"Idempotency-Key": str(uuid4())} - self._execute_query( - STORE_EVENT_MUTATION, - variables={ - "eventRecord": { - "errorInfo": _input_for_serializable_error_info(event.error_info), - "level": event.level, - "userMessage": event.user_message, - "runId": event.run_id, - "timestamp": event.timestamp, - "stepKey": event.step_key, - "pipelineName": event.job_name, - "dagsterEvent": _input_for_dagster_event(event.dagster_event), - } - }, - headers=headers, - ) + if os.getenv("DAGSTER_CLOUD_STORE_EVENT_OVER_HTTP"): + self._store_events_http(headers, events) + else: + self._execute_query( + STORE_EVENT_BATCH_MUTATION, + variables={ + "eventRecords": [_input_for_event(event) for event in events], + }, + headers=headers, + ) def delete_events(self, run_id: str): self._execute_query( @@ -802,10 +843,13 @@ def get_latest_asset_partition_materialization_attempts_without_materializations "afterStorageId": after_storage_id, }, ) - return res["data"]["eventLogs"][ + result = res["data"]["eventLogs"][ "getLatestAssetPartitionMaterializationAttemptsWithoutMaterializations" ] + # Translate list to tuple + return {key: tuple(val) for key, val in result.items()} + def get_event_tags_for_asset( self, asset_key: AssetKey, @@ -898,7 +942,7 @@ def initialize_concurrency_limit_to_default(self, concurrency_key: str) -> bool: if error["className"] == "DagsterInvalidInvocationError": raise DagsterInvalidInvocationError(error["message"]) else: - raise GraphQLStorageError(res) + raise DagsterCloudAgentServerError(res) return result.get("success") def set_concurrency_slots(self, concurrency_key: str, num: int) -> None: @@ -915,7 +959,7 @@ def set_concurrency_slots(self, concurrency_key: str, num: int) -> None: if error["className"] == "DagsterInvalidInvocationError": raise DagsterInvalidInvocationError(error["message"]) else: - raise GraphQLStorageError(res) + raise DagsterCloudAgentServerError(res) return res def delete_concurrency_limit(self, concurrency_key: str) -> None: @@ -930,7 +974,7 @@ def delete_concurrency_limit(self, concurrency_key: str) -> None: if error["className"] == "DagsterInvalidInvocationError": raise DagsterInvalidInvocationError(error["message"]) else: - raise GraphQLStorageError(res) + raise DagsterCloudAgentServerError(res) def get_concurrency_keys(self) -> Set[str]: res = self._execute_query(GET_CONCURRENCY_KEYS_QUERY) diff --git a/dagster-cloud/dagster_cloud/storage/runs/storage.py b/dagster-cloud/dagster_cloud/storage/runs/storage.py index 6e3d06f..9b8766a 100644 --- a/dagster-cloud/dagster_cloud/storage/runs/storage.py +++ b/dagster-cloud/dagster_cloud/storage/runs/storage.py @@ -22,7 +22,7 @@ ) from dagster._core.events import DagsterEvent from dagster._core.execution.backfill import BulkActionStatus, PartitionBackfill -from dagster._core.remote_representation.origin import ExternalJobOrigin +from dagster._core.remote_representation.origin import RemoteJobOrigin from dagster._core.snap import ( ExecutionPlanSnapshot, JobSnapshot, @@ -47,7 +47,7 @@ ) from dagster._utils import utc_datetime_from_timestamp from dagster._utils.merger import merge_dicts -from dagster_cloud_cli.core.errors import GraphQLStorageError +from dagster_cloud_cli.core.errors import DagsterCloudAgentServerError from typing_extensions import Self from .queries import ( @@ -172,12 +172,9 @@ def _graphql_client(self): ) def _execute_query(self, query, variables=None, idempotent_mutation=False): - res = self._graphql_client.execute( + return self._graphql_client.execute( query, variable_values=variables, idempotent_mutation=idempotent_mutation ) - if "errors" in res: - raise GraphQLStorageError(res) - return res def add_run(self, dagster_run: DagsterRun): check.inst_param(dagster_run, "dagster_run", DagsterRun) @@ -196,7 +193,7 @@ def add_run(self, dagster_run: DagsterRun): if error["className"] == "DagsterSnapshotDoesNotExist": raise DagsterSnapshotDoesNotExist(error["message"]) else: - raise GraphQLStorageError(res) + raise DagsterCloudAgentServerError(res) return dagster_run @@ -513,7 +510,7 @@ def set_cursor_values(self, pairs: Mapping[str, str]): return NotImplementedError("KVS is not supported from the user cloud") # Migrating run history - def replace_job_origin(self, run: DagsterRun, job_origin: ExternalJobOrigin): + def replace_job_origin(self, run: DagsterRun, job_origin: RemoteJobOrigin): self._execute_query( MUTATE_JOB_ORIGIN, variables={ diff --git a/dagster-cloud/dagster_cloud/storage/schedules/storage.py b/dagster-cloud/dagster_cloud/storage/schedules/storage.py index 6396874..68d9d5d 100644 --- a/dagster-cloud/dagster_cloud/storage/schedules/storage.py +++ b/dagster-cloud/dagster_cloud/storage/schedules/storage.py @@ -21,7 +21,6 @@ deserialize_value, serialize_value, ) -from dagster_cloud_cli.core.errors import GraphQLStorageError from typing_extensions import Self from .queries import ( @@ -67,10 +66,7 @@ def _graphql_client(self): ) def _execute_query(self, query, variables=None): - res = self._graphql_client.execute(query, variable_values=variables) - if "errors" in res: - raise GraphQLStorageError(res) - return res + return self._graphql_client.execute(query, variable_values=variables) def wipe(self): raise Exception("Not allowed to wipe from user cloud") diff --git a/dagster-cloud/dagster_cloud/util/__init__.py b/dagster-cloud/dagster_cloud/util/__init__.py index 0ccd00a..acc7ff0 100644 --- a/dagster-cloud/dagster_cloud/util/__init__.py +++ b/dagster-cloud/dagster_cloud/util/__init__.py @@ -1,11 +1,15 @@ +import zlib from collections import namedtuple -from typing import Any, Dict, List, Mapping +from contextlib import contextmanager +from io import BytesIO +from typing import Any, Dict, List, Mapping, NamedTuple from dagster import ( Field, _check as check, ) from dagster._config import BoolSourceType, IntSourceType, StringSourceType +from dagster._serdes import serialize_value from dagster._serdes.utils import create_snapshot_id @@ -81,3 +85,10 @@ def keys_not_none( dictionary: Mapping[str, Any], ) -> bool: return all(key in dictionary and dictionary[key] is not None for key in keys) + + +@contextmanager +def compressed_namedtuple_upload_file(to_serialize: NamedTuple): + compressed_data = zlib.compress(serialize_value(to_serialize).encode("utf-8")) + with BytesIO(compressed_data) as f: + yield f diff --git a/dagster-cloud/dagster_cloud/version.py b/dagster-cloud/dagster_cloud/version.py index 93d2517..14d9d2f 100644 --- a/dagster-cloud/dagster_cloud/version.py +++ b/dagster-cloud/dagster_cloud/version.py @@ -1 +1 @@ -__version__ = "1.6.14" +__version__ = "1.7.0" diff --git a/dagster-cloud/dagster_cloud/workspace/config_schema/ecs.py b/dagster-cloud/dagster_cloud/workspace/config_schema/ecs.py index 5514bd9..8795618 100644 --- a/dagster-cloud/dagster_cloud/workspace/config_schema/ecs.py +++ b/dagster-cloud/dagster_cloud/workspace/config_schema/ecs.py @@ -120,6 +120,11 @@ is_required=False, description="Additional tags to apply to the launched ECS task for a code server.", ), + "server_health_check": Field( + Permissive(), + is_required=False, + description="Health check to include in code server task definitions.", + ), } diff --git a/dagster-cloud/dagster_cloud/workspace/ecs/client.py b/dagster-cloud/dagster_cloud/workspace/ecs/client.py index d26d0f1..7551cc6 100644 --- a/dagster-cloud/dagster_cloud/workspace/ecs/client.py +++ b/dagster-cloud/dagster_cloud/workspace/ecs/client.py @@ -148,6 +148,7 @@ def register_task_definition( mount_points=None, volumes=None, linux_parameters=None, + health_check=None, ): container_name = container_name or family @@ -187,6 +188,7 @@ def register_task_definition( mount_points=mount_points, volumes=volumes, linux_parameters=linux_parameters, + health_check=health_check, ) try: @@ -239,6 +241,7 @@ def create_service( runtime_platform=None, mount_points=None, volumes=None, + health_check=None, ): logger = logger or logging.getLogger("dagster_cloud.EcsClient") @@ -266,6 +269,7 @@ def create_service( mount_points=mount_points, volumes=volumes, linux_parameters=ECS_EXEC_LINUX_PARAMETERS if allow_ecs_exec else None, + health_check=health_check, ) service_registry_arn = None diff --git a/dagster-cloud/dagster_cloud/workspace/ecs/launcher.py b/dagster-cloud/dagster_cloud/workspace/ecs/launcher.py index c7f31d2..8d07ef9 100644 --- a/dagster-cloud/dagster_cloud/workspace/ecs/launcher.py +++ b/dagster-cloud/dagster_cloud/workspace/ecs/launcher.py @@ -82,6 +82,7 @@ def __init__( run_sidecar_containers: Optional[Sequence[Mapping[str, Any]]] = None, server_ecs_tags: Optional[Sequence[Mapping[str, Optional[str]]]] = None, run_ecs_tags: Optional[Sequence[Mapping[str, Optional[str]]]] = None, + server_health_check: Optional[Mapping[str, Any]] = None, **kwargs, ): self.ecs = boto3.client("ecs") @@ -142,6 +143,10 @@ def __init__( self.server_ecs_tags = check.opt_sequence_param(server_ecs_tags, "server_ecs_tags") self.run_ecs_tags = check.opt_sequence_param(run_ecs_tags, "run_ecs_tags") + self.server_health_check = check.opt_mapping_param( + server_health_check, "server_health_check" + ) + self.client = Client( cluster_name=self.cluster, subnet_ids=self.subnets, @@ -387,6 +392,7 @@ def _start_new_server_spinup( run_sidecar_containers=self.run_sidecar_containers, server_ecs_tags=self.server_ecs_tags, run_ecs_tags=self.run_ecs_tags, + server_health_check=self.server_health_check, ).merge(EcsContainerContext.create_from_config(metadata.container_context)) # disallow multiple replicas for code locations acting as pex servers @@ -457,11 +463,10 @@ def _start_new_server_spinup( runtime_platform=container_context.runtime_platform, mount_points=container_context.mount_points, volumes=container_context.volumes, + health_check=container_context.server_health_check, ) self._logger.info( - "Created a new service at hostname {} for {}:{}, waiting for server to be ready...".format( - service.hostname, deployment_name, location_name - ) + f"Created a new service at hostname {service.hostname} for {deployment_name}:{location_name}, waiting for server to be ready..." ) endpoint = ServerEndpoint( diff --git a/dagster-cloud/dagster_cloud/workspace/ecs/utils.py b/dagster-cloud/dagster_cloud/workspace/ecs/utils.py index f7db6ae..0c8325c 100644 --- a/dagster-cloud/dagster_cloud/workspace/ecs/utils.py +++ b/dagster-cloud/dagster_cloud/workspace/ecs/utils.py @@ -2,7 +2,7 @@ import re from typing import Optional -from dagster._core.remote_representation.origin import ExternalJobOrigin +from dagster._core.remote_representation.origin import RemoteJobOrigin from dagster_aws.ecs.utils import sanitize_family from ..user_code_launcher.utils import get_human_readable_label, unique_resource_name @@ -55,14 +55,14 @@ def get_server_task_definition_family( def get_run_task_definition_family( organization_name: Optional[str], deployment_name: str, - job_origin: ExternalJobOrigin, + job_origin: RemoteJobOrigin, ) -> str: # Truncate the location name if it's too long (but add a unique suffix at the end so that no matter what it's unique) # Relies on the fact that org name and deployment name are always <= 64 characters long to # stay well underneath the 255 character limit imposed by ECS job_name = job_origin.job_name - repo_name = job_origin.external_repository_origin.repository_name - location_name = job_origin.external_repository_origin.code_location_origin.location_name + repo_name = job_origin.repository_origin.repository_name + location_name = job_origin.repository_origin.code_location_origin.location_name assert len(str(organization_name)) <= 64 assert len(deployment_name) <= 64 diff --git a/dagster-cloud/dagster_cloud/workspace/kubernetes/launcher.py b/dagster-cloud/dagster_cloud/workspace/kubernetes/launcher.py index 189d1a3..7c025d2 100644 --- a/dagster-cloud/dagster_cloud/workspace/kubernetes/launcher.py +++ b/dagster-cloud/dagster_cloud/workspace/kubernetes/launcher.py @@ -430,10 +430,7 @@ def _start_new_server_spinup( ), ) self._logger.info( - "Created deployment {} in namespace {}".format( - deployment_reponse.metadata.name, - container_context.namespace, - ) + f"Created deployment {deployment_reponse.metadata.name} in namespace {container_context.namespace}" ) except ApiException as e: self._logger.error( @@ -635,9 +632,7 @@ def _remove_server_handle(self, server_handle: K8sHandle) -> None: raise self._logger.info( - "Removed deployment and service {} in namespace {}".format( - server_handle.name, server_handle.namespace - ) + f"Removed deployment and service {server_handle.name} in namespace {server_handle.namespace}" ) def __exit__(self, exception_type, exception_value, traceback): diff --git a/dagster-cloud/dagster_cloud/workspace/kubernetes/utils.py b/dagster-cloud/dagster_cloud/workspace/kubernetes/utils.py index 1f5168c..f2d4182 100644 --- a/dagster-cloud/dagster_cloud/workspace/kubernetes/utils.py +++ b/dagster-cloud/dagster_cloud/workspace/kubernetes/utils.py @@ -4,7 +4,6 @@ from typing import Mapping, Optional import kubernetes -from dagster._utils.merger import merge_dicts from dagster_k8s.client import DagsterKubernetesClient from dagster_k8s.models import k8s_model_from_dict from kubernetes import client @@ -113,75 +112,34 @@ def construct_code_location_deployment( args, server_timestamp: float, ): - pull_policy = container_context.image_pull_policy - env_config_maps = container_context.env_config_maps - env_secrets = container_context.env_secrets - service_account_name = container_context.service_account_name - image_pull_secrets = container_context.image_pull_secrets - volume_mounts = container_context.volume_mounts - - volumes = container_context.volumes - resources = container_context.resources - - scheduler_name = container_context.scheduler_name - security_context = container_context.security_context - - env = merge_dicts( - metadata.get_grpc_server_env( - SERVICE_PORT, location_name, instance.ref_for_deployment(deployment_name) - ), - container_context.get_environment_dict(), + env = metadata.get_grpc_server_env( + SERVICE_PORT, location_name, instance.ref_for_deployment(deployment_name) ) user_defined_config = container_context.server_k8s_config container_config = copy.deepcopy(user_defined_config.container_config) - container_config["args"] = args - - if pull_policy: - container_config["image_pull_policy"] = pull_policy - user_defined_env_vars = container_config.pop("env", []) - user_defined_env_from = container_config.pop("env_from", []) - user_defined_volume_mounts = container_config.pop("volume_mounts", []) - user_defined_resources = container_config.pop("resources", {}) - user_defined_security_context = container_config.pop("security_context", None) - - container_name = container_config.get("name", "dagster") + container_name = container_config.pop("name", "dagster") container_config = { **container_config, + "args": args, "name": container_name, "image": metadata.image, - "env": [{"name": key, "value": value} for key, value in env.items()] - + user_defined_env_vars, - "env_from": ( - [{"config_map_ref": {"name": config_map}} for config_map in env_config_maps] - + [{"secret_ref": {"name": secret_name}} for secret_name in env_secrets] - + user_defined_env_from + "env": ( + [{"name": key, "value": value} for key, value in env.items()] + user_defined_env_vars ), - "volume_mounts": volume_mounts + user_defined_volume_mounts, - "resources": user_defined_resources or resources, - "security_context": user_defined_security_context or security_context, } pod_spec_config = copy.deepcopy(user_defined_config.pod_spec_config) - user_defined_image_pull_secrets = pod_spec_config.pop("image_pull_secrets", []) - user_defined_service_account_name = pod_spec_config.pop("service_account_name", None) user_defined_containers = pod_spec_config.pop("containers", []) - user_defined_volumes = pod_spec_config.pop("volumes", []) - user_defined_scheduler_name = pod_spec_config.pop("scheduler_name", None) pod_spec_config = { **pod_spec_config, - "image_pull_secrets": [{"name": x["name"]} for x in image_pull_secrets] - + user_defined_image_pull_secrets, - "service_account_name": user_defined_service_account_name or service_account_name, "containers": [container_config] + user_defined_containers, - "volumes": volumes + user_defined_volumes, - "scheduler_name": user_defined_scheduler_name or scheduler_name, } pod_template_spec_metadata = copy.deepcopy(user_defined_config.pod_template_spec_metadata) diff --git a/dagster-cloud/dagster_cloud/workspace/user_code_launcher/user_code_launcher.py b/dagster-cloud/dagster_cloud/workspace/user_code_launcher/user_code_launcher.py index ffaa55f..6f725b1 100644 --- a/dagster-cloud/dagster_cloud/workspace/user_code_launcher/user_code_launcher.py +++ b/dagster-cloud/dagster_cloud/workspace/user_code_launcher/user_code_launcher.py @@ -9,7 +9,7 @@ import time import zlib from abc import abstractmethod, abstractproperty -from concurrent.futures import ThreadPoolExecutor, wait +from concurrent.futures import ThreadPoolExecutor, as_completed, wait from contextlib import AbstractContextManager from typing import ( Any, @@ -31,6 +31,7 @@ ) import dagster._check as check +import grpc import pendulum from dagster import BoolSource, Field, IntSource from dagster._api.list_repositories import sync_list_repositories_grpc @@ -38,7 +39,7 @@ from dagster._core.errors import DagsterUserCodeUnreachableError from dagster._core.instance import MayHaveInstanceWeakref from dagster._core.launcher import RunLauncher -from dagster._core.remote_representation import ExternalRepositoryOrigin +from dagster._core.remote_representation import RemoteRepositoryOrigin from dagster._core.remote_representation.origin import ( CodeLocationOrigin, RegisteredCodeLocationOrigin, @@ -49,7 +50,7 @@ from dagster._utils.error import SerializableErrorInfo, serializable_error_info_from_exc_info from dagster._utils.merger import merge_dicts from dagster._utils.typed_dict import init_optional_typeddict -from dagster_cloud_cli.core.errors import GraphQLStorageError, raise_http_error +from dagster_cloud_cli.core.errors import raise_http_error from dagster_cloud_cli.core.workspace import CodeDeploymentMetadata from typing_extensions import Self, TypeAlias @@ -92,8 +93,8 @@ # Check on pending delete servers every 30th reconcile PENDING_DELETE_SERVER_CHECK_INTERVAL = 30 -# How often to sync actual_entries with pex server liveness -MULTIPEX_ACTUAL_ENTRIES_REFRESH_INTERVAL = 30 +# How often to sync actual_entries with server liveness +ACTUAL_ENTRIES_REFRESH_INTERVAL = 30 CLEANUP_SERVER_GRACE_PERIOD_SECONDS = 3600 @@ -318,6 +319,8 @@ def __init__( self._grpc_servers: Dict[ DeploymentAndLocation, Union[DagsterCloudGrpcServer, SerializableErrorInfo] ] = {} + self._first_unavailable_times: Dict[DeploymentAndLocation, float] = {} + self._pending_delete_grpc_server_handles: Set[ServerHandle] = set() self._grpc_servers_lock = threading.Lock() self._per_location_metrics: Dict[ @@ -387,11 +390,8 @@ def get_active_agent_ids(self) -> Set[str]: "retrieve active agent_ids. Just returning this agent as an active ID." ) return set([self._instance.instance_uuid]) - if "errors" in result: - raise GraphQLStorageError(result["errors"]) - else: - self._logger.info(f"Active agent ids response: {result}") - return set(agent_data["id"] for agent_data in result["data"]["agents"]) + self._logger.info(f"Active agent ids response: {result}") + return set(agent_data["id"] for agent_data in result["data"]["agents"]) @property def code_server_metrics_enabled(self) -> bool: @@ -554,14 +554,10 @@ def _update_workspace_entry( with open(dst, "rb") as f: self._logger.info( - "Uploading workspace entry for {deployment_name}:{location_name} ({size} bytes)".format( - deployment_name=deployment_name, - location_name=workspace_entry.location_name, - size=os.path.getsize(dst), - ) + f"Uploading workspace entry for {deployment_name}:{workspace_entry.location_name} ({os.path.getsize(dst)} bytes)" ) - resp = self._instance.rest_requests_session.put( + resp = self._instance.requests_managed_retries_session.put( self._instance.dagster_cloud_upload_workspace_entry_url, headers=self._instance.headers_for_deployment(deployment_name), data={}, @@ -602,7 +598,7 @@ def _update_workspace_entry( _ = [f.result() for f in futures] with open(dst, "rb") as f: - resp = self._instance.rest_requests_session.put( + resp = self._instance.requests_managed_retries_session.put( self._instance.dagster_cloud_upload_workspace_entry_url, headers=self._instance.headers_for_deployment(deployment_name), data={}, @@ -645,7 +641,7 @@ def _get_upload_location_data( ) in list_repositories_response.repository_code_pointer_dict.items(): external_repository_chunks = list( client.streaming_external_repository( - external_repository_origin=ExternalRepositoryOrigin( + external_repository_origin=RemoteRepositoryOrigin( location_origin, repository_name, ), @@ -686,12 +682,8 @@ def _update_location_error( metadata: CodeDeploymentMetadata, ): self._logger.error( - "Unable to update {deployment_name}:{location_name}. Updating location with error data:" - " {error_info}.".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=str(error_info), - ) + f"Unable to update {deployment_name}:{location_name}. Updating location with error data:" + f" {error_info!s}." ) # Update serialized error @@ -1086,9 +1078,7 @@ def _reconcile_thread(self, shutdown_event): self.reconcile() except Exception: self._logger.error( - "Failure updating user code servers: {exc_info}".format( - exc_info=serializable_error_info_from_exc_info(sys.exc_info()), - ) + f"Failure updating user code servers: {serializable_error_info_from_exc_info(sys.exc_info())}" ) def reconcile(self) -> None: @@ -1103,15 +1093,14 @@ def reconcile(self) -> None: # Wait for the first time the desired metadata is set before reconciling return - if ( - time.time() - self._last_refreshed_actual_entries - > MULTIPEX_ACTUAL_ENTRIES_REFRESH_INTERVAL - ): + now = pendulum.now("UTC").timestamp() + + if now - self._last_refreshed_actual_entries > ACTUAL_ENTRIES_REFRESH_INTERVAL: try: self._refresh_actual_entries() except: self._logger.exception("Failed to refresh actual entries.") - self._last_refreshed_actual_entries = time.time() + self._last_refreshed_actual_entries = now self._reconcile( desired_entries, @@ -1145,9 +1134,7 @@ def _update_metrics_thread(self, shutdown_event): ) except Exception: self._logger.error( - "Failure updating user code server metrics: {exc_info}".format( - exc_info=serializable_error_info_from_exc_info(sys.exc_info()), - ) + f"Failure updating user code server metrics: {serializable_error_info_from_exc_info(sys.exc_info())}" ) @property @@ -1155,6 +1142,10 @@ def ready_to_serve_requests(self) -> bool: # thread-safe since reconcile_count is an integer return self._reconcile_count > 0 + def _make_check_on_running_server_endpoint(self, server_endpoint: ServerEndpoint): + # Ensure that server_endpoint is bound correctly + return lambda: server_endpoint.create_client().ping("") + def _refresh_actual_entries(self) -> None: for deployment_location, server in self._multipex_servers.items(): if deployment_location in self._actual_entries: @@ -1168,7 +1159,7 @@ def _refresh_actual_entries(self) -> None: # If it isn't, this is expected if ECS is currently spinning up this service # after it crashed. In this case, we want to wait for it to fully come up # before we remove actual entries. This ensures the recon loop uses the ECS - # replacement multiplex server and not try to spin up a new multipex server. + # replacement multipex server and not try to spin up a new multipex server. self._logger.info( "Multipex server entry exists but server is not running. " "Will wait for server to come up." @@ -1177,13 +1168,90 @@ def _refresh_actual_entries(self) -> None: deployment_name, location_name = deployment_location if not self._get_existing_pex_servers(deployment_name, location_name): self._logger.warning( - "Pex servers disappeared for %s, %s. Removing actual entries to" + "Pex servers disappeared for %s:%s. Removing actual entries to" " activate reconciliation logic.", deployment_name, location_name, ) del self._actual_entries[deployment_location] + # Check to see if any servers have become unresponsive + unavailable_server_timeout = int( + os.getenv( + "DAGSTER_CLOUD_CODE_SERVER_HEALTH_CHECK_REDEPLOY_TIMEOUT", + str(self._server_process_startup_timeout), + ) + ) + + if unavailable_server_timeout < 0: + return + + running_locations = { + deployment_location: endpoint_or_error + for deployment_location, endpoint_or_error in self.get_grpc_endpoints().items() + if ( + isinstance(endpoint_or_error, ServerEndpoint) + and deployment_location in self._actual_entries + ) + } + + if not running_locations: + return + + with ThreadPoolExecutor( + max_workers=max( + len(running_locations), + int(os.getenv("DAGSTER_CLOUD_CODE_SERVER_HEALTH_CHECK_MAX_WORKERS", "8")), + ), + thread_name_prefix="dagster_cloud_agent_server_health_check", + ) as executor: + futures = {} + for deployment_location, endpoint_or_error in running_locations.items(): + deployment_name, location_name = deployment_location + + futures[ + executor.submit(self._make_check_on_running_server_endpoint(endpoint_or_error)) + ] = deployment_location + + for future in as_completed(futures): + deployment_location = futures[future] + + deployment_name, location_name = deployment_location + try: + future.result() + + # Successful ping resets the tracked last unavailable time for this code server, if set + self._first_unavailable_times.pop(deployment_location, None) + except Exception as e: + if ( + isinstance(e, DagsterUserCodeUnreachableError) + and isinstance(e.__cause__, grpc.RpcError) + and cast(grpc.RpcError, e.__cause__).code() == grpc.StatusCode.UNAVAILABLE + ): + first_unavailable_time = self._first_unavailable_times.get( + deployment_location + ) + + now = pendulum.now("UTC").timestamp() + + if not first_unavailable_time: + self._logger.warning( + f"Code server for {deployment_name}:{location_name} failed a health check. If it continues failing for more than {unavailable_server_timeout} seconds, a replacement code server will be deployed." + ) + # Initialize the first unavailable time if set + self._first_unavailable_times[deployment_location] = now + elif now > first_unavailable_time + unavailable_server_timeout: + self._logger.warning( + f"Code server for {deployment_name}:{location_name} has been unresponsive for more than {unavailable_server_timeout} seconds. Deploying a new code server." + ) + del self._actual_entries[deployment_location] + del self._first_unavailable_times[deployment_location] + else: + self._logger.exception( + f"Code server for {deployment_name}:{location_name} health check failed, but the error did not indicate that the server was unavailable." + ) + self._first_unavailable_times.pop(deployment_location, None) + def _write_liveness_sentinel(self) -> None: """Write a sentinel file to indicate that the agent is alive and grpc servers have been spun up.""" pass @@ -1376,12 +1444,8 @@ def _reconcile( error_info = serializable_error_info_from_exc_info(sys.exc_info()) self._logger.error( - "Error while waiting for multipex server for {deployment_name}:{location_name}:" - " {error_info}".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=error_info, - ) + f"Error while waiting for multipex server for {deployment_name}:{location_name}:" + f" {error_info}" ) new_dagster_servers[to_update_key] = error_info # Clear out this multipex server so we don't try to use it again @@ -1454,13 +1518,8 @@ def _reconcile( except Exception: error_info = serializable_error_info_from_exc_info(sys.exc_info()) self._logger.error( - "Error while waiting for server for {deployment_name}:{location_name} for {deployment_info} to be" - " ready: {error_info}".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=error_info, - deployment_info=deployment_info, - ) + f"Error while waiting for server for {deployment_name}:{location_name} for {deployment_info} to be" + f" ready: {error_info}" ) server_or_error = error_info @@ -1487,6 +1546,7 @@ def _reconcile( # the server to start serving new requests with self._grpc_servers_lock: self._grpc_servers[to_update_key] = server_or_error + self._first_unavailable_times.pop(to_update_key, None) for to_update_key in to_update_keys: deployment_name, location_name = to_update_key @@ -1498,11 +1558,7 @@ def _reconcile( if server_handles: removed_any_servers = True self._logger.info( - "Removing {num_servers} existing servers for {deployment_name}:{location_name}".format( - num_servers=len(server_handles), - location_name=location_name, - deployment_name=deployment_name, - ) + f"Removing {len(server_handles)} existing servers for {deployment_name}:{location_name}" ) for server_handle in server_handles: @@ -1511,11 +1567,7 @@ def _reconcile( except Exception: self._logger.error( "Error while cleaning up after updating server for" - " {deployment_name}:{location_name}: {error_info}".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=serializable_error_info_from_exc_info(sys.exc_info()), - ) + f" {deployment_name}:{location_name}: {serializable_error_info_from_exc_info(sys.exc_info())}" ) # Remove any existing multipex servers other than the current one for each location @@ -1545,11 +1597,7 @@ def _reconcile( except Exception: self._logger.error( "Error while cleaning up old multipex server for" - " {deployment_name}:{location_name}: {error_info}".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=serializable_error_info_from_exc_info(sys.exc_info()), - ) + f" {deployment_name}:{location_name}: {serializable_error_info_from_exc_info(sys.exc_info())}" ) # On the current multipex server, shut down any old pex servers @@ -1572,11 +1620,7 @@ def _reconcile( except Exception: self._logger.error( "Error while cleaning up after updating server for" - " {deployment_name}:{location_name}: {error_info}".format( - deployment_name=deployment_name, - location_name=location_name, - error_info=serializable_error_info_from_exc_info(sys.exc_info()), - ) + f" {deployment_name}:{location_name}: {serializable_error_info_from_exc_info(sys.exc_info())}" ) if removed_any_servers: @@ -1835,7 +1879,7 @@ def _wait_for_dagster_server_process( client: DagsterGrpcClient, timeout, additional_check: Optional[Callable[[], None]] = None, - get_timeout_debug_info: Optional[Callable[[], None]] = None, + get_timeout_debug_info: Optional[Callable[[], Any]] = None, ) -> None: self._wait_for_server_process( client, timeout, additional_check, get_timeout_debug_info=get_timeout_debug_info @@ -1898,7 +1942,7 @@ def upload_job_snapshot( client = server.server_endpoint.create_client() location_origin = self._get_code_location_origin(job_selector.location_name) response = client.external_job( - ExternalRepositoryOrigin(location_origin, job_selector.repository_name), + RemoteRepositoryOrigin(location_origin, job_selector.repository_name), job_selector.job_name, ) if not response.serialized_job_data: @@ -1914,7 +1958,7 @@ def upload_job_snapshot( f.write(zlib.compress(response.serialized_job_data.encode("utf-8"))) with open(dst, "rb") as f: - resp = self._instance.rest_requests_session.put( + resp = self._instance.requests_managed_retries_session.put( self._instance.dagster_cloud_upload_job_snap_url, headers=self._instance.headers_for_deployment(deployment_name), data={}, diff --git a/dagster-cloud/setup.py b/dagster-cloud/setup.py index 40703c1..89aedda 100644 --- a/dagster-cloud/setup.py +++ b/dagster-cloud/setup.py @@ -40,8 +40,8 @@ def get_description() -> str: packages=find_packages(exclude=["dagster_cloud_tests*"]), include_package_data=True, install_requires=[ - "dagster==1.6.14", - "dagster-cloud-cli==1.6.14", + "dagster==1.7.0", + "dagster-cloud-cli==1.7.0", "pex>=2.1.132,<3", "questionary", "requests", @@ -66,13 +66,13 @@ def get_description() -> str: "dbt-snowflake", "dbt-postgres", "dbt-duckdb", - "dagster-dbt==0.22.14", - "dagster_k8s==0.22.14", + "dagster-dbt==0.23.0", + "dagster_k8s==0.23.0", ], "insights": ["pyarrow"], - "docker": ["docker", "dagster_docker==0.22.14"], - "kubernetes": ["kubernetes", "dagster_k8s==0.22.14"], - "ecs": ["dagster_aws==0.22.14", "boto3"], + "docker": ["docker", "dagster_docker==0.23.0"], + "kubernetes": ["kubernetes", "dagster_k8s==0.23.0"], + "ecs": ["dagster_aws==0.23.0", "boto3"], "sandbox": ["supervisor"], "pex": ["boto3"], "serverless": ["boto3"],