Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

architecture: factor HFCompatible out #954

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

leondz
Copy link
Owner

@leondz leondz commented Oct 21, 2024

HFCompatible was embedded in generators.base, tying slow-to-import HF-specific stuff to base classes. This PR moves HFCompatible to a separate module, with a candidate location in garak.resources.api.huggingface, enabling fast base class loading.

@leondz leondz added the enhancement Architectural upgrades label Oct 21, 2024
Copy link
Collaborator

@jmartin-tech jmartin-tech left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move looks reasonable, I am on the fence on using garak.resources.api to represent wrappers for dependencies.

One idea would be to use garak/resources/huggingface/__init__.py to expose the class, this could allow for keeping the class definitions in unique files imported to __init__py for exposure as more Compatible types are identified over time. Just a thought that came to mind, no strong argument to favor this at this time.

This PR also suggests there is another consumer for this mixin in buffs.paraphrase.PegasusT5.

garak/resources/api/huggingface.py Show resolved Hide resolved
from garak.buffs.base import Buff
from garak.resources.api.huggingface import HFCompatible


class PegasusT5(Buff):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an observation not required for this PR, it looks like this class could benefit from a refactor to use HFCompatible and expose the para_model_name and hf_args as DEFAULT_PARAMS.

Suggested change
class PegasusT5(Buff):
class PegasusT5(Buff, HFCompatible):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is incomplete the class needs to be extended to consume the HFCompatible mixin in _load_model():

self.torch_device moves to the standardized self.device and should be detected/populated from hf_args["device"] with a call to self._select_hf_device() in _load_model():

    def __init__(self, config_root=_config) -> None:
        self.max_length = 60
        self.temperature = 1.5
        self.num_return_sequences = 6
        self.num_beams = self.num_return_sequences
        self.tokenizer = None
        self.para_model = None
        super().__init__(config_root=config_root)
   def _load_model(self):
        from transformers import PegasusForConditionalGeneration, PegasusTokenizer

        self.device = self._select_hf_device()
        model_kwargs = self._gather_hf_params(
            hf_constructor=PegasusForConditionalGeneration.from_pretrained
        )  # will defer to device_map if device map was `auto` may not match self.device

        self.para_model = PegasusForConditionalGeneration.from_pretrained(
            self.para_model_name, **model_kwargs
        ).to(self.device)
        self.tokenizer = PegasusTokenizer.from_pretrained(self.para_model_name)

Not an issue for this PR, however I suspect a few more items should likely be promoted to DEFAULT_PARAMS. The max_length, temperature, num_return_sequences, and possibly num_beams if the value does not always have to be equal to num_return_sequences should likely be exposed a configurable. Since Fast looks like it may also have similar items to promote I think that can be deferred.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Thanks for the details. Will mark as ready for review when out of draft.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the model_kwargs part led to the paraphraser returning all blanks. Did the rest of the integration and added a test to catch this unwanted behaviour.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still getting blank results when using _gather_hf_params, will take a look in a bit, but if you have suggestions, they're welcome!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets drop the _gather_hf_params() for now as the PegasusForConditionalGeneration.from_pretrained() looks like it is not handling the extra args in the way the current code is expecting, I suspect device vs device_map is also impacting the expectations.

@leondz leondz marked this pull request as ready for review October 22, 2024 09:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Architectural upgrades
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants