Skip to content

Commit b78a737

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into mac_wheels_ci
2 parents 571ed17 + c8de21c commit b78a737

File tree

9 files changed

+96
-28
lines changed

9 files changed

+96
-28
lines changed

.github/workflows/macos_wheel.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
package-name: torchcodec
5656
trigger-event: ${{ github.event_name }}
5757
build-platform: "python-build-package"
58-
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ${CONDA_RUN} python3 -m build --wheel -vvv --no-isolation"
58+
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"
5959

6060
install-and-test:
6161
runs-on: macos-m1-stable

benchmarks/decoders/benchmark_decoders.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def get_frames_from_video(self, video_file, pts_list):
209209
best_video_stream = metadata["bestVideoStreamIndex"]
210210
indices_list = [int(pts * average_fps) for pts in pts_list]
211211
frames = []
212-
frames = get_frames_at_indices(
212+
frames, *_ = get_frames_at_indices(
213213
decoder, stream_index=best_video_stream, frame_indices=indices_list
214214
)
215215
return frames
@@ -226,7 +226,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
226226
best_video_stream = metadata["bestVideoStreamIndex"]
227227
frames = []
228228
indices_list = list(range(numFramesToDecode))
229-
frames = get_frames_at_indices(
229+
frames, *_ = get_frames_at_indices(
230230
decoder, stream_index=best_video_stream, frame_indices=indices_list
231231
)
232232
return frames

packaging/check_glibcxx.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838

3939
MAX_ALLOWED = (3, 4, 19)
4040

41-
libraries = sys.argv[1].split("\n")
41+
symbol_matches = sys.argv[1].split("\n")
4242
all_symbols = set()
43-
for line in libraries:
43+
for line in symbol_matches:
4444
# We search for GLIBCXX_major.minor.micro
4545
if match := re.search(r"GLIBCXX_\d+\.\d+\.\d+", line):
4646
all_symbols.add(match.group(0))
4747

4848
if not all_symbols:
49-
raise ValueError(f"No GLIBCXX symbols found in {libraries}. Something is wrong.")
49+
raise ValueError(f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong.")
5050

5151
all_versions = (symbol.split("_")[1].split(".") for symbol in all_symbols)
5252
all_versions = (tuple(int(v) for v in version) for version in all_versions)

src/torchcodec/_samplers/video_clip_sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _get_clips_for_index_based_sampling(
240240
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
241241
for i in range(index_based_sampler_args.frames_per_clip)
242242
]
243-
frames = get_frames_at_indices(
243+
frames, *_ = get_frames_at_indices(
244244
video_decoder,
245245
stream_index=metadata_json["bestVideoStreamIndex"],
246246
frame_indices=batch_indexes,

src/torchcodec/decoders/_core/VideoDecoder.cpp

+46-13
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,14 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
191191
int64_t numFrames,
192192
const VideoStreamDecoderOptions& options,
193193
const StreamMetadata& metadata)
194-
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
195-
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})),
196-
frames(torch::empty(
194+
: frames(torch::empty(
197195
{numFrames,
198196
options.height.value_or(*metadata.height),
199197
options.width.value_or(*metadata.width),
200198
3},
201-
{torch::kUInt8})) {}
199+
{torch::kUInt8})),
200+
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
201+
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
202202

203203
VideoDecoder::VideoDecoder() {}
204204

@@ -1017,24 +1017,57 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10171017
validateUserProvidedStreamIndex(streamIndex);
10181018
validateScannedAllStreams("getFramesAtIndices");
10191019

1020+
auto indicesAreSorted =
1021+
std::is_sorted(frameIndices.begin(), frameIndices.end());
1022+
1023+
std::vector<size_t> argsort;
1024+
if (!indicesAreSorted) {
1025+
// if frameIndices is [13, 10, 12, 11]
1026+
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
1027+
// to use to decode the frames
1028+
// and argsort is [ 1, 3, 2, 0]
1029+
argsort.resize(frameIndices.size());
1030+
for (size_t i = 0; i < argsort.size(); ++i) {
1031+
argsort[i] = i;
1032+
}
1033+
std::sort(
1034+
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
1035+
return frameIndices[a] < frameIndices[b];
1036+
});
1037+
}
1038+
10201039
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
10211040
const auto& stream = streams_[streamIndex];
10221041
const auto& options = stream.options;
10231042
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata);
10241043

