Skip to content

Add functionality to support IA3#578

Merged
pacman100 merged 39 commits intohuggingface:mainfrom
SumanthRH:ia3
Jul 13, 2023
Merged

Add functionality to support IA3#578
pacman100 merged 39 commits intohuggingface:mainfrom
SumanthRH:ia3

Conversation

@SumanthRH
Copy link
Contributor

Hi,

We've added some code to support IA3 from the T-few paper. Most of the code is inspired by the LoRA implementation. Currently, implementation supports multiple adapters, int-8 training, merging and unmerging. I've only added a minimal set of models for now. With IA3, there are learned vectors added to key, value and feed-forward layers. The forward pass differs for feedforward vs non-feedforward layers.

I hope this can be merged into the repo!

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 15, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @SumanthRH for adding IA3 (Infused Adapter by Inhibiting and Amplifying Inner Activations) PEFT method 🤗🚀✨! This will provide more options for the community to explore for their specific problems.

Could you please remove changes from existing example notebooks and only have new notebook examples specific to IA3?

I left a comment regarding refactoring some code.

@pacman100
Copy link
Contributor

Also, please run make style and make quality to resolve the quality issues

@pacman100
Copy link
Contributor

Hello, it would be great if you could add a test file for testing the minimal core components only (forward, save_pretrained, generate). cc @younesbelkada

@SumanthRH
Copy link
Contributor Author

Hi @pacman100,

Thanks for checking out the pull request. I can add tests similar to what you have for LoRA/Prefix tuning. I wasn't sure if you would be merging this, so I hadn't added those changes yet.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SumanthRH
Thank you very much for your great work
I second what @pacman100 said, it would be great if you can add a test file to make sure your implementation won't get broken by future PRs. For that, please have a look at : https://github.com/huggingface/peft/blob/main/tests/test_adaption_prompt.py and you can adapt it to your needs. Let us know if you need any help designing the tests

@SumanthRH
Copy link
Contributor Author

SumanthRH commented Jun 26, 2023

Hi @pacman100 @younesbelkada

