Skip to content

Commit 2a19014

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into loglevelencoder
2 parents 7921558 + 44bef81 commit 2a19014

15 files changed

+277
-222
lines changed

src/torchcodec/_core/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function(make_torchcodec_libraries
6868
)
6969

7070
if(ENABLE_CUDA)
71-
list(APPEND decoder_sources CudaDevice.cpp)
71+
list(APPEND decoder_sources CudaDeviceInterface.cpp)
7272
endif()
7373

7474
set(decoder_library_dependencies

src/torchcodec/_core/CudaDevice.cpp renamed to src/torchcodec/_core/CudaDeviceInterface.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <torch/types.h>
55
#include <mutex>
66

7-
#include "src/torchcodec/_core/CudaDevice.h"
7+
#include "src/torchcodec/_core/CudaDeviceInterface.h"
88
#include "src/torchcodec/_core/FFMPEGCommon.h"
99
#include "src/torchcodec/_core/SingleStreamDecoder.h"
1010

@@ -16,9 +16,10 @@ extern "C" {
1616
namespace facebook::torchcodec {
1717
namespace {
1818

19-
bool g_cuda = registerDeviceInterface(
20-
torch::kCUDA,
21-
[](const torch::Device& device) { return new CudaDevice(device); });
19+
bool g_cuda =
20+
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
21+
return new CudaDeviceInterface(device);
22+
});
2223

2324
// We reuse cuda contexts across VideoDeoder instances. This is because
2425
// creating a cuda context is expensive. The cache mechanism is as follows:
@@ -163,20 +164,21 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
163164
}
164165
} // namespace
165166