1044+
auto previousIndexInVideo = -1;
10251045
for (auto f = 0; f < frameIndices.size(); ++f) {
1026-
auto frameIndex = frameIndices[f];
1027-
if (frameIndex < 0 || frameIndex >= stream.allFrames.size()) {
1046+
auto indexInOutput = indicesAreSorted ? f : argsort[f];
1047+
auto indexInVideo = frameIndices[indexInOutput];
1048+
if (indexInVideo < 0 || indexInVideo >= stream.allFrames.size()) {
10281049
throw std::runtime_error(
1029-
"Invalid frame index=" + std::to_string(frameIndex));
1050+
"Invalid frame index=" + std::to_string(indexInVideo));
10301051
}
1031-
DecodedOutput singleOut =
1032-
getFrameAtIndex(streamIndex, frameIndex, output.frames[f]);
1033-
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
1034-
output.frames[f] = singleOut.frame;
1052+
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
1053+
// Avoid decoding the same frame twice
1054+
auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
1055+
output.frames[indexInOutput].copy_(output.frames[previousIndexInOutput]);
1056+
output.ptsSeconds[indexInOutput] =
1057+
output.ptsSeconds[previousIndexInOutput];
1058+
output.durationSeconds[indexInOutput] =
1059+
output.durationSeconds[previousIndexInOutput];
1060+
} else {
1061+
DecodedOutput singleOut = getFrameAtIndex(
1062+
streamIndex, indexInVideo, output.frames[indexInOutput]);
1063+
if (options.colorConversionLibrary ==
1064+
ColorConversionLibrary::FILTERGRAPH) {
1065+
output.frames[indexInOutput] = singleOut.frame;
1066+
}
1067+
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
1068+
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
10351069
}
1036-
// Note that for now we ignore the pts and duration parts of the output,
1037-
// because they're never used in any caller.
1070+
previousIndexInVideo = indexInVideo;
10381071
}
10391072
output.frames = MaybePermuteHWC2CHW(options, output.frames);
10401073
return output;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
m.def(
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
43-
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
43+
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
4444
m.def(
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
@@ -218,15 +218,15 @@ OpsDecodedOutput get_frame_at_index(
218218
return makeOpsDecodedOutput(result);
219219
}
220220

221-
at::Tensor get_frames_at_indices(
221+
OpsBatchDecodedOutput get_frames_at_indices(
222222
at::Tensor& decoder,
223223
int64_t stream_index,
224224
at::IntArrayRef frame_indices) {
225225
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
226226
std::vector<int64_t> frameIndicesVec(
227227
frame_indices.begin(), frame_indices.end());
228228
auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec);
229-
return result.frames;
229+
return makeOpsBatchDecodedOutput(result);
230230
}
231231

232232
OpsBatchDecodedOutput get_frames_in_range(

src/torchcodec/decoders/_core/VideoDecoderOps.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder);
8787

8888
// Return the frames at a given index for a given stream as a single stacked
8989
// Tensor.
90-
at::Tensor get_frames_at_indices(
90+
OpsBatchDecodedOutput get_frames_at_indices(
9191
at::Tensor& decoder,
9292
int64_t stream_index,
9393
at::IntArrayRef frame_indices);

src/torchcodec/decoders/_core/video_decoder_ops.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,13 @@ def get_frames_at_indices_abstract(
190190
*,
191191
stream_index: int,
192192
frame_indices: List[int],
193-
) -> torch.Tensor:
193+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194194
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
195-
return torch.empty(image_size)
195+
return (
196+
torch.empty(image_size),
197+
torch.empty([], dtype=torch.float),
198+
torch.empty([], dtype=torch.float),
199+
)
196200

197201

198202
@register_fake("torchcodec_ns::get_frames_in_range")

test/decoders/test_video_decoder_ops.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,45 @@ def test_get_frames_at_indices(self):
116116
decoder = create_from_file(str(NASA_VIDEO.path))
117117
scan_all_streams_to_update_metadata(decoder)
118118
add_video_stream(decoder)
119-
frames0and180 = get_frames_at_indices(
119+
frames0and180, *_ = get_frames_at_indices(
120120
decoder, stream_index=3, frame_indices=[0, 180]
121121
)
122122
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
123123
reference_frame180 = NASA_VIDEO.get_frame_by_name("time6.000000")
124124
assert_tensor_equal(frames0and180[0], reference_frame0)
125125
assert_tensor_equal(frames0and180[1], reference_frame180)
126126

127+
def test_get_frames_at_indices_unsorted_indices(self):
128+
decoder = create_from_file(str(NASA_VIDEO.path))
129+
_add_video_stream(decoder)
130+
scan_all_streams_to_update_metadata(decoder)
131+
stream_index = 3
132+
133+
frame_indices = [2, 0, 1, 0, 2]
134+
135+
expected_frames = [
136+
get_frame_at_index(
137+
decoder, stream_index=stream_index, frame_index=frame_index
138+
)[0]
139+
for frame_index in frame_indices
140+
]
141+
142+
frames, *_ = get_frames_at_indices(
143+
decoder,
144+
stream_index=stream_index,
145+
frame_indices=frame_indices,
146+
)
147+
for frame, expected_frame in zip(frames, expected_frames):
148+
assert_tensor_equal(frame, expected_frame)
149+
150+
# first and last frame should be equal, at index 2. We then modify the
151+
# first frame and assert that it's now different from the last frame.
152+
# This ensures a copy was properly made during the de-duplication logic.
153+
assert_tensor_equal(frames[0], frames[-1])
154+
frames[0] += 20
155+
with pytest.raises(AssertionError):
156+
assert_tensor_equal(frames[0], frames[-1])
157+
127158
def test_get_frames_in_range(self):
128159
decoder = create_from_file(str(NASA_VIDEO.path))
129160
scan_all_streams_to_update_metadata(decoder)
@@ -425,7 +456,7 @@ def test_color_conversion_library_with_dimension_order(
425456
assert frames.shape[1:] == expected_shape
426457
assert_tensor_equal(frames[0], frame0_ref)
427458

428-
frames = get_frames_at_indices(
459+
frames, *_ = get_frames_at_indices(
429460
decoder, stream_index=stream_index, frame_indices=[0, 1, 3, 4]
430461
)
431462
assert frames.shape[1:] == expected_shape

0 commit comments

Comments
 (0)