Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Improve get_attention_backend #1214

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 66 additions & 35 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4669,74 +4669,105 @@ def get_qkv_layout(
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
k: torch.Tensor
Key tensor. It may be different from input `k` as we try to fit tensors to
a supported layout.
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
"""

check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"

def run_iteratively(q, k, v):
# check data pointers
data_ptr = q.untyped_storage().data_ptr()
check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k])
data_ptr = k.untyped_storage().data_ptr()
check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v])

# check tensor shapes
shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]

# check tensor strides
stride = q.stride()
check_strides_qkv = all(stride == x.stride() for x in [q, k, v])
check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple(
sv / v.shape[-1] for sv in v.stride()[:-1]
)

shape = q.shape
check_shapes_qkv = all(shape == x.shape for x in [q, k, v])
shape = k.shape
check_shapes_kv = shape[:-1] == v.shape[:-1]
# check tensor offsets for h3d and 3hd layouts
prod_h_d = q.shape[-1] * q.shape[-2]
check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v]))
check_h3d_offsets = all(
x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v])
)

last_dim_size = q.shape[-1]
check_last_dim_offsets_qkv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v])
# check tensor offsets for hd_h2d and hd_2hd layouts
prod_all_dims = [np.prod(x.shape) for x in [q, k]]
offset = prod_all_dims[0] if check_ptrs_qkv else 0
prod_h_d = k.shape[-1] * k.shape[-2]
check_2hd_offsets = all(
x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v])
)
last_dim_size = k.shape[-1]
check_last_dim_offsets_kv = all(
i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v])
check_h2d_offsets = all(
x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v])
)

last_two_dims_size = q.shape[-1] * q.shape[-2]
check_last_two_dims_offsets_qkv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v])
# check tensor offsets for hd_hd_hd layouts
check_hd_offsets_qkv = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v]))
if check_ptrs_qkv
else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v]))
)
check_hd_offsets_qk = (
all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k]))
if not check_ptrs_qkv and check_ptrs_qk
else all(x.storage_offset() == 0 for i, x in enumerate([q, k]))
)
last_two_dims_size = k.shape[-1] * k.shape[-2]
check_last_two_dims_offsets_kv = all(
i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v])
check_hd_offsets_kv = (
all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v]))
if not check_ptrs_qkv and check_ptrs_kv
else all(x.storage_offset() == 0 for i, x in enumerate([k, v]))
)

if (
check_ptrs_qkv
and check_strides_qkv
and check_shapes_qkv
and check_last_two_dims_offsets_qkv
and not check_last_dim_offsets_qkv
):
if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets:
# sb3hd, bs3hd, t3hd
# one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv
qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:]
elif (
check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv
):
elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets:
# sbh3d, bsh3d, th3d
# one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv
qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:]
elif (
check_ptrs_kv
and check_strides_kv
and check_shapes_kv
and check_last_two_dims_offsets_kv
and not check_last_dim_offsets_kv
):
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets:
# sbhd_sb2hd, bshd_bs2hd, thd_t2hd
# two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv:
elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets:
# sbhd_sbh2d, bshd_bsh2d, thd_th2d
# two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv
# q and kv may be disjoint or consecutive in memory, and when consecutive, they may
# have the same data pointer, i.e. check_ptrs_qkv=True
qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:]
elif check_strides_kv and check_shapes_kv:
elif (
check_strides_kv
and check_shapes_kv
and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk)
):
# sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = "not_supported"
Expand Down
Loading