Skip to content

Commit 446555f

Browse files
rahul-tuliSara Adkins
and
Sara Adkins
authored
Preserve sparsity SPARSEGPT (#2282)
* test * Preserve weight sparsity if greater than threshold * Add argument to preserve sparsity mask in SPARSEGPT * fix case when mask is none --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com>
1 parent 14a1b08 commit 446555f

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

src/sparseml/modifiers/obcq/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class SparseGPTModifier(Modifier):
5454
:param block_size: Used to determine number of columns to compress in one pass
5555
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
5656
diagonal norm
57+
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
58+
during when applying sparsegpt, this becomes useful when starting from a
59+
previously pruned model, defaults to False.
5760
"""
5861

5962
sparsity: Union[float, List[float]] = 0.0
@@ -68,6 +71,7 @@ class SparseGPTModifier(Modifier):
6871
prunem_: Optional[int] = None
6972
block_size: int = 128
7073
dampening_frac: Optional[float] = 0.01
74+
preserve_sparsity_mask: bool = False
7175

7276
def on_initialize_structure(self, state: State, **kwargs):
7377
"""

src/sparseml/modifiers/obcq/pytorch.py

+1
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _compression_arguments(self, sparsity):
203203
"prunem": self.prunem_,
204204
"blocksize": self.block_size,
205205
"percdamp": self.dampening_frac,
206+
"preserve_sparsity_mask": self.preserve_sparsity_mask,
206207
}
207208

208209
def _compression_class(self):

src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def fasterprune(
8484
prunem: int = 0,
8585
blocksize: int = 128,
8686
percdamp: float = 0.01,
87+
preserve_sparsity_mask: bool = False,
8788
):
8889
"""
8990
Run pruning and quantization(if applicable) on the layer up to the target
@@ -95,6 +96,7 @@ def fasterprune(
9596
:param blocksize: Number of columns to compress in one pass
9697
:param percdamp: Amount of dampening to apply to H, as a fraction of the
9798
diagonal norm
99+
:param preserve_sparsity_mask: Extend or ignore the base sparsity mask
98100
"""
99101
final_shape = self.layer.weight.shape
100102
final_dtype = self.layer.weight.dtype
@@ -123,6 +125,13 @@ def fasterprune(
123125
Hinv = self.H
124126

125127
mask = None
128+
if preserve_sparsity_mask:
129+
# compute existing sparsity mask
130+
mask = torch.where(
131+
W == 0,
132+
torch.tensor(1, dtype=torch.bool),
133+
torch.tensor(0, dtype=torch.bool),
134+
)
126135

127136
# See section 3.4 of https://arxiv.org/abs/2203.07259
128137
for i1 in range(0, self.columns, blocksize):
@@ -138,12 +147,32 @@ def fasterprune(
138147
if prunen == 0:
139148
if mask is not None:
140149
mask1 = mask[:, i1:i2]
150+
if int(W1.numel() * sparsity) > mask1.sum():
151+
# target sparsity is higher than base sparsity, extend mask1
152+
tmp = (
153+
(~mask[:, i1:i2])
154+
* W1**2
155+
/ (torch.diag(Hinv1).reshape((1, -1))) ** 2
156+
)
157+
thresh = torch.sort(tmp.flatten())[0][
158+
int(tmp.numel() * sparsity)
159+
]
160+
mask1 = tmp <= thresh
161+
else:
162+
raise ValueError(
163+
"The target sparsity is lower than the sparsity "
164+
"of the base model. Please retry "
165+
"after turning preserve_sparsity_mask=False"
166+
)
141167
else:
142168
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
143169
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
144170
mask1 = tmp <= thresh
145171
else:
146-
mask1 = torch.zeros_like(W1) == 1
172+
if mask is not None:
173+
mask1 = mask[:, i1:i2]
174+
else:
175+
mask1 = torch.zeros_like(W1) == 1
147176

148177
for i in range(count):
149178
w = W1[:, i]
@@ -154,6 +183,10 @@ def fasterprune(
154183
W1[:, i : (i + prunem)] ** 2
155184
/ (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2
156185
)
186+
187+
if mask is not None:
188+
tmp = tmp * (~mask[:, i : (i + prunem)])
189+
157190
mask1.scatter_(
158191
1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True
159192
)
@@ -174,7 +207,12 @@ def fasterprune(
174207
W[:, i1:i2] = Q1
175208
Losses += torch.sum(Losses1, 1) / 2
176209

177-
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
210+
if preserve_sparsity_mask:
211+
# respect the sparsity of other groups
212+
# really not needed, but kept for explicitness
213+
W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:])
214+
else:
215+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
178216

179217
_LOGGER.info("time %.2f" % (time.time() - tick))
180218
_LOGGER.info("error %.2f" % torch.sum(Losses).item())

0 commit comments

Comments
 (0)