@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
195
195
int64_t numFrames,
196
196
const VideoStreamDecoderOptions& options,
197
197
const StreamMetadata& metadata)
198
- : frames(torch::empty(
199
- {numFrames,
200
- options.height .value_or (*metadata.height ),
201
- options.width .value_or (*metadata.width ),
202
- 3 },
203
- at::TensorOptions (options.device).dtype(torch::kUInt8 ))),
204
- ptsSeconds (torch::empty({numFrames}, {torch::kFloat64 })),
205
- durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {}
198
+ : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64 })),
199
+ durationSeconds (torch::empty({numFrames}, {torch::kFloat64 })) {
200
+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (options, metadata);
201
+ int height = frameDims.height ;
202
+ int width = frameDims.width ;
203
+ frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
204
+ }
206
205
207
206
VideoDecoder::VideoDecoder () {}
208
207
@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
364
363
inputs->pad_idx = 0 ;
365
364
inputs->next = nullptr ;
366
365
char description[512 ];
367
- int width = activeStream.codecContext ->width ;
368
- int height = activeStream.codecContext ->height ;
369
- if (options.height .has_value () && options.width .has_value ()) {
370
- width = *options.width ;
371
- height = *options.height ;
372
- }
366
+ auto frameDims = getHeightAndWidthFromOptionsOrMetadata (
367
+ options, containerMetadata_.streams [streamIndex]);
368
+ int height = frameDims.height ;
369
+ int width = frameDims.width ;
370
+
373
371
std::snprintf (
374
372
description,
375
373
sizeof (description),
@@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
869
867
convertAVFrameToDecodedOutputOnCuda (
870
868
streamInfo.options .device ,
871
869
streamInfo.options ,
872
- streamInfo. codecContext . get () ,
870
+ containerMetadata_. streams [streamIndex] ,
873
871
rawOutput,
874
872
output,
875
873
preAllocatedOutputTensor);
@@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
899
897
torch::Tensor tensor;
900
898
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
901
899
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
902
- int width = streamInfo.options .width .value_or (frame->width );
903
- int height = streamInfo.options .height .value_or (frame->height );
900
+ auto frameDims =
901
+ getHeightAndWidthFromOptionsOrAVFrame (streamInfo.options , *frame);
902
+ int height = frameDims.height ;
903
+ int width = frameDims.width ;
904
904
if (preAllocatedOutputTensor.has_value ()) {
905
905
tensor = preAllocatedOutputTensor.value ();
906
906
auto shape = tensor.sizes ();
@@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
914
914
" x3, got " ,
915
915
shape);
916
916
} else {
917
- tensor = torch::empty (
918
- {height, width, 3 }, torch::TensorOptions ().dtype ({torch::kUInt8 }));
917
+ tensor = allocateEmptyHWCTensor (height, width, torch::kCPU );
919
918
}
920
919
rawOutput.data = tensor.data_ptr <uint8_t >();
921
920
convertFrameToBufferUsingSwsScale (rawOutput);
@@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
1315
1314
enum AVPixelFormat frameFormat =
1316
1315
static_cast <enum AVPixelFormat>(frame->format );
1317
1316
StreamInfo& activeStream = streams_[streamIndex];
1318
- int outputWidth = activeStream.options .width .value_or (frame->width );
1319
- int outputHeight = activeStream.options .height .value_or (frame->height );
1317
+ auto frameDims =
1318
+ getHeightAndWidthFromOptionsOrAVFrame (activeStream.options , *frame);
1319
+ int outputHeight = frameDims.height ;
1320
+ int outputWidth = frameDims.width ;
1320
1321
if (activeStream.swsContext .get () == nullptr ) {
1321
1322
SwsContext* swsContext = sws_getContext (
1322
1323
frame->width ,
@@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
1382
1383
ffmpegStatus =
1383
1384
av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
1384
1385
TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1385
- std::vector<int64_t > shape = {filteredFrame->height , filteredFrame->width , 3 };
1386
+ auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1387
+ streams_[streamIndex].options , *filteredFrame.get ());
1388
+ int height = frameDims.height ;
1389
+ int width = frameDims.width ;
1390
+ std::vector<int64_t > shape = {height, width, 3 };
1386
1391
std::vector<int64_t > strides = {filteredFrame->linesize [0 ], 3 , 1 };
1387
1392
AVFrame* filteredFramePtr = filteredFrame.release ();
1388
1393
auto deleter = [filteredFramePtr](void *) {
@@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() {
1406
1411
}
1407
1412
}
1408
1413
1414
+ FrameDims getHeightAndWidthFromOptionsOrMetadata (
1415
+ const VideoDecoder::VideoStreamDecoderOptions& options,
1416
+ const VideoDecoder::StreamMetadata& metadata) {
1417
+ return FrameDims (
1418
+ options.height .value_or (*metadata.height ),
1419
+ options.width .value_or (*metadata.width ));
1420
+ }
1421
+
1422
+ FrameDims getHeightAndWidthFromOptionsOrAVFrame (
1423
+ const VideoDecoder::VideoStreamDecoderOptions& options,
1424
+ const AVFrame& avFrame) {
1425
+ return FrameDims (
1426
+ options.height .value_or (avFrame.height ),
1427
+ options.width .value_or (avFrame.width ));
1428
+ }
1429
+
1430
+ torch::Tensor allocateEmptyHWCTensor (
1431
+ int height,
1432
+ int width,
1433
+ torch::Device device,
1434
+ std::optional<int > numFrames) {
1435
+ auto tensorOptions = torch::TensorOptions ()
1436
+ .dtype (torch::kUInt8 )
1437
+ .layout (torch::kStrided )
1438
+ .device (device);
1439
+ TORCH_CHECK (height > 0 , " height must be > 0, got: " , height);
1440
+ TORCH_CHECK (width > 0 , " width must be > 0, got: " , width);
1441
+ if (numFrames.has_value ()) {
1442
+ auto numFramesValue = numFrames.value ();
1443
+ TORCH_CHECK (
1444
+ numFramesValue >= 0 , " numFrames must be >= 0, got: " , numFramesValue);
1445
+ return torch::empty ({numFramesValue, height, width, 3 }, tensorOptions);
1446
+ } else {
1447
+ return torch::empty ({height, width, 3 }, tensorOptions);
1448
+ }
1449
+ }
1450
+
1409
1451
std::ostream& operator <<(
1410
1452
std::ostream& os,
1411
1453
const VideoDecoder::DecodeStats& stats) {
0 commit comments