Skip to content

Commit

Permalink
Add options to sort panels alphabetically and by prefix (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Jul 8, 2024
1 parent e5e4f7f commit cf0aba4
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 76 deletions.
4 changes: 2 additions & 2 deletions tests/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ class CustomChartFactory(CustomDataclassFactory[wr.CustomChart]):
@classmethod
def query(cls):
return {"history": {"keys": ["x", "y"], "id": None, "name": None}}

@classmethod
def chart_fields(cls):
return {"x": "x", "y": "y"}

@classmethod
def chart_strings(cls):
return {"x": "x-axis", "y": "y-axis"}
Expand Down
12 changes: 3 additions & 9 deletions tests/test_workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from polyfactory.pytest_plugin import register_fixture

import wandb_workspaces.expr
import wandb_workspaces.workspaces as ws
import wandb_workspaces.reports.v2 as wr
import wandb_workspaces.workspaces as ws
from wandb_workspaces.utils.validators import (
validate_no_emoji,
validate_spec_version,
Expand Down Expand Up @@ -50,20 +50,15 @@ def runset_settings(cls):
wandb_workspaces.expr.Metric("vwx").notin([8, 9, 0, "broccoli"]),
],
)

@classmethod
def sections(cls):
return [
ws.Section(name="section1", panels=[wr.LinePlot()]),
ws.Section(name="section2", panels=[wr.BarPlot(title='tomato')]),
ws.Section(name="section2", panels=[wr.BarPlot(title="tomato")]),
]


@register_fixture
class WorkspaceSettingsFactory(CustomDataclassFactory[ws.WorkspaceSettings]):
__model__ = ws.WorkspaceSettings


