Skip to content

Commit 3d83278

Browse files
jouwhwchen2017raviguptaamd
authored
fix init weights issue for critic/reward model (#983)
* Add file extension (#980) Signed-off-by: Hongwei Chen <[email protected]> Signed-off-by: jouw <[email protected]> * fix init weights issue for critic/reward model Signed-off-by: jouw <[email protected]> * Update submodule link to reflect https style (#981) Signed-off-by: raviguptaamd <[email protected]> Signed-off-by: jouw <[email protected]> * fix formatting issue Signed-off-by: jouw <[email protected]> --------- Signed-off-by: Hongwei Chen <[email protected]> Signed-off-by: jouw <[email protected]> Signed-off-by: raviguptaamd <[email protected]> Co-authored-by: Hongwei Chen <[email protected]> Co-authored-by: raviguptaamd <[email protected]>
1 parent 4579df3 commit 3d83278

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from huggingface_hub import snapshot_download
1313
from transformers.integrations.deepspeed import HfDeepSpeedConfig
14+
from transformers.modeling_utils import no_init_weights
1415

1516
from dschat.utils.model.reward_model import RewardModel
1617
from dschat.utils.utils import load_state_dict_into_model, print_rank_0
@@ -99,7 +100,8 @@ def create_hf_model(model_class,
99100
dschf = None
100101
if rlhf_training:
101102
# the weight loading is handled by create critic model
102-
model = model_class.from_config(model_config)
103+
with no_init_weights():
104+
model = model_class.from_config(model_config)
103105
else:
104106
model = model_class.from_pretrained(
105107
model_name_or_path,

0 commit comments

Comments
 (0)