Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Activation operations #1164

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
89 changes: 89 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,95 @@ def test_make_extra_output(
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)

# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.CastFloat8(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
12 changes: 1 addition & 11 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@

"""

from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

"""Single tensor operations supported by the operation fuser."""

from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .cast_float8 import CastFloat8
from .identity import Identity
from .make_extra_output import MakeExtraOutput
from .reduce_scatter import ReduceScatter
Expand Down
Loading
Loading