diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bc80b389d3..e884ae6fec 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4669,74 +4669,105 @@ def get_qkv_layout( `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + q: torch.Tensor + Query tensor. It may be different from input `q` as we try to fit tensors to + a supported layout. + k: torch.Tensor + Key tensor. It may be different from input `k` as we try to fit tensors to + a supported layout. + v: torch.Tensor + Value tensor. It may be different from input `v` as we try to fit tensors to + a supported layout. """ check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" def run_iteratively(q, k, v): + # check data pointers data_ptr = q.untyped_storage().data_ptr() check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) data_ptr = k.untyped_storage().data_ptr() check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + # check tensor shapes + shape = q.shape + check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) + shape = k.shape + check_shapes_kv = shape[:-1] == v.shape[:-1] + + # check tensor strides stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( sv / v.shape[-1] for sv in v.stride()[:-1] ) - shape = q.shape - check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) - shape = k.shape - check_shapes_kv = shape[:-1] == v.shape[:-1] + # check tensor offsets for h3d and 3hd layouts + prod_h_d = q.shape[-1] * q.shape[-2] + check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) + check_h3d_offsets = all( + x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) + ) - last_dim_size = q.shape[-1] - check_last_dim_offsets_qkv = all( - i * last_dim_size == x.storage_offset() for i, x in enumerate([q, k, v]) + # check tensor offsets for hd_h2d and hd_2hd layouts + prod_all_dims = [np.prod(x.shape) for x in [q, k]] + offset = prod_all_dims[0] if check_ptrs_qkv else 0 + prod_h_d = k.shape[-1] * k.shape[-2] + check_2hd_offsets = all( + x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) ) - last_dim_size = k.shape[-1] - check_last_dim_offsets_kv = all( - i * last_dim_size == x.storage_offset() for i, x in enumerate([k, v]) + check_h2d_offsets = all( + x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) ) - last_two_dims_size = q.shape[-1] * q.shape[-2] - check_last_two_dims_offsets_qkv = all( - i * last_two_dims_size == x.storage_offset() for i, x in enumerate([q, k, v]) + # check tensor offsets for hd_hd_hd layouts + check_hd_offsets_qkv = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) + if check_ptrs_qkv + else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) + ) + check_hd_offsets_qk = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) + if not check_ptrs_qkv and check_ptrs_qk + else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) ) - last_two_dims_size = k.shape[-1] * k.shape[-2] - check_last_two_dims_offsets_kv = all( - i * last_two_dims_size == x.storage_offset() for i, x in enumerate([k, v]) + check_hd_offsets_kv = ( + all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) + if not check_ptrs_qkv and check_ptrs_kv + else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) ) - if ( - check_ptrs_qkv - and check_strides_qkv - and check_shapes_qkv - and check_last_two_dims_offsets_qkv - and not check_last_dim_offsets_qkv - ): + if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: # sb3hd, bs3hd, t3hd + # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] - elif ( - check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_last_dim_offsets_qkv - ): + elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: # sbh3d, bsh3d, th3d + # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] - elif ( - check_ptrs_kv - and check_strides_kv - and check_shapes_kv - and check_last_two_dims_offsets_kv - and not check_last_dim_offsets_kv - ): + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: # sbhd_sb2hd, bshd_bs2hd, thd_t2hd + # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_last_dim_offsets_kv: + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: # sbhd_sbh2d, bshd_bsh2d, thd_th2d + # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] - elif check_strides_kv and check_shapes_kv: + elif ( + check_strides_kv + and check_shapes_kv + and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + ): # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd + # three chunks of memory, q, k and v, which may be disjoint or consecutive, and + # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or + # check_ptrs_qk=True or check_ptrs_kv=True qkv_layout = "_".join(list([qkv_format]) * 3) else: qkv_layout = "not_supported"