Skip to content

add glora#3098

Open
not-lain wants to merge 39 commits into
huggingface:mainfrom
not-lain:glora2
Open

add glora#3098
not-lain wants to merge 39 commits into
huggingface:mainfrom
not-lain:glora2

Conversation

@not-lain

Copy link
Copy Markdown

follow up on #780 and #2568
this pr adds GLoRA to the library
i also made a minor colab notebook to test this remotely in which you can find here

@BenjaminBossan BenjaminBossan left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviving the GLoRA implementation to PEFT. There is still a lot missing (docs, examples, more tests) but let's focus on the integration for now and work on the rest later.

Unfortunately, this PR seems to be based on the state that PEFT was in when the PR was first suggested. We made several refactors since then, which require different patterns to implement. The good news is that the final result should be much simpler with less code required. I marked the corresponding parts, please check. Maybe this is even something that a coding agent can do if asked to update the implementation and pointed to the most recent PEFT code for reference.

Comment thread docs/source/package_reference/glora.md Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated


# Refactored GLoraLinear for PEFT compatibility
class GLoraLinear(GLoraLayer, nn.Linear):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably based this PR on the very old original PR. Since then, we had several refactors in PEFT, some of which affect how we designed the adapter layers. Could you please check the latest implementation in tuners/lora/layer or PRs of other recent PEFT additions (e.g. #2851)? Most notably, we now pass the base_layer (i.e. the original layer) and wrap it inside the PEFT layer. To get the results of the base layer, we can then call self.base_layer(x) in the forward call (no need for F.linear(x, self.weight)).

Comment thread src/peft/tuners/glora/model.py Outdated
m.bias.requires_grad = True


class GLoraModel(BaseTuner):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar argument as for GLoraLayer: We have substantially refactored this part of PEFT. The good news is that it should greatly simplify the overall implementation: You only need to define _create_and_replace and _create_new_module, the remaining methods should all be fine when inherited from the parent class. Moreover, we need these class attributes:

  • prefix: str = "glora_"
  • tuner_layer_cls = GloraLayer
  • target_module_mapping = TRANSFORMERS_MODELS_TO_GLORA_TARGET_MODULES_MAPPING

Comment thread tests/test_glora.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated
self.glora_Cu: nn.ParameterDict = nn.ParameterDict()
self.glora_D: nn.ParameterDict = nn.ParameterDict()
self.glora_E: nn.ParameterDict = nn.ParameterDict()
self.eval_config: dict[str, dict[str, object]] = {}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? IIUC, this is supposed to be an object that unifies the settings for each parameter, as we have distinct options like "vector" or "constant". I haven't checked all the details, but ideally we would just define during the layer initialization something like:

if config_A_B == "lora":
    self.glora_Ad[adapter_name] = nn.Linear(...)
elif config_A_B == "vector":
    self.glora_Ad[adapter_name] = ...

The arguments config_A_B etc. should be passed to the __init__ and update_layer methods, directly coming from the GloraConfig.

This change may require some custom nn.Modules but that would be fine. I would really like to avoid the whole prepare_path call during forward and frontload the whole resolution to the initialization. This make the forward/merge/unmerge call simpler to understand and should also be slightly more performant.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is still relevant.

Comment thread src/peft/tuners/glora/layer.py Outdated
not-lain and others added 4 commits March 17, 2026 07:55
@BenjaminBossan

Copy link
Copy Markdown
Member

Note that due to another PEFT method being merged, there is now a merge conflict, but it should be straightforward to resolve. Once you're finished, please ping me for another review.

not-lain added 10 commits March 23, 2026 15:08
- Updated GloraLayer to inherit from BaseTunerLayer
- Enhanced adapter management with new methods and improved type checks.
- Refined initialization to ensure compatibility with nn.Linear layers only.
- Adjusted merge logic to handle weight and bias more robustly.
- Updated tests to skip unsupported quantized layers.
@not-lain

Copy link
Copy Markdown
Author

@BenjaminBossan
made the necessary changes mentioned above and made a couple extra changes to match that of test_custom_models.
would appreciate it if you can take a look at this pr.

@BenjaminBossan BenjaminBossan left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates. I found that there is still some work left to make GLoRA consistent with other PEFT methods. Please check my comments and don't hesitate to ask if something is unclear.

Also note that in the meantime, we have updated the contribution guidelines for adding new PEFT methods: https://github.com/huggingface/peft/blob/main/docs/source/developer_guides/contributing.md#add-a-new-peft-fine-tuning-method. You already tick many boxes, but please refer to this when it comes to what's still missing.

Comment thread docs/source/package_reference/glora.md Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/config.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated
self.out_features: int
base = self.get_base_layer()
# Use exact type check: bitsandbytes.Linear4bit subclasses nn.Linear but is not compatible with GLORA math.
if type(base) is not nn.Linear:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this. The check for the correct type should happen in _create_new_module.

Comment thread src/peft/tuners/glora/model.py Outdated
"config_D_E": peft_config.config_D_E,
}
new_module = GloraLinear(target, **kwargs_glora)
new_module.add_adapter(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, don't call add_adapter here, new_module = GloraLinear(target, **kwargs_glora) should be enough, as GloraLinear should call self.update_layer.

Also note that, as mentioned earlier, we did a small refactor here, so the peft_config should be passed to GloraLinear's __init__.

Comment thread src/peft/tuners/glora/model.py Outdated
Comment thread tests/test_custom_models.py Outdated
_skip_if_merging_not_supported(test_name, config_cls, config_kwargs_1)
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)

