From 7899a668cde8737b75c2b6250fe1ccb95983dd15 Mon Sep 17 00:00:00 2001 From: jlbmorales <163588116+jlbmorales@users.noreply.github.com> Date: Tue, 22 Apr 2025 15:11:48 -0700 Subject: [PATCH] Update sam2_base.py Fixes bug where "clear_old_points" crashes due to assertion expecting no previous SAM mask logits --- torchao/_models/sam2/modeling/sam2_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 4c2a24a0ef..5c4eda1d6c 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -788,9 +788,10 @@ def _track_step( if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits + else: + assert mask_inputs is None multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - - assert mask_inputs is None + assert multimask_output if point_inputs is not None: point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs}