Skip to content

Commit bc89ce1

Browse files
authored
Rename get_frames_at and get_frames_displayed_at (#302)
1 parent aa428fb commit bc89ce1

File tree

4 files changed

+28
-28
lines changed

4 files changed

+28
-28
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1)
6565
# pts_seconds: 9.960000038146973
6666
# duration_seconds: 0.03999999910593033
6767

68-
decoder.get_frames_at(start=10, stop=30, step=5)
68+
decoder.get_frames_in_range(start=10, stop=30, step=5)
6969
# FrameBatch:
7070
# data (shape): torch.Size([4, 3, 400, 640])
7171
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])

examples/basic_example.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
120120
# their :term:`pts` (Presentation Time Stamp), and their duration.
121121
# This can be achieved using the
122122
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
123-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
123+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which
124124
# will return a :class:`~torchcodec.Frame` and
125125
# :class:`~torchcodec.FrameBatch` objects respectively.
126126

@@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
129129
print(last_frame)
130130

131131
# %%
132-
middle_frames = decoder.get_frames_at(start=10, stop=20, step=2)
132+
middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2)
133133
print(f"{type(middle_frames) = }")
134134
print(middle_frames)
135135

@@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
152152
# So far, we have retrieved frames based on their index. We can also retrieve
153153
# frames based on *when* they are displayed with
154154
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
155-
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
155+
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which
156156
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
157157
# respectively.
158158

@@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
161161
print(frame_at_2_seconds)
162162

163163
# %%
164-
first_two_seconds = decoder.get_frames_displayed_at(
164+
first_two_seconds = decoder.get_frames_displayed_in_range(
165165
start_seconds=0,
166166
stop_seconds=2,
167167
)

src/torchcodec/decoders/_video_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_frame_at(self, index: int) -> Frame:
181181
duration_seconds=duration_seconds.item(),
182182
)
183183

184-
def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
184+
def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch:
185185
"""Return multiple frames at the given index range.
186186
187187
Frames are in [start, stop).
@@ -238,7 +238,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame:
238238
duration_seconds=duration_seconds.item(),
239239
)
240240

241-
def get_frames_displayed_at(
241+
def get_frames_displayed_in_range(
242242
self, start_seconds: float, stop_seconds: float
243243
) -> FrameBatch:
244244
"""Returns multiple frames in the given range.

test/decoders/test_video_decoder.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,14 @@ def test_get_frame_displayed_at_fails(self):
366366
frame = decoder.get_frame_displayed_at(100.0) # noqa
367367

368368
@pytest.mark.parametrize("stream_index", [0, 3, None])
369-
def test_get_frames_at(self, stream_index):
369+
def test_get_frames_in_range(self, stream_index):
370370
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
371371

372372
# test degenerate case where we only actually get 1 frame
373373
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
374374
start=9, stop=10, stream_index=stream_index
375375
)
376-
frames9 = decoder.get_frames_at(start=9, stop=10)
376+
frames9 = decoder.get_frames_in_range(start=9, stop=10)
377377

378378
assert_tensor_equal(ref_frames9, frames9.data)
379379
assert frames9.pts_seconds[0].item() == pytest.approx(
@@ -389,7 +389,7 @@ def test_get_frames_at(self, stream_index):
389389
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(
390390
start=0, stop=10, stream_index=stream_index
391391
)
392-
frames0_9 = decoder.get_frames_at(start=0, stop=10)
392+
frames0_9 = decoder.get_frames_in_range(start=0, stop=10)
393393
assert frames0_9.data.shape == torch.Size(
394394
[
395395
10,
@@ -412,7 +412,7 @@ def test_get_frames_at(self, stream_index):
412412
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(
413413
start=0, stop=10, step=2, stream_index=stream_index
414414
)
415-
frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2)
415+
frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2)
416416
assert frames0_8_2.data.shape == torch.Size(
417417
[
418418
5,
@@ -434,13 +434,13 @@ def test_get_frames_at(self, stream_index):
434434
)
435435

436436
# test numpy.int64 for indices
437-
frames0_8_2 = decoder.get_frames_at(
437+
frames0_8_2 = decoder.get_frames_in_range(
438438
start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2)
439439
)
440440
assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data)
441441

442442
# an empty range is valid!
443-
empty_frames = decoder.get_frames_at(5, 5)
443+
empty_frames = decoder.get_frames_in_range(5, 5)
444444
assert_tensor_equal(
445445
empty_frames.data,
446446
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index),
@@ -456,10 +456,10 @@ def test_get_frames_at(self, stream_index):
456456
(
457457
lambda decoder: decoder[0],
458458
lambda decoder: decoder.get_frame_at(0).data,
459-
lambda decoder: decoder.get_frames_at(0, 4).data,
459+
lambda decoder: decoder.get_frames_in_range(0, 4).data,
460460
lambda decoder: decoder.get_frame_displayed_at(0).data,
461461
# TODO: uncomment once D60001893 lands
462-
# lambda decoder: decoder.get_frames_displayed_at(0, 1).data,
462+
# lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data,
463463
),
464464
)
465465
def test_dimension_order(self, dimension_order, frame_getter):
@@ -487,7 +487,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
487487
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)
488488

