Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions QEfficient/finetune/experimental/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,145 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type

import torch.nn as nn
from transformers import AutoTokenizer
import transformers
from transformers.utils.logging import get_logger

from QEfficient.finetune.experimental.core.component_registry import registry
from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token

Check failure on line 18 in QEfficient/finetune/experimental/core/model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/finetune/experimental/core/model.py:8:1: I001 Import block is un-sorted or un-formatted

logger = get_logger(__name__)


class BaseModel(nn.Module, ABC):
"""Shared skeleton for every finetunable model in the system."""

def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__()
self.model_name = model_name
self.model_kwargs: Dict[str, Any] = model_kwargs
self._model: Optional[nn.Module] = None
self._tokenizer: Any = None # HF tokenizers are not nn.Modules.

# Factory constructor: load model after __init__ finishes
@classmethod
def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel":
obj = cls(model_name, **model_kwargs)
module = obj.load_model()
if not isinstance(module, nn.Module):
raise TypeError(f"load_model() must return nn.Module, got {type(module)}")
obj._model = module
return obj

@abstractmethod
def load_model(self) -> nn.Module:
"""Create and return the underlying torch.nn.Module."""
...

def load_tokenizer(self) -> Any:
"""Override if the model exposes a tokenizer."""
raise NotImplementedError(f"{type(self).__name__} does not provide a tokenizer.")

# Lazy accessors
@property
def model(self) -> nn.Module:
if self._model is None:
raise RuntimeError("Model not loaded; use .create(...) to load.")
return self._model

@property
def tokenizer(self) -> Any:
if self._tokenizer is None:
self._tokenizer = self.load_tokenizer()
return self._tokenizer

# nn.Module API surface
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

def get_input_embeddings(self):
if hasattr(self.model, "get_input_embeddings"):
return self.model.get_input_embeddings()
logger.info(f"Model {self.model_name} does not expose input embeddings", logging.WARNING)
return None

def resize_token_embeddings(self, new_num_tokens: int) -> None:
if hasattr(self.model, "resize_token_embeddings"):
self.model.resize_token_embeddings(new_num_tokens)
else:
logger.info(f"Model {self.model_name} cannot resize token embeddings", logging.WARNING)

# optional
def to(self, *args, **kwargs):
self.model.to(*args, **kwargs)
return self

def train(self, mode: bool = True):
self.model.train(mode)
return super().train(mode)

def eval(self):
return self.train(False)


@registry.model("hf")
class HFModel(BaseModel):
"""HuggingFace-backed model with optional quantization."""

def __init__(
self,
model_name: str,
auto_class_name: str = "AutoModelForCausalLM",
*,
tokenizer_name: Optional[str] = None,
**model_kwargs: Any,
) -> None:
super().__init__(model_name, **model_kwargs)
self.tokenizer_name = tokenizer_name or model_name
self.auto_class: Type = self._resolve_auto_class(auto_class_name)

@staticmethod
def _resolve_auto_class(auto_class_name: str) -> Type:
if not hasattr(transformers, auto_class_name):
candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel"))
raise ValueError(
f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}"
)
return getattr(transformers, auto_class_name)

# def _build_quant_config(self) -> Optional[BitsAndBytesConfig]:
# if not self.model_kwargs.get("load_in_4bit"):
# return None
# return BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"),
# bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16),
# bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True),
# )

def configure_model_kwargs(self) -> Dict[str, Any]:
"""Hook for subclasses to tweak HF `.from_pretrained` kwargs."""
extra = dict(self.model_kwargs)
# extra["quantization_config"] = self._build_quant_config()
return extra

def load_model(self) -> nn.Module:
logger.info(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}")

return self.auto_class.from_pretrained(
self.model_name,
**self.configure_model_kwargs(),
)

def load_tokenizer(self) -> AutoTokenizer:
"""Load Hugging Face tokenizer."""
logger.info(f"Loading tokenizer '{self.tokenizer_name}'")
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
insert_pad_token(tokenizer)
return tokenizer
143 changes: 143 additions & 0 deletions QEfficient/finetune/experimental/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import pytest
import torch
import torch.nn as nn
from unittest import mock

