Skip to content

Commit

Permalink
deploy: 1705ffe
Browse files Browse the repository at this point in the history
  • Loading branch information
facebook-github-bot committed Sep 11, 2024
1 parent 6e65ed8 commit 844df31
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 65 deletions.
4 changes: 3 additions & 1 deletion _modules/xformers/ops/fmha/flash.html
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ <h1>Source code for xformers.ops.fmha.flash</h1><div class="highlight"><pre>
<span class="k">raise</span>
<span class="k">assert</span> <span class="n">is_pt_flash_compatible</span><span class="p">(</span><span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">FLASH_VERSION</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">attention</span><span class="o">.</span><span class="n">_get_flash_version</span><span class="p">()</span> <span class="c1"># type: ignore</span>
<span class="n">FLASH_VERSION</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;v</span><span class="si">{</span><span class="n">FLASH_VERSION</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">VARLEN_LSE_PACKED</span> <span class="o">=</span> <span class="kc">False</span>
<span class="n">_USE_PT_FLASH_ATTN</span> <span class="o">=</span> <span class="kc">True</span>

Expand Down Expand Up @@ -774,6 +775,7 @@ <h1>Source code for xformers.ops.fmha.flash</h1><div class="highlight"><pre>
<span class="n">lse</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
<span class="n">inp</span><span class="p">:</span> <span class="n">Inputs</span><span class="p">,</span>
<span class="n">original_query_shape</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span>
<span class="n">varlen_lse_packed</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="n">VARLEN_LSE_PACKED</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
<span class="c1"># Easy case: no varlen</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">inp</span><span class="o">.</span><span class="n">attn_bias</span><span class="p">,</span> <span class="n">VARLEN_BIASES</span><span class="p">):</span>
Expand All @@ -783,7 +785,7 @@ <h1>Source code for xformers.ops.fmha.flash</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="n">lse</span>

<span class="c1"># Already packed: just bring back the batch dimension</span>
<span class="k">if</span> <span class="n">VARLEN_LSE_PACKED</span><span class="p">:</span>
<span class="k">if</span> <span class="n">varlen_lse_packed</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">original_query_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">5</span><span class="p">:</span>
<span class="c1"># (1, G, H, total_q)</span>
<span class="k">return</span> <span class="n">lse</span><span class="o">.</span><span class="n">unflatten</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">original_query_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
Expand Down
Loading

0 comments on commit 844df31

Please sign in to comment.