489489
# Note that we are comparing the results of VideoDecoder's method:
490-
# get_frames_displayed_at()
490+
# get_frames_displayed_in_range()
491491
# With the testing framework's method:
492492
# get_frame_data_by_range()
493493
# That is, we are testing the correctness of a pts-based range against an index-
@@ -504,7 +504,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
504504
# value for frame 5 that we have access to on the Python side is slightly less than the pts
505505
# value on the C++ side. This test still produces the correct result because a slightly
506506
# less value still falls into the correct window.
507-
frames0_4 = decoder.get_frames_displayed_at(
507+
frames0_4 = decoder.get_frames_displayed_in_range(
508508
decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds
509509
)
510510
assert_tensor_equal(
@@ -513,15 +513,15 @@ def test_get_frames_by_pts_in_range(self, stream_index):
513513
)
514514

515515
# Range where the stop seconds is about halfway between pts values for two frames.
516-
also_frames0_4 = decoder.get_frames_displayed_at(
516+
also_frames0_4 = decoder.get_frames_displayed_in_range(
517517
decoder.get_frame_at(0).pts_seconds,
518518
decoder.get_frame_at(4).pts_seconds + HALF_DURATION,
519519
)
520520
assert_tensor_equal(also_frames0_4.data, frames0_4.data)
521521

522522
# Again, the intention here is to provide the exact values we care about. In practice, our
523523
# pts values are slightly smaller, so we nudge the start upwards.
524-
frames5_9 = decoder.get_frames_displayed_at(
524+
frames5_9 = decoder.get_frames_displayed_in_range(
525525
decoder.get_frame_at(5).pts_seconds,
526526
decoder.get_frame_at(10).pts_seconds,
527527
)
@@ -533,7 +533,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
533533
# Range where we provide start_seconds and stop_seconds that are different, but
534534
# also should land in the same window of time between two frame's pts values. As
535535
# a result, we should only get back one frame.
536-
frame6 = decoder.get_frames_displayed_at(
536+
frame6 = decoder.get_frames_displayed_in_range(
537537
decoder.get_frame_at(6).pts_seconds,
538538
decoder.get_frame_at(6).pts_seconds + HALF_DURATION,
539539
)
@@ -543,7 +543,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
543543
)
544544

545545
# Very small range that falls in the same frame.
546-
frame35 = decoder.get_frames_displayed_at(
546+
frame35 = decoder.get_frames_displayed_in_range(
547547
decoder.get_frame_at(35).pts_seconds,
548548
decoder.get_frame_at(35).pts_seconds + 1e-10,
549549
)
@@ -555,7 +555,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
555555
# Single frame where the start seconds is before frame i's pts, and the stop is
556556
# after frame i's pts, but before frame i+1's pts. In that scenario, we expect
557557
# to see frames i-1 and i.
558-
frames7_8 = decoder.get_frames_displayed_at(
558+
frames7_8 = decoder.get_frames_displayed_in_range(
559559
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
560560
- HALF_DURATION,
561561
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
@@ -567,7 +567,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
567567
)
568568

569569
# Start and stop seconds are the same value, which should not return a frame.
570-
empty_frame = decoder.get_frames_displayed_at(
570+
empty_frame = decoder.get_frames_displayed_in_range(
571571
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
572572
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
573573
)
@@ -583,7 +583,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
583583
)
584584

585585
# Start and stop seconds land within the first frame.
586-
frame0 = decoder.get_frames_displayed_at(
586+
frame0 = decoder.get_frames_displayed_in_range(
587587
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds,
588588
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds
589589
+ HALF_DURATION,
@@ -595,7 +595,7 @@ def test_get_frames_by_pts_in_range(self, stream_index):
595595

596596
# We should be able to get all frames by giving the beginning and ending time
597597
# for the stream.
598-
all_frames = decoder.get_frames_displayed_at(
598+
all_frames = decoder.get_frames_displayed_in_range(
599599
decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds
600600
)
601601
assert_tensor_equal(all_frames.data, decoder[:])
@@ -604,13 +604,13 @@ def test_get_frames_by_pts_in_range_fails(self):
604604
decoder = VideoDecoder(NASA_VIDEO.path)
605605

606606
with pytest.raises(ValueError, match="Invalid start seconds"):
607-
frame = decoder.get_frames_displayed_at(100.0, 1.0) # noqa
607+
frame = decoder.get_frames_displayed_in_range(100.0, 1.0) # noqa
608608

609609
with pytest.raises(ValueError, match="Invalid start seconds"):
610-
frame = decoder.get_frames_displayed_at(20, 23) # noqa
610+
frame = decoder.get_frames_displayed_in_range(20, 23) # noqa
611611

612612
with pytest.raises(ValueError, match="Invalid stop seconds"):
613-
frame = decoder.get_frames_displayed_at(0, 23) # noqa
613+
frame = decoder.get_frames_displayed_in_range(0, 23) # noqa
614614

615615

616616
if __name__ == "__main__":

0 commit comments

Comments
 (0)