Skip to content

Commit 2088934

Browse files
committed
.
1 parent 2b47104 commit 2088934

File tree

3 files changed

+70
-43
lines changed

3 files changed

+70
-43
lines changed
Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
from pathlib import Path
23
from time import perf_counter_ns
34

@@ -45,51 +46,72 @@ def report_stats(times, num_frames, unit="ms"):
4546
return med, fps
4647

4748

48-
def sample(sampler, **kwargs):
49-
decoder = VideoDecoder(VIDEO_PATH)
49+
def sample(decoder, sampler, **kwargs):
5050
return sampler(
5151
decoder,
5252
num_frames_per_clip=10,
5353
**kwargs,
5454
)
5555

5656

57-
VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
58-
NUM_EXP = 30
59-
60-
for num_clips in (1, 50):
61-
print("-" * 10)
62-
print(f"{num_clips = }")
63-
64-
print("clips_at_random_indices ", end="")
65-
times, num_frames = bench(
66-
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
67-
)
68-
report_stats(times, num_frames, unit="ms")
69-
70-
print("clips_at_regular_indices ", end="")
71-
times, num_frames = bench(
72-
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
73-
)
74-
report_stats(times, num_frames, unit="ms")
75-
76-
print("clips_at_random_timestamps ", end="")
77-
times, num_frames = bench(
78-
sample,
79-
clips_at_random_timestamps,
80-
num_clips=num_clips,
81-
num_exp=NUM_EXP,
82-
warmup=2,
83-
)
84-
report_stats(times, num_frames, unit="ms")
85-
86-
print("clips_at_regular_timestamps ", end="")
87-
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
88-
times, num_frames = bench(
89-
sample,
90-
clips_at_regular_timestamps,
91-
seconds_between_clip_starts=seconds_between_clip_starts,
92-
num_exp=NUM_EXP,
93-
warmup=2,
94-
)
95-
report_stats(times, num_frames, unit="ms")
57+
def main(device, video):
58+
NUM_EXP = 30
59+
60+
for num_clips in (1, 50):
61+
print("-" * 10)
62+
print(f"{num_clips = }")
63+
64+
print("clips_at_random_indices ", end="")
65+
decoder = VideoDecoder(video, device=device)
66+
times, num_frames = bench(
67+
sample,
68+
decoder,
69+
clips_at_random_indices,
70+
num_clips=num_clips,
71+
num_exp=NUM_EXP,
72+
warmup=2,
73+
)
74+
report_stats(times, num_frames, unit="ms")
75+
76+
print("clips_at_regular_indices ", end="")
77+
times, num_frames = bench(
78+
sample,
79+
decoder,
80+
clips_at_regular_indices,
81+
num_clips=num_clips,
82+
num_exp=NUM_EXP,
83+
warmup=2,
84+
)
85+
report_stats(times, num_frames, unit="ms")
86+
87+
print("clips_at_random_timestamps ", end="")
88+
times, num_frames = bench(
89+
sample,
90+
decoder,
91+
clips_at_random_timestamps,
92+
num_clips=num_clips,
93+
num_exp=NUM_EXP,
94+
warmup=2,
95+
)
96+
report_stats(times, num_frames, unit="ms")
97+
98+
print("clips_at_regular_timestamps ", end="")
99+
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
100+
times, num_frames = bench(
101+
sample,
102+
decoder,
103+
clips_at_regular_timestamps,
104+
seconds_between_clip_starts=seconds_between_clip_starts,
105+
num_exp=NUM_EXP,
106+
warmup=2,
107+
)
108+
report_stats(times, num_frames, unit="ms")
109+
110+
111+
if __name__ == "__main__":
112+
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
113+
parser = argparse.ArgumentParser()
114+
parser.add_argument("--device", type=str, default="cpu")
115+
parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH))
116+
args = parser.parse_args()
117+
main(args.device, args.video)

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ void convertAVFrameToDecodedOutputOnCuda(
1919
const VideoDecoder::VideoStreamDecoderOptions& options,
2020
AVCodecContext* codecContext,
2121
VideoDecoder::RawDecodedOutput& rawOutput,
22-
VideoDecoder::DecodedOutput& output) {
22+
VideoDecoder::DecodedOutput& output,
23+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
2324
throwUnsupportedDeviceError(device);
2425
}
2526

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from typing import Literal, Optional, Tuple, Union
1010

11-
from torch import Tensor
11+
from torch import device, Tensor
1212

1313
from torchcodec import Frame, FrameBatch
1414
from torchcodec.decoders import _core as core
@@ -41,6 +41,8 @@ class VideoDecoder:
4141
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
4242
decoding which is best if you are running a single instance of ``VideoDecoder``.
4343
Default: 1.
44+
device (str or torch.device, optional): The device to use for decoding.
45+
4446
4547
.. note::
4648
@@ -64,6 +66,7 @@ def __init__(
6466
stream_index: Optional[int] = None,
6567
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
6668
num_ffmpeg_threads: int = 1,
69+
device: Optional[Union[str, device]] = "cpu",
6770
):
6871
if isinstance(source, str):
6972
self._decoder = core.create_from_file(source)
@@ -92,6 +95,7 @@ def __init__(
9295
stream_index=stream_index,
9396
dimension_order=dimension_order,
9497
num_threads=num_ffmpeg_threads,
98+
device=device,
9599
)
96100

97101
self.metadata, self.stream_index = _get_and_validate_stream_metadata(

0 commit comments

Comments
 (0)