Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for rankllama #294

Merged
merged 1 commit into from
Aug 13, 2024
Merged

Conversation

aniquetahir
Copy link
Contributor

This pull request adds support for RankLlama with LoRA. Detailed instructions are included in README.md.

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

@aniquetahir aniquetahir marked this pull request as draft August 6, 2024 17:15
@aniquetahir aniquetahir marked this pull request as ready for review August 6, 2024 17:16
@OctoberChang OctoberChang self-assigned this Aug 6, 2024
@aniquetahir aniquetahir force-pushed the xmr_argparse branch 7 times, most recently from 2b78a90 to dd94777 Compare August 6, 2024 18:47
@aniquetahir aniquetahir marked this pull request as draft August 6, 2024 18:50
@aniquetahir aniquetahir marked this pull request as ready for review August 6, 2024 18:57
pecos/xmr/reranker/trainer.py Outdated Show resolved Hide resolved
@@ -0,0 +1,91 @@
# PECOS XMR Reranker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some general comments about the README.md:

  • We should also introduce the data schema for training/inference?
  • For the Command Line Usage (CLI), we have pecos.xmr.reranker.train. Should we also have pecos.xmc.reranker.predict?
  • Do we want to support Python API usage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another high level comment. Can we avoid hard-coded the input columns and make them configurable in the config JSON file? Some hard-coded columns, for example:

  • Line 79 of data_utils.py: keywords
  • Line 110 of data_utils.py: contents, titles
  • Line 296-298 of model.py: inp_id, ret_idxs, rel

Copy link
Contributor Author

@aniquetahir aniquetahir Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another high level comment. Can we avoid hard-coded the input columns and make them configurable in the config JSON file? Some hard-coded columns, for example:

* Line 79 of `data_utils.py`: `keywords`

* Line 110 of `data_utils.py`: `contents`, `titles`

* Line 296-298 of `model.py`: `inp_id`, `ret_idxs`, `rel`

This can now be specified in the configuration. I added details in the README.md.

Some general comments about the README.md:

* We should also introduce the data schema for training/inference?

* For the Command Line Usage (CLI), we have `pecos.xmr.reranker.train`.  Should we also have `pecos.xmc.reranker.predict`?

* Do we want to support Python API usage?

Added predictions.

pecos/xmr/reranker/model.py Outdated Show resolved Hide resolved
params: The model parameters (RankingModelParams)
"""
training_args = train_params.training_args
training_args.remove_unused_columns = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line still necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are adding additional information to the output of the collate function, this is needed to avoid it being removed by the trainer.

"""
Enable gradient checkpointing for the model
"""
self.hf_model.enable_input_require_grads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the right place to call hf_model.enable_input_require_grads().
From Tevatron RankLlaMA implementation (https://github.com/texttron/tevatron/blob/main/src/tevatron/reranker/modeling.py#L79), they are calling only when both "LoRA" and "training_args.gradient_checkpointing" is enable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following sequence of operation is expected:

  1. First get_modifed_model is called.
  2. the trainer calls gradient_checkpointing_enable, when gradient checkpointing is enabled.

@aniquetahir aniquetahir marked this pull request as draft August 8, 2024 16:51
@aniquetahir aniquetahir force-pushed the xmr_argparse branch 6 times, most recently from 87cc9f6 to 63a083b Compare August 8, 2024 17:39
@aniquetahir aniquetahir marked this pull request as ready for review August 8, 2024 17:40
@aniquetahir aniquetahir marked this pull request as draft August 8, 2024 18:22
@aniquetahir aniquetahir marked this pull request as ready for review August 8, 2024 18:25
lbl_idxs: List[int],
):
"""
Collate function for training. Tokenizes the input and return features and returns the collated batch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc String seems to be out-dated.

model_params (RankingModel.ModelParams): the model parameters
train_params (RankingModel.TrainParams): the training parameters
Returns:
An instance of UberGlobalModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc String seems to be out-dated. Remove UberGlobalModel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docstrings

@aniquetahir aniquetahir marked this pull request as draft August 8, 2024 21:59
@aniquetahir aniquetahir marked this pull request as ready for review August 8, 2024 22:03
setup.py Outdated
'transformers>=4.4.2; python_version>="3.9"'
'transformers>=4.4.2; python_version>="3.9"',
'tqdm>=4.66.4',
'peft>=0.11.0',
Copy link
Contributor

@OctoberChang OctoberChang Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The peft package requires python_version >= 3.8 that's why the latest unit test failed.

Please also check the minimal support python version for other libraries you introduce.

@aniquetahir aniquetahir marked this pull request as ready for review August 12, 2024 23:46
Copy link
Contributor

@OctoberChang OctoberChang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@OctoberChang OctoberChang merged commit ea254b0 into amzn:mainline Aug 13, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants