Skip to content

Commit 3f18cfb

Browse files
Naveen Sudafacebook-github-bot
Naveen Suda
authored andcommitted
compare prepared vs. converted outputs for Embedding (pytorch#2087)
Summary: Fixed the embedding op and updated the test. Reviewed By: telgamal-1, jerryzh168 Differential Revision: D73266106
1 parent 4805efd commit 3f18cfb

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/quantization/test_qat.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,8 @@ def test_qat_4w_embedding(self):
808808
_quantized_decomposed_quantize_per_channel_group_wrapper,
809809
)
810810
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
811+
from torchao.quantization.utils import compute_error
812+
811813

812814
group_size = 256
813815
model = M2()
@@ -816,9 +818,9 @@ def test_qat_4w_embedding(self):
816818
quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size)
817819
prepared = quantizer.prepare(model)
818820
prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight)
819-
prepared(*x)
820-
converted = quantizer.convert(model)
821-
converted(*x)
821+
prepared_output = prepared(*x)
822+
converted = quantizer.convert(copy.deepcopy(prepared))
823+
converted_output = converted(*x)
822824

823825
# Assert the scales, zero points, and weights are correct after convert
824826
qmin, qmax = -8, 7
@@ -837,9 +839,12 @@ def test_qat_4w_embedding(self):
837839
torch.int8,
838840
group_size,
839841
)
842+
sqnr = compute_error(prepared_output, converted_output).detach().item()
840843
torch.testing.assert_close(converted.embedding.weight, q_weight)
841844
torch.testing.assert_close(converted.embedding.scale, s)
842845
torch.testing.assert_close(converted.embedding.zero_point, zp)
846+
torch.testing.assert_close(sqnr, float('inf'))
847+
843848

844849
def test_fake_quantize_config_granularity(self):
845850
"""

torchao/quantization/qat/embedding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def _convert_helper(self, module: torch.nn.Module):
251251
group_size=group_size,
252252
scale_precision=scale_precision,
253253
zero_point_precision=zero_point_precision,
254+
weight_original_precision=child.weight.dtype,
254255
device=child.weight.device,
255256
)
256257
setattr(module, name, quantized_embedding)
@@ -336,6 +337,7 @@ def __init__(
336337
group_size: int = 32,
337338
scale_precision: torch.dtype = torch.float32,
338339
zero_point_precision: torch.dtype = torch.int32,
340+
weight_original_precision: torch.dtype = torch.float32,
339341
device: torch.device = None,
340342
):
341343
super().__init__()
@@ -354,6 +356,7 @@ def __init__(
354356
self.group_size = group_size
355357
self.scale_precision = scale_precision
356358
self.zero_point_precision = zero_point_precision
359+
self.weight_original_precision = weight_original_precision
357360

358361
# currently storing unpacked int8 weights
359362
self.register_buffer(
@@ -393,7 +396,7 @@ def forward(self, x):
393396
qmax,
394397
torch.int8,
395398
self.group_size,
396-
x.dtype,
399+
self.weight_original_precision,
397400
)
398401
return F.embedding(
399402
x,

0 commit comments

Comments
 (0)