diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index bbdde86c6..fb93e16a6 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -6437,6 +6437,7 @@ def scaled_dot_product_attention(context, node): - attn_mask : (target_seq, source_seq) or (B, target_seq, source_seq) or (B, h, target_seq, source_seq) or (B, ..., target_seq, source_seq) - is_causal : bool + - scale : optional float Output shape: (target_seq, d_v) or (B,...,target_seq, d_v) @@ -6448,14 +6449,22 @@ def scaled_dot_product_attention(context, node): https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html """ inputs = _get_inputs(context, node, min_expected=3) - q, k, v = inputs[: 3] + q, k, v = inputs[:3] attn_mask = None if len(inputs) < 4 else inputs[3] dropout = 0.0 if len(inputs) < 5 else inputs[4] is_causal = False if len(inputs) < 6 else inputs[5].val + + # When len(inputs) == 7, the inputs are (q, k, v, attn_mask, dropout, is_causal, scale) + if len(inputs) == 7 and inputs[6] is not None: + raise NotImplementedError( + "scaled_dot_product_attention op: scale parameter is not handled." + ) + if attn_mask is not None and is_causal: raise ValueError( "scaled_dot_product_attention op: attn_mask cannot be provided when is_causal is set to True." ) + if dropout is not None and (dropout.val is None or dropout.val != 0.0): raise ValueError("scaled_dot_product_attention op: dropout is not supported yet")