From 77c1da7fe5c9a25f613bd4045bb1da4856310aab Mon Sep 17 00:00:00 2001 From: lshamis Date: Wed, 16 Oct 2024 14:40:09 +0000 Subject: [PATCH] Revert "remove pre-dep torch (fairinternal/xformers#1237)" (fairinternal/xformers#1241) This reverts commit 51111b0c7affd6c0c88b3beae45c62e47143bf83. __original_commit__ = fairinternal/xformers@16cb3bc0d9f1cbeea387048426a331684fad908a --- setup.py | 233 ++++++++++++++++++++++++++----------------------------- 1 file changed, 108 insertions(+), 125 deletions(-) diff --git a/setup.py b/setup.py index 2530854d23..4f0da97fb0 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,30 @@ from typing import List, Optional import setuptools -from setuptools.command.build_ext import build_ext +import torch +from torch.utils.cpp_extension import ( + CUDA_HOME, + ROCM_HOME, + BuildExtension, + CppExtension, + CUDAExtension, +) + +this_dir = os.path.dirname(__file__) +pt_attn_compat_file_path = os.path.join( + this_dir, "xformers", "ops", "fmha", "torch_attention_compat.py" +) -this_file = Path(__file__) -this_dir = this_file.parent +# Define the module name +module_name = "torch_attention_compat" + +# Load the module +spec = importlib.util.spec_from_file_location(module_name, pt_attn_compat_file_path) +assert spec is not None +attn_compat_module = importlib.util.module_from_spec(spec) +sys.modules[module_name] = attn_compat_module +assert spec.loader is not None +spec.loader.exec_module(attn_compat_module) def get_extra_nvcc_flags_for_build_type(cuda_version: int) -> List[str]: @@ -50,18 +70,18 @@ def fetch_requirements(): def get_local_version_suffix() -> str: - if not (this_dir / ".git").is_dir(): + if not (Path(__file__).parent / ".git").is_dir(): # Most likely installing from a source distribution return "" date_suffix = datetime.datetime.now().strftime("%Y%m%d") git_hash = subprocess.check_output( - ["git", "rev-parse", "--short", "HEAD"], cwd=this_dir + ["git", "rev-parse", "--short", "HEAD"], cwd=Path(__file__).parent ).decode("ascii")[:-1] return f"+{git_hash}.d{date_suffix}" def get_flash_version() -> str: - flash_dir = this_dir / "third_party" / "flash-attention" + flash_dir = Path(__file__).parent / "third_party" / "flash-attention" try: return subprocess.check_output( ["git", "describe", "--tags", "--always"], @@ -84,15 +104,17 @@ def generate_version_py(version: str) -> str: def symlink_package(name: str, path: Path, is_building_wheel: bool) -> None: - cwd = this_file.resolve().parent + cwd = Path(__file__).resolve().parent path_from = cwd / path - path_to = cwd / Path(name.replace(".", os.sep)) + path_to = os.path.join(cwd, *name.split(".")) try: - if path_to.is_dir() and not path_to.is_symlink(): + if os.path.islink(path_to): + os.unlink(path_to) + elif os.path.isdir(path_to): shutil.rmtree(path_to) else: - path_to.unlink() + os.remove(path_to) except FileNotFoundError: pass # OSError: [WinError 1314] A required privilege is not held by the client @@ -101,7 +123,6 @@ def symlink_package(name: str, path: Path, is_building_wheel: bool) -> None: # So we force a copy, see #611 use_symlink = os.name != "nt" and not is_building_wheel if use_symlink: - # path_to.symlink_to(path_from) os.symlink(src=path_from, dst=path_to) else: shutil.copytree(src=path_from, dst=path_to) @@ -190,31 +211,29 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int): def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): - from torch.utils.cpp_extension import CUDAExtension - nvcc_archs_flags = get_flash_attention2_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: return [] - flash_root = this_dir / "third_party" / "flash-attention" - cutlass_inc = flash_root / "csrc" / "cutlass" / "include" - if not flash_root.exists() or not cutlass_inc.exists(): + flash_root = os.path.join(this_dir, "third_party", "flash-attention") + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) sources = ["csrc/flash_attn/flash_api.cpp"] - for f in (flash_root / "csrc" / "flash_attn" / "src").glob("*.cu"): - if "hdim224" in f.name: + for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): + if "hdim224" in Path(f).name: continue - sources.append(str(f.relative_to(flash_root))) + sources.append(str(Path(f).relative_to(flash_root))) common_extra_compile_args = ["-DFLASHATTENTION_DISABLE_ALIBI"] return [ CUDAExtension( name="xformers._C_flashattention", - sources=[str(flash_root / path) for path in sources], + sources=[os.path.join(flash_root, path) for path in sources], extra_compile_args={ "cxx": extra_compile_args.get("cxx", []) + common_extra_compile_args, "nvcc": extra_compile_args.get("nvcc", []) @@ -237,9 +256,9 @@ def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): include_dirs=[ p.absolute() for p in [ - flash_root / "csrc" / "flash_attn", - flash_root / "csrc" / "flash_attn" / "src", - flash_root / "csrc" / "cutlass" / "include", + Path(flash_root) / "csrc" / "flash_attn", + Path(flash_root) / "csrc" / "flash_attn" / "src", + Path(flash_root) / "csrc" / "cutlass" / "include", ] ], ) @@ -250,8 +269,6 @@ def get_flash_attention2_extensions(cuda_version: int, extra_compile_args): # FLASH-ATTENTION v3 ###################################### def get_flash_attention3_nvcc_archs_flags(cuda_version: int): - import torch - if os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") != "0": return [] if platform.system() != "Linux" or cuda_version < 1203: @@ -280,33 +297,29 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int): def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): - from torch.utils.cpp_extension import CUDAExtension - nvcc_archs_flags = get_flash_attention3_nvcc_archs_flags(cuda_version) if not nvcc_archs_flags: return [] - flash_root = this_dir / "third_party" / "flash-attention" - cutlass_inc = flash_root / "csrc" / "cutlass" / "include" - if not flash_root.exists() or not cutlass_inc.exists(): + flash_root = os.path.join(this_dir, "third_party", "flash-attention") + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) - sources = [] - sources += [ - str(f.relative_to(flash_root)) for f in (flash_root / "hopper").glob("*.cu") - ] - sources += [ - str(f.relative_to(flash_root)) for f in (flash_root / "hopper").glob("*.cpp") + sources = [ + str(Path(f).relative_to(flash_root)) + for f in glob.glob(os.path.join(flash_root, "hopper", "*.cu")) + + glob.glob(os.path.join(flash_root, "hopper", "*.cpp")) ] sources = [s for s in sources if "flash_bwd_hdim256_fp16_sm90.cu" not in s] return [ CUDAExtension( name="xformers._C_flashattention3", - sources=[str(flash_root / path) for path in sources], + sources=[os.path.join(flash_root, path) for path in sources], extra_compile_args={ "cxx": extra_compile_args.get("cxx", []), "nvcc": extra_compile_args.get("nvcc", []) @@ -339,8 +352,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): include_dirs=[ p.absolute() for p in [ - flash_root / "csrc" / "cutlass" / "include", - flash_root / "hopper", + Path(flash_root) / "csrc" / "cutlass" / "include", + Path(flash_root) / "hopper", ] ], ) @@ -353,29 +366,6 @@ def rename_cpp_cu(cpp_files): def get_extensions(): - import torch - from torch.utils.cpp_extension import ( - CUDA_HOME, - ROCM_HOME, - CppExtension, - CUDAExtension, - ) - - pt_attn_compat_file_path = ( - this_dir / "xformers" / "ops" / "fmha" / "torch_attention_compat.py" - ) - - # Define the module name - module_name = "torch_attention_compat" - - # Load the module - spec = importlib.util.spec_from_file_location(module_name, pt_attn_compat_file_path) - assert spec is not None - attn_compat_module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = attn_compat_module - assert spec.loader is not None - spec.loader.exec_module(attn_compat_module) - extensions_dir = os.path.join("xformers", "csrc") sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) @@ -402,7 +392,7 @@ def get_extensions(): source_cuda = list(set(source_cuda) - set(source_hip_generated)) sources = list(set(sources) - set(source_hip)) - sputnik_dir = this_dir / "third_party" / "sputnik" + sputnik_dir = os.path.join(this_dir, "third_party", "sputnik") xformers_pt_cutlass_attn = os.getenv("XFORMERS_PT_CUTLASS_ATTN") # By default, we try to link to torch internal CUTLASS attention implementation @@ -415,12 +405,12 @@ def get_extensions(): ): source_cuda = list(set(source_cuda) - set(fmha_source_cuda)) - cutlass_dir = this_dir / "third_party" / "cutlass" / "include" - cutlass_util_dir = ( - this_dir / "third_party" / "cutlass" / "tools" / "util" / "include" + cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include") + cutlass_util_dir = os.path.join( + this_dir, "third_party", "cutlass", "tools", "util", "include" ) - cutlass_examples_dir = this_dir / "third_party" / "cutlass" / "examples" - if not cutlass_dir.exists(): + cutlass_examples_dir = os.path.join(this_dir, "third_party", "cutlass", "examples") + if not os.path.exists(cutlass_dir): raise RuntimeError( f"CUTLASS submodule not found at {cutlass_dir}. " "Did you forget to run " @@ -448,7 +438,7 @@ def get_extensions(): use_pt_flash = False if ( - (torch.cuda.is_available() and (CUDA_HOME is not None)) + (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv("FORCE_CUDA", "0") == "1" or os.getenv("TORCH_CUDA_ARCH_LIST", "") != "" ): @@ -456,10 +446,10 @@ def get_extensions(): extension = CUDAExtension sources += source_cuda include_dirs += [ - str(sputnik_dir), - str(cutlass_dir), - str(cutlass_util_dir), - str(cutlass_examples_dir), + sputnik_dir, + cutlass_dir, + cutlass_util_dir, + cutlass_examples_dir, ] nvcc_flags = [ "-DHAS_PYTORCH", @@ -544,10 +534,12 @@ def get_extensions(): extension = CUDAExtension sources += source_hip_cu - include_dirs += [this_dir / "xformers" / "csrc" / "attention" / "hip_fmha"] + include_dirs += [ + Path(this_dir) / "xformers" / "csrc" / "attention" / "hip_fmha" + ] include_dirs += [ - this_dir / "third_party" / "composable_kernel_tiled" / "include" + Path(this_dir) / "third_party" / "composable_kernel_tiled" / "include" ] generator_flag = [] @@ -635,61 +627,41 @@ def run(self): distutils.command.clean.clean.run(self) -class TorchBuildExtension(build_ext): - def run(self): - from torch.utils.cpp_extension import BuildExtension - - class BuildExtensionWithExtraFiles(BuildExtension): - def __init__(self, *args, **kwargs) -> None: - self.xformers_build_metadata = kwargs.pop("extra_files") - self.pkg_name = "xformers" - super().__init__(*args, **kwargs) - - def build_extensions(self) -> None: - super().build_extensions() - for filename, content in self.xformers_build_metadata.items(): - with open( - os.path.join(self.build_lib, self.pkg_name, filename), "w+" - ) as fp: - fp.write(content) - - def copy_extensions_to_source(self) -> None: - """ - Used for `pip install -e .` - Copies everything we built back into the source repo - """ - build_py = self.get_finalized_command("build_py") - package_dir = build_py.get_package_dir(self.pkg_name) - - for filename in self.xformers_build_metadata.keys(): - inplace_file = os.path.join(package_dir, filename) - regular_file = os.path.join(self.build_lib, self.pkg_name, filename) - self.copy_file(regular_file, inplace_file, level=self.verbose) - super().copy_extensions_to_source() - - extensions, extensions_metadata = get_extensions() - - setuptools.setup( - ext_modules=extensions, - cmdclass={ - "build_ext": BuildExtensionWithExtraFiles.with_options( - no_python_abi_suffix=True, - extra_files={ - "cpp_lib.json": json.dumps(extensions_metadata), - "version.py": generate_version_py(version), - }, - ), - "clean": clean, - }, - ) +class BuildExtensionWithExtraFiles(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + self.xformers_build_metadata = kwargs.pop("extra_files") + self.pkg_name = "xformers" + super().__init__(*args, **kwargs) + + def build_extensions(self) -> None: + super().build_extensions() + for filename, content in self.xformers_build_metadata.items(): + with open( + os.path.join(self.build_lib, self.pkg_name, filename), "w+" + ) as fp: + fp.write(content) + + def copy_extensions_to_source(self) -> None: + """ + Used for `pip install -e .` + Copies everything we built back into the source repo + """ + build_py = self.get_finalized_command("build_py") + package_dir = build_py.get_package_dir(self.pkg_name) + + for filename in self.xformers_build_metadata.keys(): + inplace_file = os.path.join(package_dir, filename) + regular_file = os.path.join(self.build_lib, self.pkg_name, filename) + self.copy_file(regular_file, inplace_file, level=self.verbose) + super().copy_extensions_to_source() if __name__ == "__main__": if os.getenv("BUILD_VERSION"): # In CI version = os.getenv("BUILD_VERSION", "0.0.0") else: - version_txt = this_dir / "version.txt" - with version_txt.open() as f: + version_txt = os.path.join(this_dir, "version.txt") + with open(version_txt) as f: version = f.readline().strip() version += get_local_version_suffix() @@ -704,13 +676,24 @@ def copy_extensions_to_source(self) -> None: Path("third_party") / "flash-attention" / "flash_attn", is_building_wheel, ) + extensions, extensions_metadata = get_extensions() setuptools.setup( name="xformers", description="XFormers: A collection of composable Transformer building blocks.", version=version, install_requires=fetch_requirements(), packages=setuptools.find_packages(exclude=("tests*", "benchmarks*")), - cmdclass={"build_ext": TorchBuildExtension}, + ext_modules=extensions, + cmdclass={ + "build_ext": BuildExtensionWithExtraFiles.with_options( + no_python_abi_suffix=True, + extra_files={ + "cpp_lib.json": json.dumps(extensions_metadata), + "version.py": generate_version_py(version), + }, + ), + "clean": clean, + }, url="https://facebookresearch.github.io/xformers/", python_requires=">=3.7", author="Facebook AI Research",