From f2c88ab587a4b3de5e62cc5b756eee598c008814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20R=C3=B6sti?= Date: Mon, 12 Aug 2024 16:21:07 -0600 Subject: [PATCH] Interleave sync instructions in matmul designs (#1678) --- .../matrix_multiplication/single_core/aie2.py | 93 ++++++---- .../whole_array/README.md | 2 + .../matrix_multiplication/whole_array/aie2.py | 161 +++++++++++++----- .../whole_array/diagram.png | Bin 0 -> 545460 bytes 4 files changed, 181 insertions(+), 75 deletions(-) create mode 100644 programming_examples/basic/matrix_multiplication/whole_array/diagram.png diff --git a/programming_examples/basic/matrix_multiplication/single_core/aie2.py b/programming_examples/basic/matrix_multiplication/single_core/aie2.py index 08c1eeb0b2..bb84ad7fd7 100644 --- a/programming_examples/basic/matrix_multiplication/single_core/aie2.py +++ b/programming_examples/basic/matrix_multiplication/single_core/aie2.py @@ -13,6 +13,7 @@ from aie.dialects.aiex import * from aie.dialects.scf import * import aie.utils.trace as trace_utils +from aie.utils.trace import PortEvent def main(): @@ -245,44 +246,74 @@ def sequence(A, B, C): ddr_id=2, size=trace_size, offset=C_sz_in_bytes, + events=[ + PortEvent( + trace_utils.CoreEvent.PORT_RUNNING_0, + port_number=1, + master=True, + ), + PortEvent( + trace_utils.CoreEvent.PORT_RUNNING_1, + port_number=2, + master=True, + ), + PortEvent( + trace_utils.CoreEvent.PORT_RUNNING_2, + port_number=5, + master=True, + ), + trace_utils.CoreEvent.INSTR_EVENT_0, + trace_utils.CoreEvent.INSTR_EVENT_1, + trace_utils.CoreEvent.MEMORY_STALL, + trace_utils.CoreEvent.LOCK_STALL, + trace_utils.CoreEvent.INSTR_VECTOR, + ], ) - # only do 5 tile rows at a time before synchronizing, so we can reuse BDs - rows_per_block = 5 + # only do 4 tile rows at a time before synchronizing, so we can reuse BDs + rows_per_block = 6 for tile_row_block in range(ceildiv(M_div_m, rows_per_block)): - C_row_offset = tile_row_block * rows_per_block * m * N - num_tile_rows = min( - [rows_per_block, M_div_m - tile_row_block * rows_per_block] - ) - npu_dma_memcpy_nd( - metadata="outC", - bd_id=0, - mem=C, - offsets=[0, 0, 0, C_row_offset], - sizes=[num_tile_rows, N_div_n, m, n], - strides=[m_x_N, n, N, 1], - ) - for tile_row in range(num_tile_rows): - A_row_offset = ( - ((tile_row_block * rows_per_block) + tile_row) * m * K + # we only sync on half the BDs before reusing them, so the other half can concurrently keep running + # that's what this loop is for + for pingpong in [0, 1]: + C_row_offset = ( + tile_row_block * rows_per_block * m * N + + pingpong * rows_per_block // 2 * m * N ) - npu_dma_memcpy_nd( - metadata="inA", - bd_id=2 * tile_row + 1, - mem=A, - offsets=[0, 0, 0, A_row_offset], - sizes=[N_div_n, K_div_k, m, k], - strides=[0, k, K, 1], + row_base = ( + tile_row_block * rows_per_block + + pingpong * rows_per_block // 2 ) + bd_id_base = 8 * pingpong + num_tile_rows = min([rows_per_block // 2, M_div_m - row_base]) npu_dma_memcpy_nd( - metadata="inB", - bd_id=2 * tile_row + 2, - mem=B, - sizes=[N_div_n, K_div_k, k, n], - strides=[n, k_x_N, N, 1], + metadata="outC", + bd_id=bd_id_base, + mem=C, + offsets=[0, 0, 0, C_row_offset], + sizes=[num_tile_rows, N_div_n, m, n], + strides=[m_x_N, n, N, 1], ) - - npu_sync(column=0, row=0, direction=0, channel=0) + for tile_row in range(num_tile_rows): + A_row_offset = (row_base + tile_row) * m * K + npu_dma_memcpy_nd( + metadata="inA", + bd_id=bd_id_base + 2 * tile_row + 1, + mem=A, + offsets=[0, 0, 0, A_row_offset], + sizes=[N_div_n, K_div_k, m, k], + strides=[0, k, K, 1], + ) + npu_dma_memcpy_nd( + metadata="inB", + bd_id=bd_id_base + 2 * tile_row + 2, + mem=B, + sizes=[N_div_n, K_div_k, k, n], + strides=[n, k_x_N, N, 1], + ) + if tile_row_block > 0 or (tile_row_block == 0 and pingpong > 0): + npu_sync(column=0, row=0, direction=0, channel=0) + npu_sync(column=0, row=0, direction=0, channel=0) print(ctx.module) diff --git a/programming_examples/basic/matrix_multiplication/whole_array/README.md b/programming_examples/basic/matrix_multiplication/whole_array/README.md index 61ea47aef7..28e1247f9a 100644 --- a/programming_examples/basic/matrix_multiplication/whole_array/README.md +++ b/programming_examples/basic/matrix_multiplication/whole_array/README.md @@ -108,6 +108,8 @@ Each of `inA_fifos`, `inB_fifos`, `OutC_fifos`, `memA_fifos`, `memB_fifos` and ` Of note is the `object_fifo_link()` operation. This operation establishes a connection between the `mem*` FIFOs and the `in*` and `outC` FIFOs. By linking ObjectFIFOs, the output received at one end of the source FIFO is fed as input into the ObjectFIFO listed as the destination. +[![data movement diagram](diagram.png)](https://excalidraw.com/#room=23df780b85d72d80cbc6,1czLdPr_vK9-OjtxFIWTpw) + !wK zb*KcAut)%Q$$$wA!uKY!u(Y;6>dIUL*U;Ce385B46&5%6*tH}0IPbj+2KOTRIn)IyK%ujggzhN4S}TT5y8HaZG3g{n zg^J4X(>m~i%p(&HC)0n+@1U9fB;k6H1Xgtc$}ga&7vD=8wde`T<3BrVnVlIdfX^el zUcU>l`CcOrfiJ4#PU6wPX>EC};m%jd7n{eKGgbyVs@n1o986|6iYixT1|ju9 zSQpfo4}X1VJjuK5gcS}mIKVc`a?*gGa(^!VzYQe_p6MRS*R3o;BEY4wymN7!f2o-g zD7S2I<3RQ9%wY7hx#o7J)>o(i^S}w(fdk%i?1(ntADQBJ!(<<4XTb;gz55=jVuTAD zFKxOO3!X&mM%oXa?r`H(jE`wsq_0YfaJ(R~!R^x}+&KqJ_(V#b-5ZY1Eoif`fGr%) z{Fk@X(Cn0dT*&jT*9l$mB`YbIa{E(KQU+TG$W&6@x21zVs7{$#MCUkn^Zalzor{}W zM>cIo>m^kzOBChgokeE{@g1%QAd||>k_?SM$>k4Wh7v1GCXg|0Pm;n1E_=*Hp{gIt zl~q-&-G0@k_VJ&g^qkhIIN(flrkhmJ2Y8Hx+p?DMU^&ylV7;~U&UHb@VBR3X*Qhh^ zAjHYZiMk*vJJxUaq$nYwABRCJkzGOH6O@3)P`BfYLmVG=8U`r!q9awQ-C{9NgXSpd z1Nh9y(yFJ4pD~wcJ4a4D7hO=;o-=L>GqmGo+b0p227@yZ%0T-v+fIGlZmA!2+ z%!k-9I!dIMi%rG?{_m{!m`e4UV2@`O&EjK03JY;iExB>Q3hVd>9#6!8OP^Fgq#OR_ zCh`v>r|xkRfF|4@U_QfD-3x_OXOr4^q1YX*us19SQh{u`f9>sF6Zt(Yt?Op2%LRrIT@xd zp#1Y-nxeiF(Qky0iHT|ccIMZv&SBKsFVVasBO_`MTiZrQ6V_+D*s3ca^tA^w2nYyh ze}69Lem3b$<99XpX6W$X|9!kYv3{*<|=NC4k-S+bR$#a(|$8Mvp=)$HVMnB0~PsQC)F$?5`T+ zW`l* zO%Os*AoUrtzn^pRBc);sd+x^_lk!C;q~=-tkZgI9b(H2?wvb#;C5Js@spD7fBUbUAo$bW z<{gJE>~)w_Q&U5On1Hcn)GFRe0&c%43=Q}1jyzfE9*q6{{XDq6Vci$XhF(ak-sg`G z^FIURaU`dIKKYp*GOkWG;i4{+apwg_&}pWfH0{5d!Tr}CrpA3Y5@xM%`mpe^j^Dx4 zNr(t6o)}ukZHLnhfrjPIxcX<)l2x~&$ek9bPN&VEn*Z@Zfx`5IXom6s@V(*4u>!;9 zP76E^YNYf(e(#;Vy{O#WT=#+g{z!Ls_oq1j;mr^b5QtMzQ2~N|dh$&FAlQE%5%rr7 zre`6xO)oEV>J>rx4Lkt)f7&WEK#;&KtN#znDeAt1x?3fR3k+fhpy}}RKP@K#5s?H? z9NuQnfe}WnvA8|Zn