Skip to content

Commit 851b399

Browse files
committed
.
1 parent ddf3e54 commit 851b399

File tree

4 files changed

+49
-8
lines changed

4 files changed

+49
-8
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ void convertAVFrameToDecodedOutputOnCuda(
201201
const VideoDecoder::VideoStreamDecoderOptions& options,
202202
AVCodecContext* codecContext,
203203
VideoDecoder::RawDecodedOutput& rawOutput,
204-
VideoDecoder::DecodedOutput& output) {
204+
VideoDecoder::DecodedOutput& output,
205+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
205206
AVFrame* src = rawOutput.frame.get();
206207

207208
TORCH_CHECK(
@@ -213,7 +214,21 @@ void convertAVFrameToDecodedOutputOnCuda(
213214
NppiSize oSizeROI = {width, height};
214215
Npp8u* input[2] = {src->data[0], src->data[1]};
215216
torch::Tensor& dst = output.frame;
216-
dst = allocateDeviceTensor({height, width, 3}, options.device);
217+
if (preAllocatedOutputTensor.has_value()) {
218+
dst = preAllocatedOutputTensor.value();
219+
auto shape = dst.sizes();
220+
TORCH_CHECK(
221+
(shape.size() == 3) && (shape[0] == height) && (shape[1] == width) &&
222+
(shape[2] == 3),
223+
"Expected tensor of shape ",
224+
height,
225+
"x",
226+
width,
227+
"x3, got ",
228+
shape);
229+
} else {
230+
dst = allocateDeviceTensor({height, width, 3}, options.device);
231+
}
217232

218233
// Use the user-requested GPU for running the NPP kernel.
219234
c10::cuda::CUDAGuard deviceGuard(device);

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ void convertAVFrameToDecodedOutputOnCuda(
3737
const VideoDecoder::VideoStreamDecoderOptions& options,
3838
AVCodecContext* codecContext,
3939
VideoDecoder::RawDecodedOutput& rawOutput,
40-
VideoDecoder::DecodedOutput& output);
40+
VideoDecoder::DecodedOutput& output,
41+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
4142

4243
void releaseContextOnCuda(
4344
const torch::Device& device,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
196196
options.height.value_or(*metadata.height),
197197
options.width.value_or(*metadata.width),
198198
3},
199-
{torch::kUInt8})),
199+
at::TensorOptions(options.device).dtype(torch::kUInt8))),
200200
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
201201
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
202202

@@ -859,13 +859,14 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
859859
convertAVFrameToDecodedOutputOnCPU(
860860
rawOutput, output, preAllocatedOutputTensor);
861861
} else if (streamInfo.options.device.type() == torch::kCUDA) {
862-
// TODO: handle pre-allocated output tensor
862+
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
863863
convertAVFrameToDecodedOutputOnCuda(
864864
streamInfo.options.device,
865865
streamInfo.options,
866866
streamInfo.codecContext.get(),
867867
rawOutput,
868-
output);
868+
output,
869+
preAllocatedOutputTensor);
869870
} else {
870871
TORCH_CHECK(
871872
false, "Invalid device type: " + streamInfo.options.device.str());

test/decoders/test_video_decoder_ops.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def seek(self, pts: float):
5757
seek_to_pts(self.decoder, pts)
5858

5959

60+
# Asserts that at most percentage of the elements are different by more than abs_tolerance.
61+
def assert_tensor_nearly_equal(frame1, frame2, percentage=0.3, abs_tolerance=20):
62+
diff = (frame2.float() - frame1.float()).abs()
63+
assert (diff > abs_tolerance).float().mean() <= percentage / 100.0
64+
65+
6066
class TestOps:
6167
def test_seek_and_next(self):
6268
decoder = create_from_file(str(NASA_VIDEO.path))
@@ -137,6 +143,24 @@ def test_get_frames_at_indices(self):
137143
assert_tensor_equal(frames0and180[0], reference_frame0)
138144
assert_tensor_equal(frames0and180[1], reference_frame180)
139145

146+
@needs_cuda
147+
def test_get_frames_at_indices_with_cuda(self):
148+
decoder = create_from_file(str(NASA_VIDEO.path))
149+
scan_all_streams_to_update_metadata(decoder)
150+
add_video_stream(decoder, device="cuda")
151+
frames0and180, *_ = get_frames_at_indices(
152+
decoder, stream_index=3, frame_indices=[0, 180]
153+
)
154+
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
155+
reference_frame180 = NASA_VIDEO.get_frame_data_by_index(
156+
INDEX_OF_FRAME_AT_6_SECONDS
157+
)
158+
assert frames0and180.device.type == "cuda"
159+
assert_tensor_nearly_equal(frames0and180[0].to("cpu"), reference_frame0)
160+
assert_tensor_nearly_equal(
161+
frames0and180[1].to("cpu"), reference_frame180, 0.3, 30
162+
)
163+
140164
def test_get_frames_at_indices_unsorted_indices(self):
141165
decoder = create_from_file(str(NASA_VIDEO.path))
142166
_add_video_stream(decoder)
@@ -657,8 +681,8 @@ def test_cuda_decoder(self):
657681
assert frame0.device.type == "cuda"
658682
frame0_cpu = frame0.to("cpu")
659683
reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
660-
# GPU decode is not bit-accurate. In the following assertion we ensure
661-
# not more than 0.3% of values have a difference greater than 20.
684+
# GPU decode is not bit-accurate. So we allow some tolerance.
685+
assert_tensor_nearly_equal(frame0_cpu, reference_frame0)
662686
diff = (reference_frame0.float() - frame0_cpu.float()).abs()
663687
assert (diff > 20).float().mean() <= 0.003
664688
assert pts == torch.tensor([0])

0 commit comments

Comments
 (0)