Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama finetune.py throws pytorch tensor datatype error with 4 bit quantization #675

Closed
1 of 2 tasks
AAndersn opened this issue Sep 23, 2024 · 8 comments
Closed
1 of 2 tasks

Comments

@AAndersn
Copy link

System Info

PyTorch 2.4.0, Cuda 12.1, CentOS HPC cluster with 7x H100 GPUs

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 python -m torch.distributed.launch \
    --nnodes 1 \
    --nproc_per_node 5 \
    -m llama_recipes.finetuning \
    --enable_fsdp \
    --model_name meta-llama/Meta-Llama-3.1-70B \
    --quantization 4bit \
    --use_peft \
    --peft_method lora \
    --dataset grammar_dataset \
    --lr 5e-5 \
    --save_model \
    --use_wandb \
    --output_dir /qfs/people/usr/models/70B

Error logs

Loading checkpoint shards:   0%|                                                                                                                                    | 0/4 [00:02<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank1]:     return _run_code(code, main_globals, None,
[rank1]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
[rank1]:     exec(code, run_globals)
[rank1]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 291, in <module>
[rank1]:     fire.Fire(main)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 121, in main
[rank1]:     model = LlamaForCausalLM.from_pretrained(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3960, in from_pretrained
[rank1]:     ) = cls._load_pretrained_model(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4434, in _load_pretrained_model
[rank1]:     new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/transformers/modeling_utils.py", line 970, in _load_state_dict_into_meta_model
[rank1]:     value = type(value)(value.data.to("cpu"), **value.__dict__)
[rank1]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 149, in __new__
[rank1]:     self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank1]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients
Loading checkpoint shards:   0%|                                                                                                                                    | 0/4 [00:00<?, ?it/s]W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60528 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60529 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60530 closing signal SIGTERM
W0922 19:40:05.383000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 60532 closing signal SIGTERM
E0922 19:40:05.857000 47946375398528 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 3 (pid: 60531) of binary: /qfs/people/usr/venv_llama_2/bin/python

This error message is then repeated by each separate GPU process, followed by

Traceback (most recent call last):
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 208, in <module>
    main()
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/typing_extensions.py", line 2360, in wrapper
    return arg(*args, **kwargs)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 204, in main
    launch(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in launch
    run(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
llama_recipes.finetuning FAILED

If the command is run without the FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 header, then it throws a different error:

ValueError: Cannot flatten integer dtype tensors
[rank0]: Traceback (most recent call last):
[rank0]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 291, in <module>
[rank0]:     fire.Fire(main)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:   File "/qfs/people/usr/llama-recipes/src/llama_recipes/finetuning.py", line 179, in main
[rank0]:     model = FSDP(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
[rank0]:     _auto_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
[rank0]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:   [Previous line repeated 2 more times]
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
[rank0]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
[rank0]:     return wrapper_cls(module, **kwargs)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank0]:     _init_param_handle_from_module(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
[rank0]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
[rank0]:     handle = FlatParamHandle(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
[rank0]:     self._init_flat_param_and_metadata(
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
[rank0]:     ) = self._validate_tensors_to_flatten(params)
[rank0]:   File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 769, in _validate_tensors_to_flatten
[rank0]:     raise ValueError("Cannot flatten integer dtype tensors")
[rank0]: ValueError: Cannot flatten integer dtype tensors

E0923 09:17:49.746000 47893711004800 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 44819) of binary: /qfs/people/usr/venv_llama_2/bin/python
Traceback (most recent call last):
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/share/apps/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 208, in <module>
    main()
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/typing_extensions.py", line 2360, in wrapper
    return arg(*args, **kwargs)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 204, in main
    launch(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in launch
    run(args)
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/qfs/people/usr/venv_llama_2/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
llama_recipes.finetuning FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-23_09:17:48
  host      : h100-02.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 44819)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected behavior

This call and dataset work fine for llama3.1-8B without quantization, but fail with 4-bit quantization. The int4 parameter specific given in https://github.com/meta-llama/llama-recipes/blob/main/recipes/quickstart/finetuning/multigpu_finetuning.md#with-fsdp--qlora does not exist.

@mreso
Copy link
Contributor

mreso commented Sep 24, 2024

Hi @AAndersn thanks for reporting. I was not able to repro this do far but I will give it another try later today. You're right about the int4, this is a left over from a back and forth while we created the PR for QLORA. Would you be interested in creating a PR to fix this?

@AAndersn
Copy link
Author

@mreso Happy to make a PR to update the docs. I'll also try rolling back to an older version of PyTorch and update this issue tomorrow to see if that fixes it.

@AAndersn
Copy link
Author

The problem appears to be an issue with AutoModel.from_pretrained() inside the finetuning.py script.

I rebuilt my environment today with llama-recipes 0.0.4 and transformers 4.45.0 and am able to run this snippet successfully:

import torch
from transformers import BitsAndBytesConfig, AutoModel

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16
)

model = AutoModel.from_pretrained(
            "meta-llama/Meta-Llama-3.1-8B",
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.bfloat16
)

However, if I copy and paste this exact snippet into finetuning.py, the AutoModel call fails with same message

python3.11/site-packages/bitsandbytes/nn/modules.py", line 149, in __new__
[rank3]:     self = torch.Tensor._make_subclass(cls, data, requires_grad)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: Only Tensors of floating point and complex dtype can require gradients

@wukaixingxp
Copy link
Contributor

Hi! @AAndersn Thanks for reporting this. I am wondering if changing AutoModel to LlamaForCausalLM solve this? Can you try? Thanks!

@AAndersn
Copy link
Author

AAndersn commented Sep 26, 2024

@wukaixingxp - Thank you so much! Changing AutoModel to LLamaForCausalLM fixed it! Testing now with 8B and 70B.

If that works, I will install the pytest suite and then update #681 to include this fix

@wukaixingxp
Copy link
Contributor

I tried your command with transformers = 4.45.0 and torch = 2.4.1. But I got this error[rank2]: Traceback (most recent call last): [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/runpy.py", line 196, in _run_module_as_main [rank2]: return _run_code(code, main_globals, None, [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/runpy.py", line 86, in _run_code [rank2]: exec(code, run_globals) [rank2]: File "/home/kaiwu/work/llama-recipes/src/llama_recipes/finetuning.py", line 332, in <module> [rank2]: fire.Fire(main) [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 143, in Fire [rank2]: component_trace = _Fire(component, args, parsed_flag_args, context, name) [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire [rank2]: component, remaining_args = _CallAndUpdateTrace( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace [rank2]: component = fn(*varargs, **kwargs) [rank2]: File "/home/kaiwu/work/llama-recipes/src/llama_recipes/finetuning.py", line 203, in main [rank2]: model = FSDP( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__ [rank2]: _auto_wrap( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap [rank2]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap [rank2]: wrapped_child, num_wrapped_params = _recursive_wrap( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap [rank2]: wrapped_child, num_wrapped_params = _recursive_wrap( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap [rank2]: wrapped_child, num_wrapped_params = _recursive_wrap( [rank2]: [Previous line repeated 2 more times] [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap [rank2]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap [rank2]: return wrapper_cls(module, **kwargs) [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__ [rank2]: _init_param_handle_from_module( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 565, in _init_param_handle_from_module [rank2]: _materialize_meta_module( [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 897, in _materialize_meta_module [rank2]: raise e [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 890, in _materialize_meta_module [rank2]: module.reset_parameters() # type: ignore[operator] [rank2]: File "/home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1729, in __getattr__ [rank2]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") [rank2]: AttributeError: 'LlamaRMSNorm' object has no attribute 'reset_parameters'. Did you mean: 'get_parameter'? /home/kaiwu/miniconda3/envs/recipe_test/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py:892: UserWarning: Unable to call `reset_parameters()` for module on meta device with error 'LlamaRMSNorm' object has no attribute 'reset_parameters'. Please ensure that your module oftype <class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'> implements a `reset_parameters()` method.

@AAndersn
Copy link
Author

pip reported a conflict with torch = 2.4.1.

I was able to run 8B with 4bit quantization with torch = 2.4.0 by replacing the AutoModel with LlamaForCausalLM or LlamaForQuestionAnswering (for use with a custom dataset).

@AAndersn
Copy link
Author

@wukaixingxp - I see you have made that update in https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/finetuning.py#L139, so will close this issue as fixed by #686

Thanks so much for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants