diff --git a/tests/test_reports.py b/tests/test_reports.py index 63a6cdd..a6d550e 100644 --- a/tests/test_reports.py +++ b/tests/test_reports.py @@ -177,11 +177,11 @@ class CustomChartFactory(CustomDataclassFactory[wr.CustomChart]): @classmethod def query(cls): return {"history": {"keys": ["x", "y"], "id": None, "name": None}} - + @classmethod def chart_fields(cls): return {"x": "x", "y": "y"} - + @classmethod def chart_strings(cls): return {"x": "x-axis", "y": "y-axis"} diff --git a/tests/test_workspaces.py b/tests/test_workspaces.py index 383ad22..48429a5 100644 --- a/tests/test_workspaces.py +++ b/tests/test_workspaces.py @@ -6,8 +6,8 @@ from polyfactory.pytest_plugin import register_fixture import wandb_workspaces.expr -import wandb_workspaces.workspaces as ws import wandb_workspaces.reports.v2 as wr +import wandb_workspaces.workspaces as ws from wandb_workspaces.utils.validators import ( validate_no_emoji, validate_spec_version, @@ -50,20 +50,15 @@ def runset_settings(cls): wandb_workspaces.expr.Metric("vwx").notin([8, 9, 0, "broccoli"]), ], ) - + @classmethod def sections(cls): return [ ws.Section(name="section1", panels=[wr.LinePlot()]), - ws.Section(name="section2", panels=[wr.BarPlot(title='tomato')]), + ws.Section(name="section2", panels=[wr.BarPlot(title="tomato")]), ] -@register_fixture -class WorkspaceSettingsFactory(CustomDataclassFactory[ws.WorkspaceSettings]): - __model__ = ws.WorkspaceSettings - - @register_fixture class SectionFactory(CustomDataclassFactory[ws.Section]): __model__ = ws.Section @@ -85,7 +80,6 @@ class SectionPanelSettingsFactory(CustomDataclassFactory[ws.SectionPanelSettings factory_names = [ "workspace_factory", - "workspace_settings_factory", "section_factory", "section_panel_settings_factory", "section_panel_settings_factory", diff --git a/wandb_workspaces/reports/v2/interface.py b/wandb_workspaces/reports/v2/interface.py index 7622639..ad9ae50 100644 --- a/wandb_workspaces/reports/v2/interface.py +++ b/wandb_workspaces/reports/v2/interface.py @@ -1,4 +1,5 @@ """Public interfaces for the Report API.""" + import os from datetime import datetime from typing import Dict, Iterable, Optional, Tuple, Union @@ -1081,7 +1082,13 @@ def from_model(cls, model: internal.WeaveBlock): "val" ] file = inputs["path"]["val"] - return cls(entity=entity, project=project, artifact=artifact, version=version, file=file) + return cls( + entity=entity, + project=project, + artifact=artifact, + version=version, + file=file, + ) @dataclass(config=dataclass_config) diff --git a/wandb_workspaces/workspaces/interface.py b/wandb_workspaces/workspaces/interface.py index e9d2e81..241d8a9 100644 --- a/wandb_workspaces/workspaces/interface.py +++ b/wandb_workspaces/workspaces/interface.py @@ -226,6 +226,18 @@ class WorkspaceSettings(Base): ignore_outliers: bool = False "Ignore outliers in all panels." + # Section settings + sort_panels_alphabetically: bool = False + "Sorts panels in all sections alphabetically" + + group_by_prefix: Literal["first", "last"] = "last" + """ + Group panels by the first or up to last prefix + + first: "a/b/c/d" -> section a + last: "a/b/c/d" -> section a/b/c + """ + # Panel settings remove_legends_from_panels: bool = False "Remove legends from all panels." @@ -262,67 +274,6 @@ class WorkspaceSettings(Base): ) "Search history for the panel search bar." - @classmethod - def from_model(cls, model: internal.ViewspecSectionSettings): - x_axis = expr._convert_be_to_fe_metric_name(model.x_axis) - - # TODO: There can be better mapping here between sane API names and the internal names - point_viz_method: Literal["bucketing", "downsampling"] - if (pvm := model.point_visualization_method) == "bucketing-gorilla": - point_viz_method = "bucketing" - elif pvm == "sampling": - point_viz_method = "downsampling" - elif pvm is None: - point_viz_method = "bucketing" - else: - raise Exception(f"Unexpected value for {pvm=}") - remove_legends_from_panels = False if model.suppress_legends is None else True - tooltip_number_of_runs = ( - "default" - if model.tooltip_number_of_runs is None - else model.tooltip_number_of_runs - ) - tooltip_color_run_names = True if model.color_run_names is None else False - max_runs = 10 if model.max_runs is None else model.max_runs - - return cls( - x_axis=x_axis, - x_min=model.x_axis_min, - x_max=model.x_axis_max, - smoothing_type=model.smoothing_type, - smoothing_weight=model.smoothing_weight, - ignore_outliers=model.ignore_outliers, - remove_legends_from_panels=remove_legends_from_panels, - tooltip_number_of_runs=tooltip_number_of_runs, - tooltip_color_run_names=tooltip_color_run_names, - max_runs=max_runs, - point_visualization_method=point_viz_method, - ) - - def to_model(self): - x_axis = expr._convert_fe_to_be_metric_name(self.x_axis) - point_viz_method = ( - "bucketing-gorilla" - if self.point_visualization_method == "bucketing" - else "sampling" - ) - suppress_legends = None if not self.remove_legends_from_panels else True - color_run_names = None if self.tooltip_color_run_names else False - - return internal.ViewspecSectionSettings( - x_axis=x_axis, - x_axis_min=self.x_min, - x_axis_max=self.x_max, - smoothing_type=self.smoothing_type, - smoothing_weight=self.smoothing_weight, - ignore_outliers=self.ignore_outliers, - suppress_legends=suppress_legends, - tooltip_number_of_runs=self.tooltip_number_of_runs, - color_run_names=color_run_names, - max_runs=self.max_runs, - point_visualization_method=point_viz_method, - ) - @dataclass(config=dataclass_config, repr=False) class RunSettings(Base): @@ -354,7 +305,11 @@ class RunsetSettings(Base): groupby: LList[expr.MetricType] = Field(default_factory=list) "A list of metrics to group by in the runset." - order: LList[expr.Ordering] = Field(default_factory=lambda: [expr.Ordering(expr.Metric("CreatedTimestamp"), ascending=False)]) + order: LList[expr.Ordering] = Field( + default_factory=lambda: [ + expr.Ordering(expr.Metric("CreatedTimestamp"), ascending=False) + ] + ) "A list of metrics and ordering to apply to the runset." run_settings: Dict[str, RunSettings] = Field(default_factory=dict) @@ -444,6 +399,52 @@ def from_model(cls, model: internal.View): regex_query = True if model.spec.section.run_sets[0].search.is_regex else False + section_settings = model.spec.section.settings + panel_bank_settings = model.spec.section.panel_bank_config.settings + x_axis = expr._convert_be_to_fe_metric_name(section_settings.x_axis) + point_viz_method: Literal["bucketing", "downsampling"] + if (pvm := section_settings.point_visualization_method) == "bucketing-gorilla": + point_viz_method = "bucketing" + elif pvm == "sampling": + point_viz_method = "downsampling" + elif pvm is None: + point_viz_method = "bucketing" + else: + raise Exception(f"Unexpected value for {pvm=}") + remove_legends_from_panels = ( + False if section_settings.suppress_legends is None else True + ) + tooltip_number_of_runs = ( + "default" + if section_settings.tooltip_number_of_runs is None + else section_settings.tooltip_number_of_runs + ) + tooltip_color_run_names = ( + True if section_settings.color_run_names is None else False + ) + max_runs = ( + 10 if section_settings.max_runs is None else section_settings.max_runs + ) + group_by_prefix = ( + "last" if panel_bank_settings.auto_organize_prefix == 2 else "first" + ) + + workspace_settings = WorkspaceSettings( + x_axis=x_axis, + x_min=section_settings.x_axis_min, + x_max=section_settings.x_axis_max, + smoothing_type=section_settings.smoothing_type, + smoothing_weight=section_settings.smoothing_weight, + ignore_outliers=section_settings.ignore_outliers, + remove_legends_from_panels=remove_legends_from_panels, + tooltip_number_of_runs=tooltip_number_of_runs, + tooltip_color_run_names=tooltip_color_run_names, + max_runs=max_runs, + point_visualization_method=point_viz_method, + sort_panels_alphabetically=panel_bank_settings.sort_alphabetically, + group_by_prefix=group_by_prefix, + ) + # then construct the Workspace object obj = cls( entity=model.entity, @@ -453,7 +454,7 @@ def from_model(cls, model: internal.View): Section.from_model(s) for s in model.spec.section.panel_bank_config.sections ], - settings=WorkspaceSettings.from_model(model.spec.section.settings), + settings=workspace_settings, runset_settings=RunsetSettings( query=model.spec.section.run_sets[0].search.query, regex_query=regex_query, @@ -491,6 +492,36 @@ def to_model(self) -> internal.View: sections = base_sections + hidden_sections is_regex = True if self.runset_settings.regex_query else None + auto_organize_prefix = 2 if self.settings.group_by_prefix == "last" else 1 + + if self.settings.sort_panels_alphabetically: + wandb.termwarn( + "settings.sort_panels_alphabetically=True; this may change panel order from what is currently defined in sections!" + ) + + x_axis = expr._convert_fe_to_be_metric_name(self.settings.x_axis) + point_viz_method = ( + "bucketing-gorilla" + if self.settings.point_visualization_method == "bucketing" + else "sampling" + ) + suppress_legends = ( + None if not self.settings.remove_legends_from_panels else True + ) + color_run_names = None if self.settings.tooltip_color_run_names else False + internal_settings = internal.ViewspecSectionSettings( + x_axis=x_axis, + x_axis_min=self.settings.x_min, + x_axis_max=self.settings.x_max, + smoothing_type=self.settings.smoothing_type, + smoothing_weight=self.settings.smoothing_weight, + ignore_outliers=self.settings.ignore_outliers, + suppress_legends=suppress_legends, + tooltip_number_of_runs=self.settings.tooltip_number_of_runs, + color_run_names=color_run_names, + max_runs=self.settings.max_runs, + point_visualization_method=point_viz_method, + ) return internal.View( entity=self.entity, @@ -502,12 +533,16 @@ def to_model(self) -> internal.View: section=internal.ViewspecSection( panel_bank_config=internal.PanelBankConfig( state=1, # TODO: What is this? + settings=internal.PanelBankConfigSettings( + sort_alphabetically=self.settings.sort_panels_alphabetically, + auto_organize_prefix=auto_organize_prefix, + ), sections=sections, ), panel_bank_section_config=internal.PanelBankSectionConfig( pinned=False ), - settings=self.settings.to_model(), + settings=internal_settings, run_sets=[ internal.Runset( id=self._internal_runset_id,