Skip to content

Commit 882ef9f

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into mac_wheels_ci
2 parents a889d52 + bb29228 commit 882ef9f

12 files changed

+288
-13
lines changed
+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
name: Build and test Linux CUDA wheels
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- nightly
8+
- main
9+
- release/*
10+
tags:
11+
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
12+
workflow_dispatch:
13+
14+
concurrency:
15+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
16+
cancel-in-progress: true
17+
18+
permissions:
19+
id-token: write
20+
contents: write
21+
22+
defaults:
23+
run:
24+
shell: bash -l -eo pipefail {0}
25+
26+
jobs:
27+
generate-matrix:
28+
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
29+
with:
30+
package-type: wheel
31+
os: linux
32+
test-infra-repository: pytorch/test-infra
33+
test-infra-ref: main
34+
with-cpu: disable
35+
with-xpu: disable
36+
with-rocm: disable
37+
with-cuda: enable
38+
build-python-only: "disable"
39+
build:
40+
needs: generate-matrix
41+
strategy:
42+
fail-fast: false
43+
name: Build and Upload wheel
44+
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
45+
with:
46+
repository: pytorch/torchcodec
47+
ref: ""
48+
test-infra-repository: pytorch/test-infra
49+
test-infra-ref: main
50+
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
51+
post-script: packaging/post_build_script.sh
52+
smoke-test-script: packaging/fake_smoke_test.py
53+
package-name: torchcodec
54+
trigger-event: ${{ github.event_name }}
55+
build-platform: "python-build-package"
56+
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation"
57+
58+
install-and-test:
59+
runs-on: linux.4xlarge.nvidia.gpu
60+
strategy:
61+
fail-fast: false
62+
matrix:
63+
# 3.9 corresponds to the minimum python version for which we build
64+
# the wheel unless the label cliflow/binaries/all is present in the
65+
# PR.
66+
# For the actual release we should add that label and change this to
67+
# include more python versions.
68+
python-version: ['3.9']
69+
cuda-version: ['11.8', '12.1', '12.4']
70+
ffmpeg-version-for-tests: ['5', '6', '7']
71+
container:
72+
image: "pytorch/manylinux-builder:cuda${{ matrix.cuda-version }}"
73+
options: "--gpus all -e NVIDIA_DRIVER_CAPABILITIES=video,compute,utility"
74+
if: ${{ always() }}
75+
needs: build
76+
steps:
77+
- name: Setup env vars
78+
run: |
79+
cuda_version_without_periods=$(echo "${{ matrix.cuda-version }}" | sed 's/\.//g')
80+
echo cuda_version_without_periods=${cuda_version_without_periods} >> $GITHUB_ENV
81+
- uses: actions/download-artifact@v3
82+
with:
83+
name: pytorch_torchcodec__3.9_cu${{ env.cuda_version_without_periods }}_x86_64
84+
path: pytorch/torchcodec/dist/
85+
- name: Setup miniconda using test-infra
86+
uses: ahmadsharif1/test-infra/.github/actions/setup-miniconda@14bc3c29f88d13b0237ab4ddf00aa409e45ade40
87+
with:
88+
python-version: ${{ matrix.python-version }}
89+
default-packages: "conda-forge::ffmpeg=${{ matrix.ffmpeg-version-for-tests }}"
90+
- name: Check env
91+
run: |
92+
${CONDA_RUN} env
93+
${CONDA_RUN} conda info
94+
${CONDA_RUN} nvidia-smi
95+
- name: Update pip
96+
run: ${CONDA_RUN} python -m pip install --upgrade pip
97+
- name: Install PyTorch
98+
run: |
99+
${CONDA_RUN} python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu${{ env.cuda_version_without_periods }}
100+
${CONDA_RUN} python -c 'import torch; print(f"{torch.__version__}"); print(f"{torch.__file__}"); print(f"{torch.cuda.is_available()=}")'
101+
- name: Install torchcodec from the wheel
102+
run: |
103+
wheel_path=`find pytorch/torchcodec/dist -type f -name "*.whl"`
104+
echo Installing $wheel_path
105+
${CONDA_RUN} python -m pip install $wheel_path -vvv
106+
107+
- name: Check out repo
108+
uses: actions/checkout@v3
109+
110+
- name: Install cuda runtime dependencies
111+
run: |
112+
# For some reason nvidia::libnpp=12.4 doesn't install but nvidia/label/cuda-12.4.0::libnpp does.
113+
# So we use the latter convention for libnpp.
114+
${CONDA_RUN} conda install --yes nvidia/label/cuda-${{ matrix.cuda-version }}.0::libnpp nvidia::cuda-nvrtc=${{ matrix.cuda-version }} nvidia::cuda-toolkit=${{ matrix.cuda-version }} nvidia::cuda-cudart=${{ matrix.cuda-version }} nvidia::cuda-driver-dev=${{ matrix.cuda-version }}
115+
- name: Install test dependencies
116+
run: |
117+
${CONDA_RUN} python -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
118+
# Ideally we would find a way to get those dependencies from pyproject.toml
119+
${CONDA_RUN} python -m pip install numpy pytest pillow
120+
121+
- name: Delete the src/ folder just for fun
122+
run: |
123+
# The only reason we checked-out the repo is to get access to the
124+
# tests. We don't care about the rest. Out of precaution, we delete
125+
# the src/ folder to be extra sure that we're running the code from
126+
# the installed wheel rather than from the source.
127+
# This is just to be extra cautious and very overkill because a)
128+
# there's no way the `torchcodec` package from src/ can be found from
129+
# the PythonPath: the main point of `src/` is precisely to protect
130+
# against that and b) if we ever were to execute code from
131+
# `src/torchcodec`, it would fail loudly because the built .so files
132+
# aren't present there.
133+
rm -r src/
134+
ls
135+
- name: Smoke test
136+
run: |
137+
${CONDA_RUN} python test/decoders/manual_smoke_test.py
138+
- name: Run Python tests
139+
run: |
140+
# We skip test_get_ffmpeg_version because it may not have a micro version.
141+
${CONDA_RUN} FAIL_WITHOUT_CUDA=1 pytest test -k "not test_get_ffmpeg_version" -vvv
142+
- name: Run Python benchmark
143+
run: |
144+
${CONDA_RUN} time python benchmarks/decoders/gpu_benchmark.py --devices=cuda:0,cpu --resize_devices=none

.github/workflows/macos_wheel.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ defaults:
2626
jobs:
2727

2828
generate-matrix:
29-
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@macbuildwheel
29+
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
3030
with:
3131
package-type: wheel
3232
os: macos-arm64
@@ -42,12 +42,12 @@ jobs:
4242
strategy:
4343
fail-fast: false
4444
name: Build and Upload Mac wheel
45-
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@macbuildwheel
45+
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main
4646
with:
4747
repository: pytorch/torchcodec
4848
ref: ""
4949
test-infra-repository: pytorch/test-infra
50-
test-infra-ref: macbuildwheel
50+
test-infra-ref: main
5151
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
5252
post-script: packaging/post_build_script.sh
5353
smoke-test-script: packaging/fake_smoke_test.py
@@ -58,7 +58,7 @@ jobs:
5858
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 python -m build --wheel -vvv --no-isolation"
5959

6060
install-and-test:
61-
runs-on: macos-m1-stable
61+
runs-on: macos-14-xlarge
6262
strategy:
6363
fail-fast: false
6464
matrix:

packaging/post_build_script.sh

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
set -eux
3+
set -ex
44

55
source packaging/helpers.sh
66

@@ -18,7 +18,12 @@ else
1818
exit 1
1919
fi
2020

21-
for ffmpeg_major_version in 4 5 6 7; do
21+
# TODO: Make ffmpeg4 work with nvcc.
22+
if [[ "$ENABLE_CUDA" -eq 1 ]]; then
23+
ffmpeg_versions=(5 6 7)
24+
fi
25+
26+
for ffmpeg_major_version in ${ffmpeg_versions[@]}; do
2227
assert_in_wheel $wheel_path torchcodec/libtorchcodec${ffmpeg_major_version}.${ext}
2328
done
2429
assert_not_in_wheel $wheel_path libtorchcodec.${ext}

src/torchcodec/decoders/_core/CMakeLists.txt

+10-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function(make_torchcodec_library library_name ffmpeg_target)
3737
set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES}
3838
${Python3_LIBRARIES})
3939
if(ENABLE_CUDA)
40-
list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY}
40+
list(APPEND NEEDED_LIBRARIES
4141
${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} )
4242
endif()
4343
target_link_libraries(
@@ -76,10 +76,15 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
7676
${CMAKE_CURRENT_SOURCE_DIR}/fetch_and_expose_non_gpl_ffmpeg_libs.cmake
7777
)
7878

79-
make_torchcodec_library(libtorchcodec4 ffmpeg4)
80-
make_torchcodec_library(libtorchcodec5 ffmpeg5)
81-
make_torchcodec_library(libtorchcodec6 ffmpeg6)
82-
make_torchcodec_library(libtorchcodec7 ffmpeg7)
79+
80+
if(NOT ENABLE_CUDA)
81+
# TODO: Enable more ffmpeg versions for cuda.
82+
make_torchcodec_library(libtorchcodec4 ffmpeg4)
83+
endif()
84+
make_torchcodec_library(libtorchcodec7 ffmpeg7)
85+
make_torchcodec_library(libtorchcodec6 ffmpeg6)
86+
make_torchcodec_library(libtorchcodec5 ffmpeg5)
87+
8388
else()
8489
message(
8590
STATUS

src/torchcodec/decoders/_core/VideoDecoder.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,53 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10731073
return output;
10741074
}
10751075

1076+
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(
1077+
int streamIndex,
1078+
const std::vector<double>& timestamps) {
1079+
validateUserProvidedStreamIndex(streamIndex);
1080+
validateScannedAllStreams("getFramesDisplayedByTimestamps");
1081+
1082+
// The frame displayed at timestamp t and the one displayed at timestamp `t +
1083+
// eps` are probably the same frame, with the same index. The easiest way to
1084+
// avoid decoding that unique frame twice is to convert the input timestamps
1085+
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
1086+
// This means this function requires a scan.
1087+
// TODO: longer term, we should implement this without requiring a scan
1088+
1089+
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1090+
const auto& stream = streams_[streamIndex];
1091+
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
1092+
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();
1093+
1094+
std::vector<int64_t> frameIndices(timestamps.size());
1095+
for (auto i = 0; i < timestamps.size(); ++i) {
1096+
auto framePts = timestamps[i];
1097+
TORCH_CHECK(
1098+
framePts >= minSeconds && framePts < maxSeconds,
1099+
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
1100+
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
1101+
").");
1102+
1103+
auto it = std::lower_bound(
1104+
stream.allFrames.begin(),
1105+
stream.allFrames.end(),
1106+
framePts,
1107+
[&stream](const FrameInfo& info, double framePts) {
1108+
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
1109+
});
1110+
int64_t frameIndex = it - stream.allFrames.begin();
1111+
// If the frame index is larger than the size of allFrames, that means we
1112+
// couldn't match the pts value to the pts value of a NEXT FRAME. And
1113+
// that means that this timestamp falls during the time between when the
1114+
// last frame is displayed, and the video ends. Hence, it should map to the
1115+
// index of the last frame.
1116+
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
1117+
frameIndices[i] = frameIndex;
1118+
}
1119+
1120+
return getFramesAtIndices(streamIndex, frameIndices);
1121+
}
1122+
10761123
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10771124
int streamIndex,
10781125
int64_t start,

src/torchcodec/decoders/_core/VideoDecoder.h

+6
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class VideoDecoder {
223223
// i.e. it will be returned when this function is called with seconds=5.0 or
224224
// seconds=5.999, etc.
225225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
226+
226227
DecodedOutput getFrameAtIndex(
227228
int streamIndex,
228229
int64_t frameIndex,
@@ -242,6 +243,11 @@ class VideoDecoder {
242243
BatchDecodedOutput getFramesAtIndices(
243244
int streamIndex,
244245
const std::vector<int64_t>& frameIndices);
246+
247+
BatchDecodedOutput getFramesDisplayedByTimestamps(
248+
int streamIndex,
249+
const std::vector<double>& timestamps);
250+
245251
// Returns frames within a given range for a given stream as a single stacked
246252
// Tensor. The range is defined by [start, stop). The values retrieved from
247253
// the range are:

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4545
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4646
m.def(
4747
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
48+
m.def(
49+
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
4850
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
4951
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5052
m.def(
@@ -240,6 +242,16 @@ OpsBatchDecodedOutput get_frames_in_range(
240242
stream_index, start, stop, step.value_or(1));
241243
return makeOpsBatchDecodedOutput(result);
242244
}
245+
OpsBatchDecodedOutput get_frames_by_pts(
246+
at::Tensor& decoder,
247+
int64_t stream_index,
248+
at::ArrayRef<double> timestamps) {
249+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
250+
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
251+
auto result =
252+
videoDecoder->getFramesDisplayedByTimestamps(stream_index, timestampsVec);
253+
return makeOpsBatchDecodedOutput(result);
254+
}
243255

244256
OpsBatchDecodedOutput get_frames_by_pts_in_range(
245257
at::Tensor& decoder,
@@ -479,6 +491,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
479491
m.impl("get_frames_at_indices", &get_frames_at_indices);
480492
m.impl("get_frames_in_range", &get_frames_in_range);
481493
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
494+
m.impl("get_frames_by_pts", &get_frames_by_pts);
482495
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
483496
m.impl(
484497
"scan_all_streams_to_update_metadata",

src/torchcodec/decoders/_core/VideoDecoderOps.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7575
// given timestamp T has T >= PTS and T < PTS + Duration.
7676
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7777

78+
// Return the frames at given ptss for a given stream
79+
OpsBatchDecodedOutput get_frames_by_pts(
80+
at::Tensor& decoder,
81+
int64_t stream_index,
82+
at::ArrayRef<double> timestamps);
83+
7884
// Return the frame that is visible at a given index in the video.
7985
OpsDecodedOutput get_frame_at_index(
8086
at::Tensor& decoder,
@@ -85,8 +91,7 @@ OpsDecodedOutput get_frame_at_index(
8591
// duration as tensors.
8692
OpsDecodedOutput get_next_frame(at::Tensor& decoder);
8793

88-
// Return the frames at a given index for a given stream as a single stacked
89-
// Tensor.
94+
// Return the frames at given indices for a given stream
9095
OpsBatchDecodedOutput get_frames_at_indices(
9196
at::Tensor& decoder,
9297
int64_t stream_index,

src/torchcodec/decoders/_core/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_frame_at_index,
2323
get_frame_at_pts,
2424
get_frames_at_indices,
25+
get_frames_by_pts,
2526
get_frames_by_pts_in_range,
2627
get_frames_in_range,
2728
get_json_metadata,

src/torchcodec/decoders/_core/video_decoder_ops.py

+16
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def load_torchcodec_extension():
7171
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
7272
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
7373
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
74+
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
7475
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
7576
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
7677
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
@@ -172,6 +173,21 @@ def get_frame_at_pts_abstract(
172173
)
173174

174175

176+
@register_fake("torchcodec_ns::get_frames_by_pts")
177+
def get_frames_by_pts_abstract(
178+
decoder: torch.Tensor,
179+
*,
180+
stream_index: int,
181+
timestamps: List[float],
182+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183+
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
184+
return (
185+
torch.empty(image_size),
186+
torch.empty([], dtype=torch.float),
187+
torch.empty([], dtype=torch.float),
188+
)
189+
190+
175191
@register_fake("torchcodec_ns::get_frame_at_index")
176192
def get_frame_at_index_abstract(
177193
decoder: torch.Tensor, *, stream_index: int, frame_index: int

0 commit comments

Comments
 (0)