166-
CudaDevice::CudaDevice(const torch::Device& device) : DeviceInterface(device) {
167+
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
168+
: DeviceInterface(device) {
167169
if (device_.type() != torch::kCUDA) {
168170
throw std::runtime_error("Unsupported device: " + device_.str());
169171
}
170172
}
171173

172-
CudaDevice::~CudaDevice() {
174+
CudaDeviceInterface::~CudaDeviceInterface() {
173175
if (ctx_) {
174176
addToCacheIfCacheHasCapacity(device_, ctx_);
175177
av_buffer_unref(&ctx_);
176178
}
177179
}
178180

179-
void CudaDevice::initializeContext(AVCodecContext* codecContext) {
181+
void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
180182
TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized");
181183

182184
// It is important for pytorch itself to create the cuda context. If ffmpeg
@@ -189,10 +191,10 @@ void CudaDevice::initializeContext(AVCodecContext* codecContext) {
189191
return;
190192
}
191193

192-
void CudaDevice::convertAVFrameToFrameOutput(
193-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
194+
void CudaDeviceInterface::convertAVFrameToFrameOutput(
195+
const VideoStreamOptions& videoStreamOptions,
194196
UniqueAVFrame& avFrame,
195-
SingleStreamDecoder::FrameOutput& frameOutput,
197+
FrameOutput& frameOutput,
196198
std::optional<torch::Tensor> preAllocatedOutputTensor) {
197199
TORCH_CHECK(
198200
avFrame->format == AV_PIX_FMT_CUDA,
@@ -263,7 +265,8 @@ void CudaDevice::convertAVFrameToFrameOutput(
263265
// we have to do this because of an FFmpeg bug where hardware decoding is not
264266
// appropriately set, so we just go off and find the matching codec for the CUDA
265267
// device
266-
std::optional<const AVCodec*> CudaDevice::findCodec(const AVCodecID& codecId) {
268+
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(
269+
const AVCodecID& codecId) {
267270
void* i = nullptr;
268271
const AVCodec* codec = nullptr;
269272
while ((codec = av_codec_iterate(&i)) != nullptr) {

src/torchcodec/_core/CudaDevice.h renamed to src/torchcodec/_core/CudaDeviceInterface.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010

1111
namespace facebook::torchcodec {
1212

13-
class CudaDevice : public DeviceInterface {
13+
class CudaDeviceInterface : public DeviceInterface {
1414
public:
15-
CudaDevice(const torch::Device& device);
15+
CudaDeviceInterface(const torch::Device& device);
1616

17-
virtual ~CudaDevice();
17+
virtual ~CudaDeviceInterface();
1818

1919
std::optional<const AVCodec*> findCodec(const AVCodecID& codecId) override;
2020

2121
void initializeContext(AVCodecContext* codecContext) override;
2222

2323
void convertAVFrameToFrameOutput(
24-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
24+
const VideoStreamOptions& videoStreamOptions,
2525
UniqueAVFrame& avFrame,
26-
SingleStreamDecoder::FrameOutput& frameOutput,
26+
FrameOutput& frameOutput,
2727
std::optional<torch::Tensor> preAllocatedOutputTensor =
2828
std::nullopt) override;
2929

src/torchcodec/_core/DeviceInterface.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
#include <stdexcept>
1313
#include <string>
1414
#include "FFMPEGCommon.h"
15-
#include "src/torchcodec/_core/SingleStreamDecoder.h"
15+
#include "src/torchcodec/_core/Frame.h"
16+
#include "src/torchcodec/_core/StreamOptions.h"
1617

1718
namespace facebook::torchcodec {
1819

@@ -41,9 +42,9 @@ class DeviceInterface {
4142
virtual void initializeContext(AVCodecContext* codecContext) = 0;
4243

4344
virtual void convertAVFrameToFrameOutput(
44-
const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions,
45+
const VideoStreamOptions& videoStreamOptions,
4546
UniqueAVFrame& avFrame,
46-
SingleStreamDecoder::FrameOutput& frameOutput,
47+
FrameOutput& frameOutput,
4748
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;
4849

4950
protected:

src/torchcodec/_core/Encoder.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ AudioEncoder::~AudioEncoder() {}
88
AudioEncoder::AudioEncoder(
99
const torch::Tensor wf,
1010
int sampleRate,
11-
std::string_view fileName)
11+
std::string_view fileName,
12+
std::optional<int64_t> bit_rate)
1213
: wf_(wf), sampleRate_(sampleRate) {
1314
TORCH_CHECK(
1415
wf_.dtype() == torch::kFloat32,
@@ -49,11 +50,12 @@ AudioEncoder::AudioEncoder(
4950
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
5051
avCodecContext_.reset(avCodecContext);
5152

52-
// TODO-ENCODING I think this sets the bit rate to the minimum supported.
53-
// That's not what the ffmpeg CLI would choose by default, so we should try to
54-
// do the same.
55-
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
56-
avCodecContext_->bit_rate = 0;
53+
if (bit_rate.has_value()) {
54+
TORCH_CHECK(*bit_rate >= 0, "bit_rate=", *bit_rate, " must be >= 0.");
55+
}
56+
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
57+
// well when "-b:a" isn't specified.
58+
avCodecContext_->bit_rate = bit_rate.value_or(0);
5759

5860
avCodecContext_->sample_rate = sampleRate_;
5961

src/torchcodec/_core/Encoder.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ class AudioEncoder {
77
public:
88
~AudioEncoder();
99

10+
// TODO-ENCODING: document in public docs that bit_rate value is only
11+
// best-effort, matching to the closest supported bit_rate. I.e. passing 1 is
12+
// like passing 0, which results in choosing the minimum supported bit rate.
13+
// Passing 44_100 could result in output being 44000 if only 44000 is
14+
// supported.
1015
AudioEncoder(
1116
const torch::Tensor wf,
1217
int sampleRate,
13-
std::string_view fileName);
18+
std::string_view fileName,
19+
std::optional<int64_t> bit_rate = std::nullopt);
1420
void encode();
1521

1622
private:

src/torchcodec/_core/Frame.h

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <torch/types.h>
10+
#include "src/torchcodec/_core/Metadata.h"
11+
#include "src/torchcodec/_core/StreamOptions.h"
12+
13+
namespace facebook::torchcodec {
14+
15+
// All public video decoding entry points return either a FrameOutput or a
16+
// FrameBatchOutput.
17+
// They are the equivalent of the user-facing Frame and FrameBatch classes in
18+
// Python. They contain RGB decoded frames along with some associated data
19+
// like PTS and duration.
20+
// FrameOutput is also relevant for audio decoding, typically as the output of
21+
// getNextFrame(), or as a temporary output variable.
22+
struct FrameOutput {
23+
// data shape is:
24+
// - 3D (C, H, W) or (H, W, C) for videos
25+
// - 2D (numChannels, numSamples) for audio
26+
torch::Tensor data;
27+
double ptsSeconds;
28+
double durationSeconds;
29+
};
30+
31+
struct FrameBatchOutput {
32+
torch::Tensor data; // 4D: of shape NCHW or NHWC.
33+
torch::Tensor ptsSeconds; // 1D of shape (N,)
34+
torch::Tensor durationSeconds; // 1D of shape (N,)
35+
36+
explicit FrameBatchOutput(
37+
int64_t numFrames,
38+
const VideoStreamOptions& videoStreamOptions,
39+
const StreamMetadata& streamMetadata);
40+
};
41+
42+
struct AudioFramesOutput {
43+
torch::Tensor data; // shape is (numChannels, numSamples)
44+
double ptsSeconds;
45+
};
46+
47+
} // namespace facebook::torchcodec

src/torchcodec/_core/Metadata.h

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <optional>
10+
#include <string>
11+
#include <vector>
12+
13+
extern "C" {
14+
#include <libavcodec/avcodec.h>
15+
#include <libavutil/avutil.h>
16+
}
17+
18+
namespace facebook::torchcodec {
19+
20+
struct StreamMetadata {
21+
// Common (video and audio) fields derived from the AVStream.
22+
int streamIndex;
23+
// See this link for what various values are available:
24+
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
25+
AVMediaType mediaType;
26+
std::optional<AVCodecID> codecId;
27+
std::optional<std::string> codecName;
28+
std::optional<double> durationSeconds;
29+
std::optional<double> beginStreamFromHeader;
30+
std::optional<int64_t> numFrames;
31+
std::optional<int64_t> numKeyFrames;
32+
std::optional<double> averageFps;
33+
std::optional<double> bitRate;
34+
35+
// More accurate duration, obtained by scanning the file.
36+
// These presentation timestamps are in time base.
37+
std::optional<int64_t> minPtsFromScan;
38+
std::optional<int64_t> maxPtsFromScan;
39+
// These presentation timestamps are in seconds.
40+
std::optional<double> minPtsSecondsFromScan;
41+
std::optional<double> maxPtsSecondsFromScan;
42+
// This can be useful for index-based seeking.
43+
std::optional<int64_t> numFramesFromScan;
44+
45+
// Video-only fields derived from the AVCodecContext.
46+
std::optional<int64_t> width;
47+
std::optional<int64_t> height;
48+
49+
// Audio-only fields
50+
std::optional<int64_t> sampleRate;
51+
std::optional<int64_t> numChannels;
52+
std::optional<std::string> sampleFormat;
53+
};
54+
55+
struct ContainerMetadata {
56+
std::vector<StreamMetadata> allStreamMetadata;
57+
int numAudioStreams = 0;
58+
int numVideoStreams = 0;
59+
// Note that this is the container-level duration, which is usually the max
60+
// of all stream durations available in the container.
61+
std::optional<double> durationSeconds;
62+
// Total BitRate level information at the container level in bit/s
63+
std::optional<double> bitRate;
64+
// If set, this is the index to the default audio stream.
65+
std::optional<int> bestAudioStreamIndex;
66+
// If set, this is the index to the default video stream.
67+
std::optional<int> bestVideoStreamIndex;
68+
};
69+
70+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)