diff --git a/src/layouter/layout.ts b/src/layouter/layout.ts index 5e83848..af61cf2 100644 --- a/src/layouter/layout.ts +++ b/src/layouter/layout.ts @@ -779,7 +779,7 @@ function layoutConditionalBlock( }); // layout each block individually to get its size. - let offsetAmount = 0; + let maxBranchHeight = 0; for (let id = 0; id < condBlock.branches.length; id++) { const [condition, block] = condBlock.branches[id]; block.id = id; @@ -817,19 +817,26 @@ function layoutConditionalBlock( blockEl.data.layout = blockInfo; blockEl.set_layout(); - blockEl.y = ( - condBlockElem.y + (blockInfo.height / 2) + + maxBranchHeight = Math.max(maxBranchHeight, blockInfo.height); + } + + // Offset each branch inside the conditional region and size it according + // to the height of the highest / longest branch. + let offsetAmount = 0; + for (const [_, branch] of condBlockElem.branches) { + branch.height = maxBranchHeight; + branch.y = ( + condBlockElem.y + (maxBranchHeight / 2) + ConditionalBlock.CONDITION_SPACING ); - blockEl.x = condBlockElem.x + (blockInfo.width / 2) + offsetAmount; - if (!blockEl.attributes()?.is_collapsed) { - offsetControlFlowRegion(block, blockEl, { + branch.x = condBlockElem.x + (branch.width / 2) + offsetAmount; + if (!branch.attributes()?.is_collapsed) { + offsetControlFlowRegion(branch.data.block, branch, { x: offsetAmount + BLOCK_MARGIN, y: BLOCK_MARGIN + ConditionalBlock.CONDITION_SPACING, }); } - - offsetAmount += blockInfo.width; + offsetAmount += branch.width; } // Annotate the JSON with layout information diff --git a/src/renderer/renderer_elements.ts b/src/renderer/renderer_elements.ts index 298b55e..d136427 100644 --- a/src/renderer/renderer_elements.ts +++ b/src/renderer/renderer_elements.ts @@ -3,12 +3,7 @@ import { DagreGraph, JsonSDFG, - JsonSDFGBlock, - JsonSDFGConditionalBlock, JsonSDFGControlFlowRegion, - JsonSDFGEdge, - JsonSDFGNode, - JsonSDFGState, Point2D, SimpleRect, } from '../index'; @@ -19,7 +14,6 @@ import { sdfg_property_to_string, sdfg_range_elem_to_string, } from '../utils/sdfg/display'; -import { check_and_redirect_edge } from '../utils/sdfg/sdfg_utils'; import { SDFVSettings } from '../utils/sdfv_settings'; import { SDFGRenderer } from './renderer'; @@ -889,12 +883,7 @@ export class ReturnBlock extends ControlFlowBlock { export class ConditionalBlock extends ControlFlowBlock { public static get CONDITION_SPACING(): number { - return 3 * SDFV.LINEHEIGHT; - } - - public static get LOOP_STATEMENT_FONT(): string { - return (SDFV.DEFAULT_CANVAS_FONTSIZE * 1.5).toString() + - 'px sans-serif'; + return 4 * SDFV.LINEHEIGHT; } public branches: ( @@ -904,6 +893,15 @@ export class ConditionalBlock extends ControlFlowBlock { ] )[] = []; + public simple_draw( + renderer: SDFGRenderer, ctx: CanvasRenderingContext2D, + mousepos?: Point2D + ): void { + this._internal_draw( + renderer, ctx, mousepos, 'conditional-background-simple-color', true + ); + } + public draw( renderer: SDFGRenderer, ctx: CanvasRenderingContext2D, _mousepos?: Point2D @@ -968,35 +966,38 @@ export class ConditionalBlock extends ControlFlowBlock { renderer, '--conditional-foreground-color' ); - let topSpacing = LoopRegion.META_LABEL_MARGIN; - let remainingHeight = this.height; - let x = topleft.x, y = topleft.y - for (const [condition, region] of this.branches) { - topSpacing += ConditionalBlock.CONDITION_SPACING; - const initBottomLineY = y + ConditionalBlock.CONDITION_SPACING; + if (!too_far_away_for_text(renderer)) { + const labelHeight = 1.5 * SDFV.LINEHEIGHT; ctx.beginPath(); - ctx.moveTo(x, initBottomLineY); - ctx.lineTo(x + this.width, initBottomLineY); + ctx.moveTo(topleft.x, topleft.y + labelHeight); + ctx.lineTo(topleft.x + this.width, topleft.y + labelHeight); ctx.stroke(); - if (!too_far_away_for_text(renderer)) { - const initTextY = ( - (y + (ConditionalBlock.CONDITION_SPACING / 2)) + - (SDFV.LINEHEIGHT / 2) - ); - const initTextMetrics = ctx.measureText( - condition?.string_data ?? 'else' - ); - const initTextX = x + (initTextMetrics.width / 2); - ctx.fillText( - condition?.string_data ?? 'else', initTextX, initTextY + const oldFont = ctx.font; + ctx.font = 'bold ' + oldFont; + + let nextX = topleft.x; + let nextY = topleft.y + labelHeight; + const condHeight = ConditionalBlock.CONDITION_SPACING - labelHeight; + for (const [condition, region] of this.branches) { + ctx.beginPath(); + ctx.moveTo(nextX, nextY); + ctx.lineTo(nextX, nextY + condHeight); + ctx.stroke(); + + const condTextY = nextY + condHeight / 2 + SDFV.LINEHEIGHT / 4; + const condText = condition?.string_data ? + 'if ' + condition.string_data : 'else'; + const condTextMetrics = ctx.measureText(condText); + const initTextX = ( + nextX + region.width / 2 - condTextMetrics.width / 2 ); + ctx.fillText(condText, initTextX, condTextY); - ctx.fillText( - 'if', topleft.x + LoopRegion.META_LABEL_MARGIN, initTextY - ); + nextX = nextX + region.width; } - y = initBottomLineY + region.height + + ctx.font = oldFont; } if (visibleRect && visibleRect.x <= topleft.x && @@ -1004,7 +1005,8 @@ export class ConditionalBlock extends ControlFlowBlock { SDFVSettings.get('showStateNames')) { if (!too_far_away_for_text(renderer)) { ctx.fillText( - this.label(), topleft.x + LoopRegion.META_LABEL_MARGIN, + this.label(), + topleft.x + ControlFlowRegion.META_LABEL_MARGIN, topleft.y + SDFV.LINEHEIGHT ); } @@ -1025,7 +1027,10 @@ export class ConditionalBlock extends ControlFlowBlock { // If collapsed, draw a "+" sign in the middle if (this.attributes().is_collapsed) { - const plusCenterY = topleft.y + (remainingHeight / 2) + topSpacing; + const contentH = this.height - ConditionalBlock.CONDITION_SPACING; + const plusCenterY = topleft.y + ( + contentH / 2 + ) + ConditionalBlock.CONDITION_SPACING; ctx.beginPath(); ctx.moveTo(this.x, plusCenterY - SDFV.LINEHEIGHT); ctx.lineTo(this.x, plusCenterY + SDFV.LINEHEIGHT); @@ -3166,7 +3171,7 @@ function batchedDrawEdges( } } -export function drawStateContents( +function drawStateContents( stateGraph: DagreGraph, ctx: CanvasRenderingContext2D, renderer: SDFGRenderer, ppp: number, visibleRect?: SimpleRect, mousePos?: Point2D @@ -3253,7 +3258,7 @@ export function drawStateContents( ); } -export function drawStateMachine( +function drawStateMachine( stateMachineGraph: DagreGraph, ctx: CanvasRenderingContext2D, renderer: SDFGRenderer, ppp: number, visibleRect?: SimpleRect, mousePos?: Point2D @@ -3346,7 +3351,7 @@ type AdaptiveTextPadding = { bottom?: number, }; -export function drawAdaptiveText( +function drawAdaptiveText( ctx: CanvasRenderingContext2D, renderer: SDFGRenderer, far_text: string, close_text: string, x: number, y: number, w: number, h: number, ppp_thres: number, @@ -3436,7 +3441,7 @@ export function drawAdaptiveText( ctx.font = oldfont; } -export function drawHexagon( +function drawHexagon( ctx: CanvasRenderingContext2D, x: number, y: number, w: number, h: number, _offset: Point2D ): void { @@ -3453,7 +3458,7 @@ export function drawHexagon( ctx.closePath(); } -export function drawOctagon( +function drawOctagon( ctx: CanvasRenderingContext2D, topleft: Point2D, width: number, height: number ): void { @@ -3471,13 +3476,13 @@ export function drawOctagon( ctx.closePath(); } -export function drawEllipse( +function drawEllipse( ctx: CanvasRenderingContext2D, x: number, y: number, w: number, h: number ): void { ctx.ellipse(x + w / 2, y + h / 2, w / 2, h / 2, 0, 0, 2 * Math.PI); } -export function drawTrapezoid( +function drawTrapezoid( ctx: CanvasRenderingContext2D, topleft: Point2D, node: SDFGNode, inverted: boolean = false ): void { @@ -3502,7 +3507,7 @@ export function drawTrapezoid( // Returns the distance from point p to line defined by two points // (line1, line2) -export function ptLineDistance( +function ptLineDistance( p: Point2D, line1: Point2D, line2: Point2D ): number { const dx = (line2.x - line1.x); diff --git a/tests/unit/sdfg_diff_viewer.test.ts b/tests/unit/sdfg_diff_viewer.test.ts index f8fbdce..4580b21 100644 --- a/tests/unit/sdfg_diff_viewer.test.ts +++ b/tests/unit/sdfg_diff_viewer.test.ts @@ -6,8 +6,9 @@ import { checkCompatLoad, parse_sdfg, } from '../../src/utils/sdfg/json_serializer'; -import { JsonSDFG, relayoutStateMachine, SDFG } from '../../src'; +import { JsonSDFG, SDFG } from '../../src'; import { SDFGDiffViewer } from '../../src/sdfg_diff_viewer'; +import { layoutSDFG } from '../../src/layouter/layout'; function _loadSDFG(name: string): JsonSDFG { const file = path.join( @@ -23,8 +24,8 @@ async function testDiffTiledGemm(): Promise { const sdfgAjson = _loadSDFG('gemm_expanded_pure'); const sdfgBjson = _loadSDFG('gemm_expanded_pure_tiled'); - const graphA = relayoutStateMachine(sdfgAjson, sdfgAjson); - const graphB = relayoutStateMachine(sdfgBjson, sdfgBjson); + const graphA = layoutSDFG(sdfgAjson); + const graphB = layoutSDFG(sdfgBjson); const sdfgA = new SDFG(sdfgAjson); sdfgA.sdfgDagreGraph = graphA;