config_1 = config_cls(**config_kwargs_1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be necessary. Why do the results differ?

@not-lain not-lain May 31, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that GLoRA's merge is multiplicative, not purely additive, during multi-active forward, each adapter computes its delta against the original frozen W0 thus the formula :

 result = W0@x + (W0*A1 + B1)@x + (W0*A2 + B2)@x

But sequential merging mutates W0 in place between adapters:

  • after adapter1: W0' = W0 + W0*A1 + B1
  • after adapter2: W0'' = W0' + W0'*A2 + B2

The second merge uses the already-modified W0', meaning W0'*A2 expands to (W0 + W0*A1 + B1)*A2 instead of the original W0*A2

LoRA doesn't have this issue because its merge is purely additive (W0 + delta), so i chose to skip the test for GLoRA for now.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're saying that if I have a GLoRA model with two active adapters, it behaves differently when merged vs unmerged? I think this should be treated like a bug. IIUC, it should be possible to implement the merge so that it gives the same result by not merging sequentially but combining the adapters before merging.

That said, if a user wants to merge adapter 1 first and then adapter 2 in a separate call, that would indeed require extra logic, as the results of naively merging the 2nd adapter would not be correct. I see two solutions here:

  1. First unmerge adapter 1, then merge 1 and 2 together.
  2. When merging, check if there is already a merged adapter and raise an error, telling users they need to merge all adapters in one step.

WDYT?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chose to unmerge adapters before merging, didn't want to introduce an edge case for glora
you can review the changes related to this change in 43ceaee

Comment thread tests/testing_common.py Outdated

@BenjaminBossan BenjaminBossan left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I still found a couple of issues, but the integration is close to being ready. Please take a look.

Comment thread method_comparison/text_generation_benchmark/utils.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated
Comment thread src/peft/tuners/glora/layer.py
Comment thread src/peft/tuners/glora/layer.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated
Comment thread src/peft/tuners/glora/layer.py Outdated
Comment thread src/peft/tuners/glora/layer.py
Comment thread src/peft/tuners/glora/layer.py Outdated
Comment thread tests/test_custom_models.py Outdated
_skip_if_merging_not_supported(test_name, config_cls, config_kwargs_1)
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)

config_1 = config_cls(**config_kwargs_1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're saying that if I have a GLoRA model with two active adapters, it behaves differently when merged vs unmerged? I think this should be treated like a bug. IIUC, it should be possible to implement the merge so that it gives the same result by not merging sequentially but combining the adapters before merging.

That said, if a user wants to merge adapter 1 first and then adapter 2 in a separate call, that would indeed require extra logic, as the results of naively merging the 2nd adapter would not be correct. I see two solutions here:

  1. First unmerge adapter 1, then merge 1 and 2 together.
  2. When merging, check if there is already a merged adapter and raise an error, telling users they need to merge all adapters in one step.

WDYT?

@not-lain not-lain requested a review from BenjaminBossan June 3, 2026 07:55
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan

Copy link
Copy Markdown
Member

@not-lain Your latest changes look good. I wanted to run the existing tests to see if there are any issues but the linter is complaining. Could you please run make style?

If the tests pass, the next step would be to complete the "Full PR" according to the contribution guideline.

@not-lain

not-lain commented Jun 5, 2026

Copy link
Copy Markdown
Author

thanks a lot for the follow up, tried addressing the rest of the points from the guideline, could you try running the CI again and let me know if you have any feedback

@BenjaminBossan BenjaminBossan left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding the docs, examples, and other little changes. Some tests are failing, but that's just because the error message doesn't quite match, this should be easy to fix.

To complete the PR, could you please also add the tests to the other test files? test_config.py, test_encoder_decoder_models.py, test_feature_extraction_models.py, test_seq_classifier.py.

Also, we recently added a new benchmark for image generation similar to the MetaMath benchmark. It would be a nice addition to add an experiment for that as well.

- GLoRA is a superset of LoRA: setting all paths to "lora" recovers standard LoRA.
- You can use different path types for A/B/C/D/E to experiment with new adaptation strategies.
- GLoRA supports all standard PEFT adapter management features (add, delete, switch, merge, etc).
- For unsupported module types, set `target_modules` to linear projections only.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate on this comment?

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str, default="path/to/model")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about removing the invalid default value?

"machine": platform.machine(),
"processor": platform.processor(),
"accelerator": torch_accelerator_module.get_device_name(0) if torch_accelerator_module.is_available() else "N/A",
"accelerator": torch_accelerator_module.get_device_name(0)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please undo unrelated changes in this file.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the example and it runs.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran the experiment locally. It got 42% test accuracy, which is not top scoring, but expected given the relatively low rank. We can leave it as is and check for better settings in the future.


# GLoRA

Generalized Low-Rank Adaptation ([GLoRA](https://huggingface.co/papers/2306.07967)) is a PEFT method that generalizes LoRA and related approaches. GLoRA decomposes updates into configurable paths (A, B, C, D, E), where each path can use low-rank, vector, constant, or disabled parameterization depending on the path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if you could quickly explain what the four options ("lora", "vector", "constant", "none") mean. Let's also mention that they trade off number of params vs expressiveness.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got 36% test accuracy with this one. Given the very small number of trainable parameters, I think it's a nice result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants