Skip to content

Commit 6270dd3

Browse files
committed
Added the ability to benchmark on a directory and output to a chart
1 parent b942423 commit 6270dd3

File tree

2 files changed

+153
-6
lines changed

2 files changed

+153
-6
lines changed

benchmarks/decoders/benchmark_decoders.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from benchmark_decoders_library import (
1313
DecordNonBatchDecoderAccurateSeek,
14+
plot_data,
1415
run_benchmarks,
1516
TorchAudioDecoder,
1617
TorchcodecCompiled,
@@ -71,6 +72,18 @@ def main() -> None:
7172
type=str,
7273
default="decord,tcoptions:,torchvision,torchaudio,torchcodec_compiled,tcoptions:num_threads=1",
7374
)
75+
parser.add_argument(
76+
"--bm_video_dir",
77+
help="Directory where video files reside. We will run benchmarks on all .mp4 files in this directory.",
78+
type=str,
79+
default="",
80+
)
81+
parser.add_argument(
82+
"--plot_path",
83+
help="Path where the generated plot is stored, if non-empty",
84+
type=str,
85+
default="",
86+
)
7487

7588
args = parser.parse_args()
7689
decoders = set(args.decoders.split(","))
@@ -118,13 +131,21 @@ def main() -> None:
118131
decoder_dict["TorchcodecNonCompiled:" + options] = (
119132
TorchcodecNonCompiledWithOptions(**kwargs_dict)
120133
)
121-
run_benchmarks(
134+
video_paths = args.bm_video_paths.split(",")
135+
if args.bm_video_dir:
136+
video_paths = []
137+
for entry in os.scandir(args.bm_video_dir):
138+
if entry.is_file() and entry.name.endswith(".mp4"):
139+
video_paths.append(entry.path)
140+
141+
df_data = run_benchmarks(
122142
decoder_dict,
123-
args.bm_video_paths,
143+
video_paths,
124144
num_uniform_samples,
125145
args.bm_video_speed_min_run_seconds,
126146
args.bm_video_creation,
127147
)
148+
plot_data(df_data, args.plot_path)
128149

129150

130151
if __name__ == "__main__":

benchmarks/decoders/benchmark_decoders_library.py

+130-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import abc
22
import json
3+
import os
34
import timeit
45

6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
510
import torch
611
import torch.utils.benchmark as benchmark
712
from torchcodec.decoders import VideoDecoder
@@ -118,17 +123,19 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
118123

119124

120125
class TorchcodecNonCompiledWithOptions(AbstractDecoder):
121-
def __init__(self, num_threads=None, color_conversion_library=None):
126+
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
122127
self._print_each_iteration_time = False
123128
self._num_threads = int(num_threads) if num_threads else None
124129
self._color_conversion_library = color_conversion_library
130+
self._device = device
125131

126132
def get_frames_from_video(self, video_file, pts_list):
127133
decoder = create_from_file(video_file)
128134
_add_video_stream(
129135
decoder,
130136
num_threads=self._num_threads,
131137
color_conversion_library=self._color_conversion_library,
138+
device=self._device,
132139
)
133140
frames = []
134141
times = []
@@ -292,6 +299,97 @@ def create_torchcodec_decoder_from_file(video_file):
292299
return video_decoder
293300

294301

302+
def plot_data(df_data, plot_path):
303+
# Creating the DataFrame
304+
df = pd.DataFrame(df_data)
305+
306+
# Sorting by video, type, and frame_count
307+
df_sorted = df.sort_values(by=["video", "type", "frame_count"])
308+
309+
# Group by video first
310+
grouped_by_video = df_sorted.groupby("video")
311+
312+
# Define colors (consistent across decoders)
313+
colors = plt.get_cmap("tab10")
314+
315+
# Find the unique combinations of (type, frame_count) per video
316+
video_type_combinations = {
317+
video: video_group.groupby(["type", "frame_count"]).ngroups
318+
for video, video_group in grouped_by_video
319+
}
320+
321+
# Get the unique videos and the maximum number of (type, frame_count) combinations per video
322+
unique_videos = list(video_type_combinations.keys())
323+
max_combinations = max(video_type_combinations.values())
324+
325+
# Create subplots: each row is a video, and each column is for a unique (type, frame_count)
326+
fig, axes = plt.subplots(
327+
nrows=len(unique_videos),
328+
ncols=max_combinations,
329+
figsize=(max_combinations * 6, len(unique_videos) * 4),
330+
sharex=True,
331+
sharey=True,
332+
)
333+
334+
# Handle cases where there's only one row or column
335+
if len(unique_videos) == 1:
336+
axes = np.array([axes]) # Make sure axes is a list of lists
337+
if max_combinations == 1:
338+
axes = np.expand_dims(axes, axis=1) # Ensure a 2D array for axes
339+
340+
# Loop through each video and its sub-groups
341+
for row, (video, video_group) in enumerate(grouped_by_video):
342+
sub_group = video_group.groupby(["type", "frame_count"])
343+
344+
# Loop through each (type, frame_count) group for this video
345+
for col, ((vtype, vcount), group) in enumerate(sub_group):
346+
ax = axes[row, col] # Select the appropriate axis
347+
348+
# Set the title for the subplot
349+
base_video = os.path.basename(video)
350+
ax.set_title(
351+
f"video={base_video}\ndecode_pattern={vcount} x {vtype}", fontsize=12
352+
)
353+
354+
# Plot bars with error bars
355+
ax.barh(
356+
group["decoder"],
357+
group["fps"],
358+
xerr=[group["fps"] - group["fps_p75"], group["fps_p25"] - group["fps"]],
359+
color=[colors(i) for i in range(len(group))],
360+
align="center",
361+
capsize=5,
362+
)
363+
364+
# Set the labels
365+
ax.set_xlabel("FPS")
366+
ax.set_ylabel("Decoder")
367+
368+
# Reverse the order of the handles and labels to match the order of the bars
369+
handles = [
370+
plt.Rectangle((0, 0), 1, 1, color=colors(i)) for i in range(len(group))
371+
]
372+
ax.legend(
373+
handles[::-1],
374+
group["decoder"][::-1],
375+
title="Decoder",
376+
loc="upper right",
377+
)
378+
379+
# Remove any empty subplots for videos with fewer combinations
380+
for row in range(len(unique_videos)):
381+
for col in range(video_type_combinations[unique_videos[row]], max_combinations):
382+
fig.delaxes(axes[row, col])
383+
384+
# Adjust layout to avoid overlap
385+
plt.tight_layout()
386+
387+
# Show plot
388+
plt.savefig(
389+
plot_path,
390+
)
391+
392+
295393
def run_benchmarks(
296394
decoder_dict,
297395
video_paths,
@@ -300,9 +398,11 @@ def run_benchmarks(
300398
benchmark_video_creation,
301399
):
302400
results = []
401+
df_data = []
402+
print(f"video_paths={video_paths}")
303403
verbose = False
304404
for decoder_name, decoder in decoder_dict.items():
305-
for video_path in video_paths.split(","):
405+
for video_path in video_paths:
306406
print(f"video={video_path}, decoder={decoder_name}")
307407
# We only use the VideoDecoder to get the metadata and get
308408
# the list of PTS values to seek to.
@@ -331,6 +431,19 @@ def run_benchmarks(
331431
results.append(
332432
seeked_result.blocked_autorange(min_run_time=min_runtime_seconds)
333433
)
434+
df_item = {}
435+
df_item["decoder"] = decoder_name
436+
df_item["video"] = video_path
437+
df_item["description"] = results[-1].description
438+
df_item["frame_count"] = num_uniform_samples
439+
df_item["median"] = results[-1].median
440+
df_item["iqr"] = results[-1].iqr
441+
df_item["type"] = "seek()+next()"
442+
df_item["fps"] = 1.0 * num_uniform_samples / results[-1].median
443+
df_item["fps_p75"] = 1.0 * num_uniform_samples / results[-1]._p75
444+
df_item["fps_p25"] = 1.0 * num_uniform_samples / results[-1]._p25
445+
df_data.append(df_item)
446+
334447
for num_consecutive_nexts in [1, 10]:
335448
consecutive_frames_result = benchmark.Timer(
336449
stmt="decoder.get_consecutive_frames_from_video(video_file, consecutive_frames_to_extract)",
@@ -348,8 +461,20 @@ def run_benchmarks(
348461
min_run_time=min_runtime_seconds
349462
)
350463
)
351-
352-
first_video_path = video_paths.split(",")[0]
464+
df_item = {}
465+
df_item["decoder"] = decoder_name
466+
df_item["video"] = video_path
467+
df_item["description"] = results[-1].description
468+
df_item["frame_count"] = num_consecutive_nexts
469+
df_item["median"] = results[-1].median
470+
df_item["iqr"] = results[-1].iqr
471+
df_item["type"] = "next()"
472+
df_item["fps"] = 1.0 * num_consecutive_nexts / results[-1].median
473+
df_item["fps_p75"] = 1.0 * num_consecutive_nexts / results[-1]._p75
474+
df_item["fps_p25"] = 1.0 * num_consecutive_nexts / results[-1]._p25
475+
df_data.append(df_item)
476+
477+
first_video_path = video_paths[0]
353478
if benchmark_video_creation:
354479
simple_decoder = VideoDecoder(first_video_path)
355480
metadata = simple_decoder.metadata
@@ -371,3 +496,4 @@ def run_benchmarks(
371496
)
372497
compare = benchmark.Compare(results)
373498
compare.print()
499+
return df_data

0 commit comments

Comments
 (0)