import transformers

Check failure on line 6 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

QEfficient/finetune/experimental/tests/test_model.py:6:8: F401 `transformers` imported but unused
from QEfficient.finetune.experimental.core import model
from QEfficient.finetune.experimental.core.model import BaseModel, HFModel

Check failure on line 8 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/finetune/experimental/tests/test_model.py:1:1: I001 Import block is un-sorted or un-formatted


class TestMockModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)


class TestCustomModel(BaseModel):
def __init__(self, model_name):
super().__init__(model_name)
print("init of custom class")

def load_model(self) -> nn.Module:
return TestMockModel()

def load_tokenizer(self):
return "dummy-tokenizer"


# BaseModel tests
def test_model_property_errors_if_not_created():
m = TestCustomModel("dummy")
with pytest.raises(RuntimeError):
_ = m.model # must call .create()


def test_create_builds_and_registers():
breakpoint()
m = TestCustomModel.create("dummy")
# inner model exists and registered
assert "_model" in m._modules
assert isinstance(m.model, TestMockModel)
# forward works
out = m(torch.zeros(1, 2))
assert out.shape == (1, 2)


def test_tokenizer_lazy_loading():
m = TestCustomModel.create("dummy")
assert m._tokenizer is None
tok = m.tokenizer
assert tok == "dummy-tokenizer"
assert m._tokenizer == tok


def test_to_moves_inner_and_returns_self():
m = TestCustomModel.create("dummy")
with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to:
ret = m.to("cuda:0")
mocked_to.assert_called_once_with(m.model, "cuda:0")
assert ret is m


def test_train_eval_sync_flags():
m = TestCustomModel.create("dummy")
m.eval()
assert m.training is False
assert m.model.training is False
m.train()
assert m.training is True
assert m.model.training is True


def test_resize_token_embeddings_and_get_input_embeddings_warn(monkeypatch):
m = TestCustomModel.create("dummy")

# resize_token_embeddings: underlying model lacks the method, should warn and not raise
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
m.resize_token_embeddings(10)
mocked_log.assert_called_once()

# get_input_embeddings: underlying model lacks method, should warn and return None
with mock.patch("QEfficient.finetune.experimental.core.model.logger.info") as mocked_log:
assert m.get_input_embeddings() is None
mocked_log.assert_called_once()


def test_state_dict_contains_inner_params():
m = TestCustomModel.create("dummy")
sd = m.state_dict()
# should contain params from TestMockModel.linear
assert any("linear.weight" in k for k in sd)
assert any("linear.bias" in k for k in sd)


# HFModel tests
def test_hfmodel_invalid_auto_class_raises():
with pytest.raises(ValueError):
HFModel.create("hf-name", auto_class_name="AutoDoesNotExist")


def test_hfmodel_loads_auto_and_tokenizer(monkeypatch):
# fake HF Auto class
class FakeAuto(nn.Module):
@classmethod
def from_pretrained(cls, name, **kwargs):
inst = cls()
inst.loaded = (name, kwargs)
return inst

def forward(self, x):
return x

fake_tok = mock.Mock()

# Monkeypatch transformer classes used in HFModel
monkeypatch.setattr(
"QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM",
FakeAuto,
raising=False,
)
monkeypatch.setattr(
model,
"AutoTokenizer",
mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)),
)
monkeypatch.setattr(
"QEfficient.finetune.experimental.core.model.insert_pad_token",
mock.Mock(),
raising=False,
)

m = HFModel.create("hf-name")
assert isinstance(m.model, FakeAuto)

# load tokenizer
tok = m.load_tokenizer()

Check failure on line 139 in QEfficient/finetune/experimental/tests/test_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F841)

QEfficient/finetune/experimental/tests/test_model.py:139:5: F841 Local variable `tok` is assigned to but never used

# tokenizer was loaded and pad token inserted
model.AutoTokenizer.from_pretrained.assert_called_once_with("hf-name")
model.insert_pad_token.assert_called_once_with(fake_tok)
Loading