Skip to content

Commit 31d86d5

Browse files
committed
remove uneccessary files
1 parent b05ad00 commit 31d86d5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1443
-986
lines changed

Diff for: src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,10 @@ def fasterprune(
179179
fake_quantize,
180180
)
181181

182-
while scale.ndim < 2:
183-
scale = scale.unsqueeze(1)
184-
zero_point = zero_point.unsqueeze(1)
185-
186-
while q.ndim < 2:
187-
q = q.unsqueeze(1)
188182
q = fake_quantize(
189-
q,
190-
scale[:, i],
191-
zero_point[:, i],
192-
self.layer.quantization_scheme.weights,
183+
q, scale, zero_point, self.layer.quantization_scheme.weights
193184
)
194185

195-
while q.ndim != 1:
196-
q.squeeze()
197-
198186
Q1[:, i] = q
199187
Losses1[:, i] = (w - q) ** 2 / d**2
200188

Diff for: src/sparseml/modifiers/quantization/pytorch.py

-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from sparseml.core import Event, EventType, State
2222
from sparseml.modifiers.quantization.base import QuantizationModifier
23-
from sparseml.modifiers.quantization.modification import modify_model
2423
from sparseml.modifiers.quantization.utils.helpers import (
2524
configure_module_bn_wrappers,
2625
freeze_bn_stats,
@@ -74,16 +73,11 @@ def __init__(self, **kwargs):
7473

7574
def on_initialize_structure(self, state: State, **kwargs):
7675
module = state.model.model
77-
# before the structure is modified to support quantization,
78-
# we need to potentially modify the model architecture
79-
module = modify_model(module)
8076
self._enable_module_qat(module)
8177
state.model.model.apply(torch.quantization.disable_observer)
8278

8379
def on_initialize(self, state: State, **kwargs) -> bool:
8480
raise_if_torch_quantization_not_available()
85-
module = state.model.model
86-
module = modify_model(module)
8781
if self.end and self.end != -1:
8882
raise ValueError(
8983
"end_epoch is disabled for QuantizationModifier and can only be set to"

Diff for: src/sparseml/modifiers/quantization_vllm/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ class vLLMQuantizationModifier(Modifier):
4343
not be updated. Leave None to not disable observers during QAT. Default is None
4444
:param num_calibration_steps: Number of steps to run post training calibration for.
4545
When None, the entire calibration_dataloader is used
46+
:param post_oneshot_calibration: Whether to rerun calibration on finalization
4647
"""
4748

4849
config_groups: Dict[str, QuantizationScheme]
4950
ignore: List[str] = Field(default_factory=list)
5051
disable_quantization_observer_epoch: Optional[float] = None
5152
num_calibration_steps: Optional[int] = None
53+
post_oneshot_calibration: Optional[bool] = False
5254

5355
def create_init_config(self) -> QuantizationConfig:
5456
return QuantizationConfig(

Diff for: src/sparseml/modifiers/quantization_vllm/pytorch.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
set_module_for_calibration,
2424
)
2525
from sparseml.core import Event, EventType, State
26-
from sparseml.modifiers.quantization.modification import modify_model
2726
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
2827
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
2928

@@ -51,7 +50,6 @@ class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
5150

5251
def on_initialize_structure(self, state: State, **kwargs):
5352
module = state.model.model
54-
module = modify_model(module)
5553
self._apply_modifier_to_model(module)
5654
module.apply(freeze_module_quantization)
5755

@@ -64,7 +62,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6462

6563
self.calibration_dataloader_ = state.data.calib
6664
module = state.model.model
67-
module = modify_model(module)
6865

6966
# intialize quantization in appropriate modules
7067
self._apply_modifier_to_model(module)
@@ -77,6 +74,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:
7774
return True
7875

7976
def on_finalize(self, state: State, **kwargs) -> bool:
77+
module = state.model.model
78+
if self.post_oneshot_calibration:
79+
module.apply(set_module_for_calibration)
80+
self._calibrate_if_possible(module)
81+
module.apply(freeze_module_quantization)
8082
return True
8183

8284
def on_start(self, state: State, event: Event, **kwargs):

Diff for: src/sparseml/modifiers/smoothquant/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from dataclasses import dataclass
1717
from typing import Dict, Generic, List, Optional, Tuple, TypeVar
1818

19+
from pydantic import Field
20+
1921
from sparseml.core import Modifier
2022
from sparseml.core.model import ModifiableModel
2123
from sparseml.core.model.base import LT
@@ -96,7 +98,7 @@ class SmoothQuantModifier(Modifier):
9698
use the whole dataset
9799
"""
98100

99-
smoothing_strength: float = 0.5
101+
smoothing_strength: float = Field(validation_alias="alpha", default=0.5)
100102
mappings: List[Tuple]
101103
ignore: Optional[List[str]] = None
102104
num_calibration_steps: Optional[int] = None

Diff for: src/sparseml/transformers/sparsification/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
# flake8: noqa
2121

22-
from .modification import *
2322
from .question_answering import *
2423
from .sparse_config import *
2524
from .sparse_model import *

Diff for: src/sparseml/transformers/sparsification/compressed_tensors_utils.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -83,42 +83,28 @@ def save_pretrained_wrapper(
8383
# state_dict gets passed in as a kwarg for FSDP models
8484
state_dict = kwargs.get("state_dict", None)
8585

86-
# check if we are in the old quantization framework
87-
if qat_active(model) and not is_model_quantized(model):
86+
if qat_active(model) or is_model_quantized(model):
8887
_LOGGER.info(
89-
"Compression for models quantized with QuantizationModifer is not "
90-
"supported. Save will be run without compression and no sparsity "
91-
"statistics will be calculated. To save a quantized model in a "
92-
"compressed state please use vLLMQuantizationModifier instead."
88+
"Compression for quantized models is not yet supported. Save will "
89+
"be run without compression and no sparsity statistics will be "
90+
"calculated."
9391
)
9492

9593
original_save_pretrained.__get__(model, model_class)(
9694
save_directory, **kwargs
9795
)
9896

99-
return
100-
101-
elif qat_active(model): # quantized in new framework
102-
_LOGGER.info(
103-
"Sparsity compression for quantized models is not yet supported. "
104-
"No sparsity statistics will be calculated and no sparsity config "
105-
"will be saved."
106-
)
107-
108-
original_save_pretrained.__get__(model, model_class)(
109-
save_directory, **kwargs
110-
)
97+
if is_model_quantized(model):
98+
quant_config = QuantizationConfig.from_pretrained(model)
99+
quant_config_data = quant_config.dict()
100+
config_file_path = os.path.join(save_directory, CONFIG_NAME)
111101

112-
quant_config = QuantizationConfig.from_pretrained(model)
113-
quant_config_data = quant_config.model_dump(exclude_unset=True)
114-
config_file_path = os.path.join(save_directory, CONFIG_NAME)
115-
116-
# add the sparsity config to the model's config file
117-
with open(config_file_path, "r") as config_file:
118-
config_data = json.load(config_file)
119-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
120-
with open(config_file_path, "w") as config_file:
121-
json.dump(config_data, config_file, indent=2, sort_keys=True)
102+
# add the sparsity config to the model's config file
103+
with open(config_file_path, "r") as config_file:
104+
config_data = json.load(config_file)
105+
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
106+
with open(config_file_path, "w") as config_file:
107+
json.dump(config_data, config_file, indent=2, sort_keys=True)
122108

123109
return
124110

@@ -140,7 +126,7 @@ def save_pretrained_wrapper(
140126
"calculation of compression statistics set "
141127
"skip_compression_stats=True"
142128
)
143-
sparsity_config = SparsityConfigMetadata.from_pretrained(
129+
sparsity_config = SparsityConfigMetadata.infer_config_from_model(
144130
model, state_dict=state_dict, compress=save_compressed
145131
)
146132

Diff for: src/sparseml/transformers/sparsification/modification/__init__.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
# flake8: noqa
16-
# isort:skip_file
17-
18-
# the modification module that adds modifications
19-
# for transformers models to enable quantization
20-
21-
# import all the modification functions for the different models
22-
from .modifying_bert import modify
23-
from .modifying_llama import modify
24-
from .modifying_mistral import modify
25-
from .modifying_distilbert import modify
26-
from .modifying_mobilebert import modify
27-
from .modifying_opt import modify
28-
from .modifying_qwen2_moe import modify
15+
from .modify_model import modify_model
16+
from .modifying_bert import *
17+
from .modifying_distilbert import *
18+
from .modifying_llama import *
19+
from .modifying_mistral import *
20+
from .modifying_mobilebert import *
21+
from .modifying_opt import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Set of helper objects that are used to modify
17+
the HuggingFace transformer models
18+
"""
19+
20+
import torch
21+
22+
23+
__all__ = [
24+
"QuantizableIdentity",
25+
"QuantizableMatMul",
26+
"QuantizableBatchMatmul",
27+
"QATMatMul",
28+
"QATLinear",
29+
]
30+
31+
32+
class QuantizableIdentity(torch.nn.Module):
33+
"""
34+
Identity model that is introduced to be used
35+
together with QuantizableMatMul to allow for
36+
SparseML quantization scheme
37+
"""
38+
39+
def forward(self, x):
40+
return x
41+
42+
43+
class QuantizableMatMul(torch.nn.Module):
44+
"""
45+
Wrapper around torch.matmul with distinct inputs/output class
46+
instances that could be quantized through SparseML recipe
47+
48+
:param left_input_cls: class instance that is used to quantize the left input
49+
:param right_input_cls: class instance that is used to quantize the right input
50+
:param output_cls: class instance that is used to quantize the output (optional)
51+
:return: the output of the matrix multiplication
52+
"""
53+
54+
def __init__(self, left_input_cls, right_input_cls, output_cls=None):
55+
super().__init__()
56+
self.left_input = left_input_cls()
57+
self.right_input = right_input_cls()
58+
self.output = output_cls() if output_cls is not None else None
59+
60+
def forward(self, a: torch.Tensor, b: torch.Tensor):
61+
out = torch.matmul(self.left_input(a), self.right_input(b))
62+
if self.output is not None:
63+
return self.output(out)
64+
return out
65+
66+
67+
class QuantizableBatchMatmul(QuantizableMatMul):
68+
"""
69+
Wrapper around torch.bmm with distinct inputs/output class
70+
instances that could be quantized through SparseML recipe
71+
72+
:param left_input_cls: class instance that is used to quantize the left input
73+
:param right_input_cls: class instance that is used to quantize the right input
74+
:param output_cls: class instance that is used to quantize the output (optional)
75+
:return: the output of the batch matrix multiplication
76+
"""
77+
78+
def forward(self, a: torch.Tensor, b: torch.Tensor):
79+
out = torch.bmm(self.left_input(a), self.right_input(b))
80+
if self.output is not None:
81+
return self.output(out)
82+
return out
83+
84+
85+
class QATMatMul(torch.nn.Module):
86+
"""
87+
Behaves like normal torch.matmul unless a SparseML QuantizationModifier
88+
is initialized (Quantization-Aware-Training is invoked)
89+
"""
90+
91+
def __init__(self):
92+
super().__init__()
93+
94+
self.wrap_qat = True
95+
self.qat_wrapper_kwargs = {
96+
"num_inputs": 2,
97+
"input_qconfigs": ["asymmetric", "symmetric"],
98+
}
99+
100+
def forward(self, a: torch.Tensor, b: torch.Tensor):
101+
return torch.matmul(a, b)
102+
103+
104+
class QATLinear(torch.nn.Module):
105+
"""
106+
Behaves like normal torch.nn.Linear unless a SparseML QuantizationModifier
107+
is initialized (Quantization-Aware-Training is invoked)
108+
When initialized does not quantize inputs. Only weights are quantized
109+
(inputs may come quantized)
110+
"""
111+
112+
def __init__(self, in_features, out_features):
113+
super().__init__()
114+
115+
self.wrap_qat = True
116+
self.qat_wrapper_kwargs = {
117+
"num_inputs": 0,
118+
"num_outputs": 1,
119+
}
120+
121+
self.linear = torch.nn.Linear(in_features, out_features)
122+
123+
def forward(self, x: torch.Tensor):
124+
return self.linear(x)

0 commit comments

Comments
 (0)