Skip to content

Commit a663197

Browse files
author
Swati Allabadi
committed
Adding unit test cases
Signed-off-by: Swati Allabadi <[email protected]>
1 parent f88f01d commit a663197

File tree

2 files changed

+149
-7
lines changed

2 files changed

+149
-7
lines changed

QEfficient/finetune/experimental/core/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from abc import ABC, abstractmethod
1010
from typing import Any, Dict, Optional, Type
1111

12-
import torch
1312
import torch.nn as nn
1413
from transformers import AutoTokenizer, BitsAndBytesConfig
1514
import transformers
15+
from transformers.utils.logging import get_logger
1616

1717
from QEfficient.finetune.experimental.core.component_registry import registry
18-
from QEfficient.finetune.experimental.utils.dataset_helper import insert_pad_token
18+
from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token
1919

2020
logger = get_logger(__name__)
2121

@@ -38,7 +38,6 @@ def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel":
3838
if not isinstance(module, nn.Module):
3939
raise TypeError(f"load_model() must return nn.Module, got {type(module)}")
4040
obj._model = module
41-
obj.add_module("_wrapped_model", module) # register
4241
return obj
4342

4443
@abstractmethod
@@ -70,14 +69,14 @@ def forward(self, *args, **kwargs):
7069
def get_input_embeddings(self):
7170
if hasattr(self.model, "get_input_embeddings"):
7271
return self.model.get_input_embeddings()
73-
logger.log_rank_zero(f"Model {self.model_name} does not expose input embeddings", logging.WARNING)
72+
logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING)
7473
return None
7574

7675
def resize_token_embeddings(self, new_num_tokens: int) -> None:
7776
if hasattr(self.model, "resize_token_embeddings"):
7877
self.model.resize_token_embeddings(new_num_tokens)
7978
else:
80-
logger.log_rank_zero(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING)
79+
logger.info(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING)
8180

8281
# optional
8382
def to(self, *args, **kwargs):
@@ -134,7 +133,7 @@ def configure_model_kwargs(self) -> Dict[str, Any]:
134133
return extra
135134

136135
def load_model(self) -> nn.Module:
137-
logger.log_rank_zero(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}")
136+
logger.info(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}")
138137

139138
return self.auto_class.from_pretrained(
140139
self.model_name,
@@ -143,7 +142,7 @@ def load_model(self) -> nn.Module:
143142

144143
def load_tokenizer(self) -> AutoTokenizer:
145144
"""Load Hugging Face tokenizer."""
146-
logger.log_rank_zero(f"Loading tokenizer '{self.tokenizer_name}'")
145+
logger.info(f"Loading tokenizer '{self.tokenizer_name}'")
147146
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
148147
insert_pad_token(tokenizer)
149148
return tokenizer
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from unittest import mock
5+
6+
import transformers
7+
from QEfficient.finetune.experimental.core import model
8+
from QEfficient.finetune.experimental.core.model import BaseModel, HFModel
9+
10+
11+
class TestMockModel(nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.linear = nn.Linear(2, 2)
15+
16+
def forward(self, x):
17+
return self.linear(x)
18+
19+
20+
class TestCustomModel(BaseModel):
21+
def __init__(self, model_name):
22+
super().__init__(model_name)
23+
print("init of custom class")
24+
25+
def load_model(self) -> nn.Module:
26+
return TestMockModel()
27+
28+
def load_tokenizer(self):
29+
return "dummy-tokenizer"
30+
31+
32+
# BaseModel tests
33+
def test_model_property_errors_if_not_created():
34+
m = TestCustomModel("dummy")
35+
with pytest.raises(RuntimeError):
36+
_ = m.model # must call .create()
37+
38+
39+
def test_create_builds_and_registers():
40+
breakpoint()
41+
m = TestCustomModel.create("dummy")
42+
# inner model exists and registered
43+
assert "_model" in m._modules
44+
assert isinstance(m.model, TestMockModel)
45+
# forward works
46+
out = m(torch.zeros(1, 2))
47+
assert out.shape == (1, 2)
48+
49+
50+
def test_tokenizer_lazy_loading():
51+
m = TestCustomModel.create("dummy")
52+
assert m._tokenizer is None
53+
tok = m.tokenizer
54+
assert tok == "dummy-tokenizer"
55+
assert m._tokenizer == tok
56+
57+
58+
def test_to_moves_inner_and_returns_self():
59+
m = TestCustomModel.create("dummy")
60+
with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to:
61+
ret = m.to("cuda:0")
62+
mocked_to.assert_called_once_with(m.model, "cuda:0")
63+
assert ret is m
64+
65+
66+
def test_train_eval_sync_flags():
67+
m = TestCustomModel.create("dummy")
68+
m.eval()
69+
assert m.training is False
70+
assert m.model.training is False
71+
m.train()
72+
assert m.training is True
73+
assert m.model.training is True
74+
75+
76+
def test_resize_token_embeddings_and_get_input_embeddings_warn(monkeypatch):
77+
m = TestCustomModel.create("dummy")
78+
79+
# resize_token_embeddings: underlying model lacks the method, should warn and not raise
80+
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
81+
m.resize_token_embeddings(10)
82+
mocked_log.assert_called_once()
83+
84+
# get_input_embeddings: underlying model lacks method, should warn and return None
85+
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
86+
assert m.get_input_embeddings() is None
87+
mocked_log.assert_called_once()
88+
89+
90+
def test_state_dict_contains_inner_params():
91+
m = TestCustomModel.create("dummy")
92+
sd = m.state_dict()
93+
# should contain params from TestMockModel.linear
94+
assert any("linear.weight" in k for k in sd)
95+
assert any("linear.bias" in k for k in sd)
96+
97+
98+
# HFModel tests
99+
def test_hfmodel_invalid_auto_class_raises():
100+
with pytest.raises(ValueError):
101+
HFModel.create("hf-name", auto_class_name="AutoDoesNotExist")
102+
103+
104+
def test_hfmodel_loads_auto_and_tokenizer(monkeypatch):
105+
# fake HF Auto class
106+
class FakeAuto(nn.Module):
107+
@classmethod
108+
def from_pretrained(cls, name, **kwargs):
109+
inst = cls()
110+
inst.loaded = (name, kwargs)
111+
return inst
112+
113+
def forward(self, x):
114+
return x
115+
116+
fake_tok = mock.Mock()
117+
118+
# Monkeypatch transformer classes used in HFModel
119+
monkeypatch.setattr(
120+
"QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM",
121+
FakeAuto,
122+
raising=False,
123+
)
124+
monkeypatch.setattr(
125+
model,
126+
"AutoTokenizer",
127+
mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)),
128+
)
129+
monkeypatch.setattr(
130+
"QEfficient.finetune.experimental.core.model.insert_pad_token",
131+
mock.Mock(),
132+
raising=False,
133+
)
134+
135+
m = HFModel.create("hf-name")
136+
assert isinstance(m.model, FakeAuto)
137+
138+
# load tokenizer
139+
tok = m.load_tokenizer()
140+
141+
# tokenizer was loaded and pad token inserted
142+
model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name")
143+
model.insert_pad_token.assert_called_once_with(fake_tok)

0 commit comments

Comments
 (0)