Skip to content

Commit c791cf2

Browse files
Naveen Sudafacebook-github-bot
Naveen Suda
authored andcommitted
compare prepared vs. converted outputs for Embedding
Summary: Fixed the embedding op and updated the test. Reviewed By: telgamal-1, jerryzh168 Differential Revision: D73266106
1 parent 0045d88 commit c791cf2

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/quantization/test_qat.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ def test_qat_4w_embedding(self):
808808
_quantized_decomposed_quantize_per_channel_group_wrapper,
809809
)
810810
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
811+
from torchao.quantization.utils import compute_error
812+
811813

812814
group_size = 256
813815
model = M2()
@@ -816,9 +818,9 @@ def test_qat_4w_embedding(self):
816818
quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size)
817819
prepared = quantizer.prepare(model)
818820
prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight)
819-
prepared(*x)
820-
converted = quantizer.convert(model)
821-
converted(*x)
821+
prepared_output = prepared(*x)
822+
converted = quantizer.convert(copy.deepcopy(prepared))
823+
converted_output = converted(*x)
822824

823825
# Assert the scales, zero points, and weights are correct after convert
824826
qmin, qmax = -8, 7
@@ -837,9 +839,12 @@ def test_qat_4w_embedding(self):
837839
torch.int8,
838840
group_size,
839841
)
842+
sqnr = compute_error(prepared_output, converted_output).detach().item()
840843
torch.testing.assert_close(converted.embedding.weight, q_weight)
841844
torch.testing.assert_close(converted.embedding.scale, s)
842845
torch.testing.assert_close(converted.embedding.zero_point, zp)
846+
torch.testing.assert_close(sqnr, float('inf'))
847+
843848

844849
def test_fake_quantize_config_granularity(self):
845850
"""

torchao/quantization/qat/embedding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _convert_helper(self, module: torch.nn.Module):
226226
group_size=group_size,
227227
scale_precision=scale_precision,
228228
zero_point_precision=zero_point_precision,
229+
weight_original_precision=child.weight.dtype,
229230
device=child.weight.device,
230231
)
231232
setattr(module, name, quantized_embedding)
@@ -323,6 +324,7 @@ def __init__(
323324
group_size: int = 32,
324325
scale_precision: torch.dtype = torch.float32,
325326
zero_point_precision: torch.dtype = torch.int32,
327+
weight_original_precision: torch.dtype = torch.float32,
326328
device: torch.device = None,
327329
):
328330
super().__init__()
@@ -341,6 +343,7 @@ def __init__(
341343
self.group_size = group_size
342344
self.scale_precision = scale_precision
343345
self.zero_point_precision = zero_point_precision
346+
self.weight_original_precision = weight_original_precision
344347

345348
# currently storing unpacked int8 weights
346349
self.register_buffer(
@@ -380,7 +383,7 @@ def forward(self, x):
380383
qmax,
381384
torch.int8,
382385
self.group_size,
383-
x.dtype,
386+
self.weight_original_precision,
384387
)
385388
return F.embedding(
386389
x,

0 commit comments

Comments
 (0)