@register_fixture
class SectionFactory(CustomDataclassFactory[ws.Section]):
__model__ = ws.Section
Expand All @@ -85,7 +80,6 @@ class SectionPanelSettingsFactory(CustomDataclassFactory[ws.SectionPanelSettings

factory_names = [
"workspace_factory",
"workspace_settings_factory",
"section_factory",
"section_panel_settings_factory",
"section_panel_settings_factory",
Expand Down
9 changes: 8 additions & 1 deletion wandb_workspaces/reports/v2/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Public interfaces for the Report API."""

import os
from datetime import datetime
from typing import Dict, Iterable, Optional, Tuple, Union
Expand Down Expand Up @@ -1081,7 +1082,13 @@ def from_model(cls, model: internal.WeaveBlock):
"val"
]
file = inputs["path"]["val"]
return cls(entity=entity, project=project, artifact=artifact, version=version, file=file)
return cls(
entity=entity,
project=project,
artifact=artifact,
version=version,
file=file,
)


@dataclass(config=dataclass_config)
Expand Down
163 changes: 99 additions & 64 deletions wandb_workspaces/workspaces/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ class WorkspaceSettings(Base):
ignore_outliers: bool = False
"Ignore outliers in all panels."

# Section settings
sort_panels_alphabetically: bool = False
"Sorts panels in all sections alphabetically"

group_by_prefix: Literal["first", "last"] = "last"
"""
Group panels by the first or up to last prefix
first: "a/b/c/d" -> section a
last: "a/b/c/d" -> section a/b/c
"""

# Panel settings
remove_legends_from_panels: bool = False
"Remove legends from all panels."
Expand Down Expand Up @@ -262,67 +274,6 @@ class WorkspaceSettings(Base):
)
"Search history for the panel search bar."

@classmethod
def from_model(cls, model: internal.ViewspecSectionSettings):
x_axis = expr._convert_be_to_fe_metric_name(model.x_axis)

# TODO: There can be better mapping here between sane API names and the internal names
point_viz_method: Literal["bucketing", "downsampling"]
if (pvm := model.point_visualization_method) == "bucketing-gorilla":
point_viz_method = "bucketing"
elif pvm == "sampling":
point_viz_method = "downsampling"
elif pvm is None:
point_viz_method = "bucketing"
else:
raise Exception(f"Unexpected value for {pvm=}")
remove_legends_from_panels = False if model.suppress_legends is None else True
tooltip_number_of_runs = (
"default"
if model.tooltip_number_of_runs is None
else model.tooltip_number_of_runs
)
tooltip_color_run_names = True if model.color_run_names is None else False
max_runs = 10 if model.max_runs is None else model.max_runs

return cls(
x_axis=x_axis,
x_min=model.x_axis_min,
x_max=model.x_axis_max,
smoothing_type=model.smoothing_type,
smoothing_weight=model.smoothing_weight,
ignore_outliers=model.ignore_outliers,
remove_legends_from_panels=remove_legends_from_panels,
tooltip_number_of_runs=tooltip_number_of_runs,
tooltip_color_run_names=tooltip_color_run_names,
max_runs=max_runs,
point_visualization_method=point_viz_method,
)

def to_model(self):
x_axis = expr._convert_fe_to_be_metric_name(self.x_axis)
point_viz_method = (
"bucketing-gorilla"
if self.point_visualization_method == "bucketing"
else "sampling"
)
suppress_legends = None if not self.remove_legends_from_panels else True
color_run_names = None if self.tooltip_color_run_names else False

return internal.ViewspecSectionSettings(
x_axis=x_axis,
x_axis_min=self.x_min,
x_axis_max=self.x_max,
smoothing_type=self.smoothing_type,
smoothing_weight=self.smoothing_weight,
ignore_outliers=self.ignore_outliers,
suppress_legends=suppress_legends,
tooltip_number_of_runs=self.tooltip_number_of_runs,
color_run_names=color_run_names,
max_runs=self.max_runs,
point_visualization_method=point_viz_method,
)


@dataclass(config=dataclass_config, repr=False)
class RunSettings(Base):
Expand Down Expand Up @@ -354,7 +305,11 @@ class RunsetSettings(Base):
groupby: LList[expr.MetricType] = Field(default_factory=list)
"A list of metrics to group by in the runset."

order: LList[expr.Ordering] = Field(default_factory=lambda: [expr.Ordering(expr.Metric("CreatedTimestamp"), ascending=False)])
order: LList[expr.Ordering] = Field(
default_factory=lambda: [
expr.Ordering(expr.Metric("CreatedTimestamp"), ascending=False)
]
)
"A list of metrics and ordering to apply to the runset."

run_settings: Dict[str, RunSettings] = Field(default_factory=dict)
Expand Down Expand Up @@ -444,6 +399,52 @@ def from_model(cls, model: internal.View):

regex_query = True if model.spec.section.run_sets[0].search.is_regex else False

section_settings = model.spec.section.settings
panel_bank_settings = model.spec.section.panel_bank_config.settings
x_axis = expr._convert_be_to_fe_metric_name(section_settings.x_axis)
point_viz_method: Literal["bucketing", "downsampling"]
if (pvm := section_settings.point_visualization_method) == "bucketing-gorilla":
point_viz_method = "bucketing"
elif pvm == "sampling":
point_viz_method = "downsampling"
elif pvm is None:
point_viz_method = "bucketing"
else:
raise Exception(f"Unexpected value for {pvm=}")
remove_legends_from_panels = (
False if section_settings.suppress_legends is None else True
)
tooltip_number_of_runs = (
"default"
if section_settings.tooltip_number_of_runs is None
else section_settings.tooltip_number_of_runs
)
tooltip_color_run_names = (
True if section_settings.color_run_names is None else False
)
max_runs = (
10 if section_settings.max_runs is None else section_settings.max_runs
)
group_by_prefix = (
"last" if panel_bank_settings.auto_organize_prefix == 2 else "first"
)

workspace_settings = WorkspaceSettings(
x_axis=x_axis,
x_min=section_settings.x_axis_min,
x_max=section_settings.x_axis_max,
smoothing_type=section_settings.smoothing_type,
smoothing_weight=section_settings.smoothing_weight,
ignore_outliers=section_settings.ignore_outliers,
remove_legends_from_panels=remove_legends_from_panels,
tooltip_number_of_runs=tooltip_number_of_runs,
tooltip_color_run_names=tooltip_color_run_names,
max_runs=max_runs,
point_visualization_method=point_viz_method,
sort_panels_alphabetically=panel_bank_settings.sort_alphabetically,
group_by_prefix=group_by_prefix,
)

# then construct the Workspace object
obj = cls(
entity=model.entity,
Expand All @@ -453,7 +454,7 @@ def from_model(cls, model: internal.View):
Section.from_model(s)
for s in model.spec.section.panel_bank_config.sections
],
settings=WorkspaceSettings.from_model(model.spec.section.settings),
settings=workspace_settings,
runset_settings=RunsetSettings(
query=model.spec.section.run_sets[0].search.query,
regex_query=regex_query,
Expand Down Expand Up @@ -491,6 +492,36 @@ def to_model(self) -> internal.View:

sections = base_sections + hidden_sections
is_regex = True if self.runset_settings.regex_query else None
auto_organize_prefix = 2 if self.settings.group_by_prefix == "last" else 1

if self.settings.sort_panels_alphabetically:
wandb.termwarn(
"settings.sort_panels_alphabetically=True; this may change panel order from what is currently defined in sections!"
)

x_axis = expr._convert_fe_to_be_metric_name(self.settings.x_axis)
point_viz_method = (
"bucketing-gorilla"
if self.settings.point_visualization_method == "bucketing"
else "sampling"
)
suppress_legends = (
None if not self.settings.remove_legends_from_panels else True
)
color_run_names = None if self.settings.tooltip_color_run_names else False
internal_settings = internal.ViewspecSectionSettings(
x_axis=x_axis,
x_axis_min=self.settings.x_min,
x_axis_max=self.settings.x_max,
smoothing_type=self.settings.smoothing_type,
smoothing_weight=self.settings.smoothing_weight,
ignore_outliers=self.settings.ignore_outliers,
suppress_legends=suppress_legends,
tooltip_number_of_runs=self.settings.tooltip_number_of_runs,
color_run_names=color_run_names,
max_runs=self.settings.max_runs,
point_visualization_method=point_viz_method,
)

return internal.View(
entity=self.entity,
Expand All @@ -502,12 +533,16 @@ def to_model(self) -> internal.View:
section=internal.ViewspecSection(
panel_bank_config=internal.PanelBankConfig(
state=1, # TODO: What is this?
settings=internal.PanelBankConfigSettings(
sort_alphabetically=self.settings.sort_panels_alphabetically,
auto_organize_prefix=auto_organize_prefix,
),
sections=sections,
),
panel_bank_section_config=internal.PanelBankSectionConfig(
pinned=False
),
settings=self.settings.to_model(),
settings=internal_settings,
run_sets=[
internal.Runset(
id=self._internal_runset_id,
Expand Down

0 comments on commit cf0aba4

Please sign in to comment.