I looked through the tests folder and thought it was better to add IA3 to the existing set of common tests in tests_common. Currently, here are all the tests that are supported:

  • _test_model_attr
  • _test_prepare_for_training
  • _test_save_pretrained
  • _test_merge_layers (Caveat here: With LoRA, the initialization is random - thus the results after and before merging differ slightly. With IA3, the initialization of the learned vectors is all ones - as per the paper. So the results don't change. Not sure if I should do any additional tests here?)
  • _test_generate
  • _test_training

All of these tests are run for decoder models (T5 and BART) and the encoder-decoder models (GPT2, OPT, BLOOM, GPT-NeoX, GPT-Neo). Also, the original authors only specific the IA3 idea for the T0 model, so for other architectures I've simply made an appropriate choice (there could be other choices that are better for that architecture?). Let me know if this works!

TODO:

  • Half-precision support
  • Gradient checkpointing
  • Run training with GPT-2 and compare performance (additional sanity check for the fan_in_fan_out implementation ) I'm able to get performance better than prefix tuning for data in peft_prefix_tuning_clm.ipynb

@SumanthRH
Copy link
Contributor Author

Hello @pacman100 , @younesbelkada . I cleaned up the code a bit more. The example notebooks might still need some edits as there's some code that's specific to my username (model saving part). But apart from that, I hope this looks good!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @SumanthRH for this addition. It looks really nice for the most part and is consistent with the existing code base, well done.

I'm not an expert on this topic (yet), so I cannot comment on every detail. I made a few suggestions, please take a look if they make sense.

In general, I think a docs entry would be good to have, since this introduces a completely new method to the library.

A question more in the direction of the other maintainers: There is a lot of code duplication between this and LoRA. Are we fine with keeping it as or would it be better to add some abstraction? As always, there are pros and cons to both. I just want to ensure that this duplication is indeed the desired state.

@SumanthRH SumanthRH requested a review from BenjaminBossan July 4, 2023 19:57
@SumanthRH
Copy link
Contributor Author

Also @BenjaminBossan, I noticed your comment on documentation a bit late, sorry about that! I've added some documentation for $\text{IA}^3$ now, with a basic conceptual guide. Docs should get updated here. In terms of supported methods/models, I've only indicated those covered by tests or in examples. Let me know if that works!

@BenjaminBossan
Copy link
Member

I've added some documentation for IA3 now, with a basic conceptual guide

Fantastic, well written, thank you.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work @SumanthRH , looking forward to merging this! Thanks also for adding the testing suite for that method!
I left few open questions, let me know what do you think
Also @BenjaminBossan raised a great question, regarding code de-duplication we should maybe start by doing the same approach as what we have in transformers # Copied from .. and test if the method/ class is effectively copied from somewhere else. I would say let's take care of that in a follow up PR

self.assertTrue(torch.allclose(logits_lora, logits_merged, atol=1e-4, rtol=1e-4))
self.assertFalse(torch.allclose(logits_merged, logits_transformers, atol=1e-10, rtol=1e-10))
self.assertTrue(torch.allclose(logits_unmerged, logits_merged, atol=1e-4, rtol=1e-4))
if config_cls == LoraConfig: # merge does not change logits for IA3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate on that? what do you mean specifcally by merge does not change logits for IA3? if merging is not supported I think we should just raise an error. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So merging is supported, but the specific check here is testing for a difference between logits_merged (after merging IA3 vectors) and logits_transformers (without IA3 vectors). With IA3, the initialization is all ones initially (From the paper: "...[the learned vectors] are all initialized with ones so that the overall function computed by the model does not change when they are added." In the case of LoRA, the random initialization will change model outputs ever so slightly after adding LoRA weights, which is being tested with the assertFalse statement. This isn't going to be true for IA3 - so I have simply ignored that test. I'm not sure what other tests can be added though for merging.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks a lot for the detailed explanation ! Do you think it is possible to force the initialization to be something different than all ones.? I think in the past I have faced the same issue with LoRA because the B matrix was initialized to all zeros. I have added a boolean here:

init_lora_weights: bool = field(
let me know if you need help on that !
Otherwise all good I think, it can be done in a follow up PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay! There's actaully an argument already called init_ia3_weights that's supposed to give this control - but even with init_ia3_weights as False, currently the initialization is all ones :
https://github.com/SumanthRH/peft/blob/5ed92fad11aeec3ec94c8dfd908e81a3cb09d4e4/src/peft/tuners/ia3.py#L399

If this going to be different, then the initialization should be such that the merged version is close enough to the pre-trained weights, which means that the entries in the learned vectors should be close enough to one (because they rescale activations). I'm not too sure what standard initialization fits in here. Maybe something like PyTorch's nn.Linear initialization for matrices, but shifted by one: $U(-\sqrt{k}, \sqrt{k}) + 1$ , where $k=1/\text{vector-length}$. I also don't want to make it look like we conjured up an initialization, so I'm open to hearing your thoughts on this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks for explaining, yes let's leave it as it is and we'll take care of that in a follow up PR !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada Could you explain why we want to require the outputs to be different each time when initializing the model? Is it not actually a good thing if by default, the added weights lead to an identity operation? I'm running into the same test error when using LoRA with certain custom models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @BenjaminBossan, LoRA initializes lora_B to all 0 so that we start with an identity function. So, for all usage init_lora_weights=True. Only during testing when we can't actually train the model but want to test the functionality of merge_and_unload, init_lora_weights is set to False else even after merging it will be identity op and we can't test if the merge was successful or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, thanks for explaining. This should probably added as a comment to that line, since I think I'm not the only one who was confused.

Regarding the testing of the merging feature, would it be possible to check the weights directly, instead of the model outputs?

@SumanthRH SumanthRH requested a review from younesbelkada July 8, 2023 06:47
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great on my side ! Thanks for all your great work on this !
Let's see what @pacman100 & @BenjaminBossan will say, let me know if I can help you addressing last comments and incoming comments
Thanks again!
PS: can you run the styling checks ? make style && make quality

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this LGTM and I don't have much to add besides what Younes said.

It is not quite clear to me why _check_target_module_exists differs from the LoRA implementation, but I think we'll refactor some of that soon, so maybe we can keep it as is.

Also @BenjaminBossan raised a great question, regarding code de-duplication we should maybe start by doing the same approach as what we have in transformers # Copied from .. and test if the method/ class is effectively copied from somewhere else. I would say let's take care of that in a follow up PR

Okay, let's deal with that later. I'm not sure if the initial reason why # Copied from was introduced in transformers necessarily apply here as well, but we can discuss it then.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @SumanthRH for all the efforts put it to add IA3. Left a suggestion, other than that LGTM!

Remove unused attribute merge_weights

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
@SumanthRH SumanthRH requested a review from pacman100 July 12, 2023 03:32
@pacman100 pacman100 merged commit c33c42f into huggingface:main Jul 13, 2023
@SumanthRH
Copy link
Contributor Author

Happy to see this merged! A small note to the maintainers @pacman100 @younesbelkada: The example notebooks peft_ia3_seq2seq.ipynb and IA3.ipynb might need some minor changes (specifically, in the "Share adapters to the Hub" sections at the end). Other than that, we should be okay! This was also my first major OSS contribution, so thanks for helping me out!

Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
* Added initial ia3 code

* Implemented ia3 correctly for feedforward layers; Fixed regex matching

* Fixed module mapping for mt5

* Merged changes from huggingface:main

* Merged changes

* Fixed lora merge conflicts

* Different bloom config

* Added save option for ia3

* Added loading code for ia3

* Added feedforward implementation in utils and seq cls example

* Added feedforward implementation in utils and seq cls example

* Implemented merge, unmerge, enable/disable adapters functionality

* Fixed feedforward during merge

* Debugging Merge

* Removing debug messages

* Cleaned up repo

* Removed non-IA3 changes

* Refactor save and load

* Added support to all models in tests; Added IA3Config for common tests

* Added half-precision support and test for gradient checkpointing; Formatted jupyter notebooks

* Added target modules for new models GPTBigCode and LLama

* Cleaned up code

* Cleaned up code

* Cleaned up example notebook

* Cleaned up  seq2seq notebook

* Corrected function docstrings; refactored find_and_replace

* Corrected function docstrings; refactored find_and_replace

* Added basic docs for IA3

* Added new conceptual guide in source tree for documentation

* Minor fix to documentation

* Minor fixes to docstrings; Added error handling for 4bit quantization; Cleaned unused merge/unmerge methods

* styling changes after merge from main

* Update src/peft/tuners/ia3.py

Remove unused attribute merge_weights

Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

---------

Co-authored-by: Abhishek2304 <abhishekgupta2304@gmail.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants