Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
299 commits
Select commit Hold shift + click to select a range
fe220cf
quantization works
ArthurZucker Oct 27, 2025
c6bb839
fixes
ArthurZucker Oct 28, 2025
2fe87ce
updates
ArthurZucker Oct 28, 2025
466df96
updates
ArthurZucker Oct 28, 2025
6f6deb0
update
ArthurZucker Oct 28, 2025
0519e21
fix fp8, it now works
ArthurZucker Oct 28, 2025
7efb487
fix-copies
ArthurZucker Oct 28, 2025
62ccfd9
nits
ArthurZucker Oct 28, 2025
8e74adc
support tp dtensor
ArthurZucker Oct 28, 2025
a5859af
local changes
ArthurZucker Oct 29, 2025
c3f5437
fix tie weight embeddding?
ArthurZucker Oct 29, 2025
a8998de
fix auto for mps
ArthurZucker Oct 29, 2025
9735c6e
current updates
ArthurZucker Oct 29, 2025
965b006
small update
ArthurZucker Oct 29, 2025
ec49d73
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Oct 29, 2025
a92cb1f
Youhou
ArthurZucker Oct 29, 2025
653933c
fix fp8
ArthurZucker Oct 29, 2025
ac1af43
TP + QUANTIZE now works
ArthurZucker Oct 29, 2025
aa0ebbe
the way to make local tensor + Dtensor work
ArthurZucker Oct 29, 2025
e1eb5a4
nit
ArthurZucker Oct 30, 2025
de09779
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Oct 30, 2025
edeacc3
move progress
ArthurZucker Oct 30, 2025
f1312dc
fix llama tests ?
ArthurZucker Oct 30, 2025
c53755f
smoll QOL
ArthurZucker Oct 30, 2025
2214575
ship most fixes
ArthurZucker Oct 30, 2025
3cde7b0
fix bunch of tests
ArthurZucker Oct 30, 2025
17f25f9
fix copies
ArthurZucker Oct 30, 2025
134959c
styling
ArthurZucker Oct 30, 2025
0402e56
yups
ArthurZucker Oct 30, 2025
4443658
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Oct 30, 2025
6c9fda4
small updates
ArthurZucker Oct 30, 2025
28a1d22
add qwen2_moe to the mapping!
ArthurZucker Oct 30, 2025
8cf9694
nit
ArthurZucker Oct 30, 2025
a01ad8d
small nits
ArthurZucker Oct 30, 2025
9f615bc
update
ArthurZucker Oct 30, 2025
fe9b047
up
ArthurZucker Oct 30, 2025
d9bb0e3
fix olmoe
ArthurZucker Oct 30, 2025
50a85ef
fix ernie
ArthurZucker Oct 30, 2025
9bed488
more fixups
ArthurZucker Oct 30, 2025
912dd2f
updates
ArthurZucker Oct 30, 2025
48c85c7
revert small granite moe stuff
ArthurZucker Oct 30, 2025
00e3604
yups
ArthurZucker Oct 30, 2025
edf96f8
update conversion mapping!
ArthurZucker Oct 30, 2025
c3c534f
licence
ArthurZucker Oct 30, 2025
6309347
smal nit
ArthurZucker Oct 30, 2025
b320474
update
ArthurZucker Oct 30, 2025
5d4d27e
up
ArthurZucker Oct 30, 2025
00846a2
Apply suggestion from @LysandreJik
ArthurZucker Oct 31, 2025
f4775fc
updates based on review
ArthurZucker Oct 31, 2025
e0fd1e4
better error handling (Am I too rust-y) ?
ArthurZucker Oct 31, 2025
d34482c
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Oct 31, 2025
904283d
Apply suggestion from @LysandreJik
ArthurZucker Oct 31, 2025
b225885
Apply suggestion from @LysandreJik
ArthurZucker Oct 31, 2025
7f196f9
small nits
ArthurZucker Oct 31, 2025
6d0aa66
fix tie weight keys?
ArthurZucker Oct 31, 2025
ef5123b
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Oct 31, 2025
9f5ec4a
nit
ArthurZucker Oct 31, 2025
2d84aba
fix glob import
ArthurZucker Oct 31, 2025
573af75
fix import and error
ArthurZucker Oct 31, 2025
e848ab6
up
ArthurZucker Oct 31, 2025
1d4411a
update
ArthurZucker Oct 31, 2025
3e4d8ea
up
ArthurZucker Oct 31, 2025
07e265d
up
ArthurZucker Oct 31, 2025
913171a
did not know glob was only 3.13
ArthurZucker Oct 31, 2025
e465bc0
fak
ArthurZucker Oct 31, 2025
19f94d0
how many tests does this fix?
ArthurZucker Oct 31, 2025
29e017d
cleanup
ArthurZucker Oct 31, 2025
7061956
qol + nits
ArthurZucker Oct 31, 2025
0ebb1b6
fixup
ArthurZucker Oct 31, 2025
6b398e1
nit
ArthurZucker Oct 31, 2025
e59b1ff
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Oct 31, 2025
52d85e0
merge
ArthurZucker Nov 1, 2025
29aa051
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 1, 2025
20b6142
small updates?
ArthurZucker Nov 1, 2025
a79de84
cleanup what is no longer used
ArthurZucker Nov 1, 2025
606452d
nits
ArthurZucker Nov 1, 2025
7eda8aa
dtype
ArthurZucker Nov 1, 2025
b148577
up
ArthurZucker Nov 1, 2025
9022bc2
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 1, 2025
0da6e92
upsates
ArthurZucker Nov 1, 2025
9cb0432
qol
ArthurZucker Nov 3, 2025
4d34ced
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 3, 2025
c515eb6
Merge branch 'main' of github.com:huggingface/transformers into refac…
ArthurZucker Nov 3, 2025
85973fc
fix triton import error
ArthurZucker Nov 3, 2025
9b6a7a4
fixup
ArthurZucker Nov 3, 2025
3baf4b7
lol so much time lost on this shit
ArthurZucker Nov 3, 2025
82a35bc
nits
ArthurZucker Nov 3, 2025
6c88206
fix the init of param
ArthurZucker Nov 3, 2025
4d79709
ah actually we don't discard lm head if missing -> needs to be moved …
ArthurZucker Nov 3, 2025
d1e84db
fix some tests
ArthurZucker Nov 3, 2025
f2938df
small fixes
ArthurZucker Nov 3, 2025
22fcdaf
up
ArthurZucker Nov 3, 2025
7d78aa1
up
ArthurZucker Nov 3, 2025
80517f5
dik why we tie weights twice but,..,,.
ArthurZucker Nov 3, 2025
2ff8532
ups
ArthurZucker Nov 3, 2025
d923061
removeunused
ArthurZucker Nov 3, 2025
ce8c1c1
fix hunyuan
ArthurZucker Nov 3, 2025
23e3ed7
small fix
ArthurZucker Nov 3, 2025
a8fb554
nits
ArthurZucker Nov 3, 2025
ab6ee8a
ish
ArthurZucker Nov 3, 2025
77ccbb1
up
ArthurZucker Nov 3, 2025
8a8beff
rev
ArthurZucker Nov 3, 2025
02386ce
fix more tie weights keys
ArthurZucker Nov 3, 2025
1c87945
small fixes
ArthurZucker Nov 3, 2025
00b95ee
nit
ArthurZucker Nov 3, 2025
a170f29
update
ArthurZucker Nov 3, 2025
8b924a3
fix and fix
ArthurZucker Nov 3, 2025
8f7b1d0
fix a test
ArthurZucker Nov 3, 2025
9386217
glubs
ArthurZucker Nov 3, 2025
4894a25
current shitty changes
ArthurZucker Nov 3, 2025
da7dc10
ship validated ones
ArthurZucker Nov 4, 2025
d7c8171
more
ArthurZucker Nov 4, 2025
e088408
more update
ArthurZucker Nov 4, 2025
4f212de
more
ArthurZucker Nov 4, 2025
dc5a22c
more
ArthurZucker Nov 4, 2025
675b2bc
more
ArthurZucker Nov 4, 2025
f85f239
mllama
ArthurZucker Nov 4, 2025
76b6a92
more up
ArthurZucker Nov 4, 2025
ba1a8b6
fix ernie
ArthurZucker Nov 4, 2025
ba3de5a
fix xopies
ArthurZucker Nov 4, 2025
8fd255c
up more
ArthurZucker Nov 4, 2025
5d7507b
more fixes
ArthurZucker Nov 4, 2025
0fb2340
up
ArthurZucker Nov 4, 2025
32b9273
up
ArthurZucker Nov 4, 2025
0b95826
fix-copies
ArthurZucker Nov 4, 2025
5794d27
fix more
ArthurZucker Nov 4, 2025
5e71bd4
more updates
ArthurZucker Nov 4, 2025
20d1b34
AI UPDATE
ArthurZucker Nov 4, 2025
89846e7
up
ArthurZucker Nov 5, 2025
a581fd7
hoey
ArthurZucker Nov 5, 2025
1652c9c
make it fast
Cyrilvallez Nov 5, 2025
dcad703
fix
Cyrilvallez Nov 5, 2025
c921ced
lol
ArthurZucker Nov 5, 2025
50714d8
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
8936cc4
fix asjusting
ArthurZucker Nov 5, 2025
5c54332
more fixes
ArthurZucker Nov 5, 2025
ff10878
_dtype nit
ArthurZucker Nov 5, 2025
9601b82
up
ArthurZucker Nov 5, 2025
db02b9d
nit
ArthurZucker Nov 5, 2025
42fd4c4
update
ArthurZucker Nov 5, 2025
4527171
update
ArthurZucker Nov 5, 2025
bd36211
remove semaphores
Cyrilvallez Nov 5, 2025
e2aefee
fix import to avoid jit execution
Cyrilvallez Nov 5, 2025
74a0e9c
try to remove custom tiing logic when its stupid
ArthurZucker Nov 5, 2025
ead2ac3
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
e7165da
fix more individual models
ArthurZucker Nov 5, 2025
2ff765e
fix whisper as well
ArthurZucker Nov 5, 2025
912562c
fix?
ArthurZucker Nov 5, 2025
c43495a
fox umt5
ArthurZucker Nov 5, 2025
57988f2
improve tqdm bar
Cyrilvallez Nov 5, 2025
8c16de1
cleanup a bit
Cyrilvallez Nov 5, 2025
b8927d6
oupsi
Cyrilvallez Nov 5, 2025
2733ff6
some updates
ArthurZucker Nov 5, 2025
8baa3fe
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
d91701f
improve
Cyrilvallez Nov 5, 2025
5146dec
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
Cyrilvallez Nov 5, 2025
acc5b24
remove all buffering -> much faster without it
Cyrilvallez Nov 5, 2025
58389a1
remove some tie_weights custome funcs when not needed
ArthurZucker Nov 5, 2025
92c0229
more fixes related to strict matching regex
ArthurZucker Nov 5, 2025
d9e7fe6
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
b57d789
remove ALL custom tie weights
ArthurZucker Nov 5, 2025
ef8b6c3
small update
ArthurZucker Nov 5, 2025
a228fd0
revert change to init scheme (no need for params)
Cyrilvallez Nov 5, 2025
07574dd
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
710b1ff
fix
SunMarc Nov 5, 2025
2526cc5
mixtral init
Cyrilvallez Nov 5, 2025
6cb3794
try less strict source check
ArthurZucker Nov 5, 2025
e4cadfb
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 5, 2025
3fea865
tied weight first shot to the fiiiixxxxxx
Cyrilvallez Nov 5, 2025
82f94b8
does this help?
ArthurZucker Nov 5, 2025
84dd6eb
:)
ArthurZucker Nov 5, 2025
cc08195
fix some ppolry defined tied_weights_keys for now
ArthurZucker Nov 5, 2025
f72f96d
fixes for more models torch_bc
ArthurZucker Nov 5, 2025
e341529
nits and fixes
ArthurZucker Nov 5, 2025
0e51dec
last update
ArthurZucker Nov 5, 2025
0f022b5
Revert "tied weight first shot to the fiiiixxxxxx"
ArthurZucker Nov 5, 2025
1dabb4c
here we go again
ArthurZucker Nov 5, 2025
0c2b667
an attempt
ArthurZucker Nov 6, 2025
c48e1ed
up?
ArthurZucker Nov 6, 2025
d223635
nits
ArthurZucker Nov 6, 2025
bdbc01a
Fix bnb loading !
SunMarc Nov 6, 2025
399388d
rm print
SunMarc Nov 6, 2025
acbeeae
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 6, 2025
f0cf8d9
add mxfp4
MekkCyber Nov 6, 2025
0364bcb
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber Nov 6, 2025
c4097db
fix missing keys
MekkCyber Nov 6, 2025
f692f4b
subclass nn.Parameters
ArthurZucker Nov 7, 2025
2fa058f
up
ArthurZucker Nov 7, 2025
78d4622
lol
ArthurZucker Nov 7, 2025
8ff4ad5
Ouiiii
ArthurZucker Nov 7, 2025
3222678
fix led
ArthurZucker Nov 7, 2025
9a76a6e
fix long cat flash
ArthurZucker Nov 7, 2025
9fde9f7
fix qwen and long cat flash
ArthurZucker Nov 7, 2025
074a449
properly fix qwen init
ArthurZucker Nov 7, 2025
0189b1d
fix dequantize path
MekkCyber Nov 7, 2025
dde5500
just push this for now
ArthurZucker Nov 7, 2025
0e7d2d0
propnet is dumb
ArthurZucker Nov 7, 2025
86082d2
first poc + tests passing
MekkCyber Nov 7, 2025
ee709ca
style
MekkCyber Nov 7, 2025
18b02ee
update
ArthurZucker Nov 7, 2025
e16da23
rm import
SunMarc Nov 7, 2025
386e259
update
SunMarc Nov 7, 2025
9c0db72
push
ArthurZucker Nov 7, 2025
9788014
Merge remote-tracking branch 'upstream/refactor-weight-loading' into …
SunMarc Nov 7, 2025
72eff97
Update src/transformers/core_model_loading.py
SunMarc Nov 7, 2025
75d3afc
remove explict sharing of some tied keys.
ArthurZucker Nov 7, 2025
85ab085
update decoder.bias
ArthurZucker Nov 7, 2025
443573a
moe case
ArthurZucker Nov 7, 2025
d841a04
Fix loadedparam
SunMarc Nov 7, 2025
e235eed
Merge remote-tracking branch 'upstream/fix-bnb' into fix-bnb
SunMarc Nov 7, 2025
e4df752
rm report
SunMarc Nov 7, 2025
f8f0973
more changes to untangle old hardcoded ting
ArthurZucker Nov 7, 2025
5c9d56c
fixup
ArthurZucker Nov 7, 2025
a0029f2
Merge branch 'main' into refactor-weight-loading
ArthurZucker Nov 7, 2025
44943fb
fix big faileurs
ArthurZucker Nov 7, 2025
3e69622
Fix tests single gpu
SunMarc Nov 7, 2025
a052513
should fix it
SunMarc Nov 7, 2025
76d66be
fix prophnet
ArthurZucker Nov 7, 2025
d176b48
Merge branch 'refactor-weight-loading' of github.com:huggingface/tran…
ArthurZucker Nov 7, 2025
3ffc59e
fix resize token embeddings
ArthurZucker Nov 10, 2025
2a00e49
nits
ArthurZucker Nov 10, 2025
f7d0183
fix xcodex
ArthurZucker Nov 10, 2025
bbf5b00
asyncio?
ArthurZucker Nov 10, 2025
0412832
fix smart apply
ArthurZucker Nov 10, 2025
c137ea3
fix data-2-vec
ArthurZucker Nov 10, 2025
7b7c990
[build-ci-image]
ArthurZucker Nov 10, 2025
de74aeb
checkout
ArthurZucker Nov 10, 2025
94a53d4
uupdate
ArthurZucker Nov 10, 2025
db4fe31
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
8755a4b
fix hunyuan
ArthurZucker Nov 10, 2025
5be67b9
update error message
ArthurZucker Nov 10, 2025
86a4e51
fix deformable detr
ArthurZucker Nov 10, 2025
09bcd2e
fixes
ArthurZucker Nov 10, 2025
7b457fd
fix init weights for non param gate up projs
ArthurZucker Nov 10, 2025
32bec2b
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber Nov 10, 2025
e033947
shared todo?
ArthurZucker Nov 10, 2025
95933aa
add comment
MekkCyber Nov 10, 2025
9fa1b7a
guard needed for compressed-tensors
SunMarc Nov 10, 2025
ea5822d
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
5881d8e
deal with buffers
SunMarc Nov 10, 2025
f93f357
update some models
ArthurZucker Nov 10, 2025
2f0a6ae
big revert, don't break this behaviour
ArthurZucker Nov 10, 2025
3c8c757
ty @SunMarc this fixes the buffers
ArthurZucker Nov 10, 2025
f5a7c33
mt5 fuck
ArthurZucker Nov 10, 2025
3651460
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
00b0044
Merge branch 'refactor-weight-loading' into fix-bnb
SunMarc Nov 10, 2025
7d8df52
fix
SunMarc Nov 10, 2025
501ed80
fix
MekkCyber Nov 11, 2025
3d7ce14
Merge branch 'fix-bnb' into fix-mxfp4
MekkCyber Nov 12, 2025
b8d2409
fix
MekkCyber Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py
Expand Down
14 changes: 7 additions & 7 deletions docs/source/de/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
Expand All @@ -533,9 +533,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
Expand Down
14 changes: 7 additions & 7 deletions docs/source/en/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
Expand All @@ -339,9 +339,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

### Convert checkpoints to Transformers
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```

Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
Expand Down
14 changes: 7 additions & 7 deletions docs/source/ja/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
Expand All @@ -431,9 +431,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
Expand Down
14 changes: 7 additions & 7 deletions docs/source/ko/add_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```

몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
Expand All @@ -371,9 +371,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```

`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```

배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
Expand Down
20 changes: 7 additions & 13 deletions examples/modular-transformers/modeling_dummy_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,10 @@ def __init__(self, config):

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand All @@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()


@auto_docstring(
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _init_weights(self, module):

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()
module.weight.zero_()


class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
Expand Down
16 changes: 12 additions & 4 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)

if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()


def token_type_ids_mask_function(
Expand Down Expand Up @@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related

def __init__(self, config):
Expand All @@ -440,7 +440,15 @@ def __init__(self, config):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)

if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self.post_init()

def get_input_embeddings(self):
Expand Down
20 changes: 7 additions & 13 deletions examples/modular-transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,10 @@ def __init__(self, config):

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand All @@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()


@auto_docstring(
Expand Down
6 changes: 3 additions & 3 deletions examples/modular-transformers/modeling_test_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,11 @@ def _init_weights(self, module):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points") and not self.config.two_stage:
Expand Down
10 changes: 9 additions & 1 deletion examples/modular-transformers/modular_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ def __init__(self, config):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)

if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping

self.post_init()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
"pyyaml>=5.1",
"pydantic>=2",
"pytest>=7.2.0",
"pytest-asyncio",
"pytest-asyncio>=1.2.0",
"pytest-rerunfailures<16.0",
"pytest-timeout",
"pytest-xdist",
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def to_diff_dict(self) -> dict[str, Any]:
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(serializable_config_dict)
Expand Down Expand Up @@ -910,7 +910,7 @@ def to_dict(self) -> dict[str, Any]:
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(output)
Expand Down
Loading
Loading