@@ -84,6 +84,7 @@ def fasterprune(
84
84
prunem : int = 0 ,
85
85
blocksize : int = 128 ,
86
86
percdamp : float = 0.01 ,
87
+ preserve_sparsity_mask : bool = False ,
87
88
):
88
89
"""
89
90
Run pruning and quantization(if applicable) on the layer up to the target
@@ -95,6 +96,7 @@ def fasterprune(
95
96
:param blocksize: Number of columns to compress in one pass
96
97
:param percdamp: Amount of dampening to apply to H, as a fraction of the
97
98
diagonal norm
99
+ :param preserve_sparsity_mask: Extend or ignore the base sparsity mask
98
100
"""
99
101
final_shape = self .layer .weight .shape
100
102
final_dtype = self .layer .weight .dtype
@@ -123,6 +125,13 @@ def fasterprune(
123
125
Hinv = self .H
124
126
125
127
mask = None
128
+ if preserve_sparsity_mask :
129
+ # compute existing sparsity mask
130
+ mask = torch .where (
131
+ W == 0 ,
132
+ torch .tensor (1 , dtype = torch .bool ),
133
+ torch .tensor (0 , dtype = torch .bool ),
134
+ )
126
135
127
136
# See section 3.4 of https://arxiv.org/abs/2203.07259
128
137
for i1 in range (0 , self .columns , blocksize ):
@@ -138,12 +147,32 @@ def fasterprune(
138
147
if prunen == 0 :
139
148
if mask is not None :
140
149
mask1 = mask [:, i1 :i2 ]
150
+ if int (W1 .numel () * sparsity ) > mask1 .sum ():
151
+ # target sparsity is higher than base sparsity, extend mask1
152
+ tmp = (
153
+ (~ mask [:, i1 :i2 ])
154
+ * W1 ** 2
155
+ / (torch .diag (Hinv1 ).reshape ((1 , - 1 ))) ** 2
156
+ )
157
+ thresh = torch .sort (tmp .flatten ())[0 ][
158
+ int (tmp .numel () * sparsity )
159
+ ]
160
+ mask1 = tmp <= thresh
161
+ else :
162
+ raise ValueError (
163
+ "The target sparsity is lower than the sparsity "
164
+ "of the base model. Please retry "
165
+ "after turning preserve_sparsity_mask=False"
166
+ )
141
167
else :
142
168
tmp = W1 ** 2 / (torch .diag (Hinv1 ).reshape ((1 , - 1 ))) ** 2
143
169
thresh = torch .sort (tmp .flatten ())[0 ][int (tmp .numel () * sparsity )]
144
170
mask1 = tmp <= thresh
145
171
else :
146
- mask1 = torch .zeros_like (W1 ) == 1
172
+ if mask is not None :
173
+ mask1 = mask [:, i1 :i2 ]
174
+ else :
175
+ mask1 = torch .zeros_like (W1 ) == 1
147
176
148
177
for i in range (count ):
149
178
w = W1 [:, i ]
@@ -154,6 +183,10 @@ def fasterprune(
154
183
W1 [:, i : (i + prunem )] ** 2
155
184
/ (torch .diag (Hinv1 )[i : (i + prunem )].reshape ((1 , - 1 ))) ** 2
156
185
)
186
+
187
+ if mask is not None :
188
+ tmp = tmp * (~ mask [:, i : (i + prunem )])
189
+
157
190
mask1 .scatter_ (
158
191
1 , i + torch .topk (tmp , prunen , dim = 1 , largest = False )[1 ], True
159
192
)
@@ -174,7 +207,12 @@ def fasterprune(
174
207
W [:, i1 :i2 ] = Q1
175
208
Losses += torch .sum (Losses1 , 1 ) / 2
176
209
177
- W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
210
+ if preserve_sparsity_mask :
211
+ # respect the sparsity of other groups
212
+ # really not needed, but kept for explicitness
213
+ W [:, i2 :] -= (~ mask [:, i2 :]) * Err1 .matmul (Hinv [i1 :i2 , i2 :])
214
+ else :
215
+ W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
178
216
179
217
_LOGGER .info ("time %.2f" % (time .time () - tick ))
180
218
_LOGGER .info ("error %.2f" % torch .sum (Losses ).item ())
0 commit comments