diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 382de621..7612334b 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -11,9 +11,8 @@ namespace facebook::torchcodec { namespace { -using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>; std::mutex g_interface_mutex; -std::unique_ptr<DeviceInterfaceMap> g_interface_map; +std::map<torch::DeviceType, CreateDeviceInterfaceFn> g_interface_map; std::string getDeviceType(const std::string& device) { size_t pos = device.find(':'); @@ -29,18 +28,11 @@ bool registerDeviceInterface( torch::DeviceType deviceType, CreateDeviceInterfaceFn createInterface) { std::scoped_lock lock(g_interface_mutex); - if (!g_interface_map) { - // We delay this initialization until runtime to avoid the Static - // Initialization Order Fiasco: - // - // https://en.cppreference.com/w/cpp/language/siof - g_interface_map = std::make_unique<DeviceInterfaceMap>(); - } TORCH_CHECK( - g_interface_map->find(deviceType) == g_interface_map->end(), + g_interface_map.find(deviceType) == g_interface_map.end(), "Device interface already registered for ", deviceType); - g_interface_map->insert({deviceType, createInterface}); + g_interface_map.insert({deviceType, createInterface}); return true; } @@ -53,16 +45,14 @@ torch::Device createTorchDevice(const std::string device) { std::scoped_lock lock(g_interface_mutex); std::string deviceType = getDeviceType(device); auto deviceInterface = std::find_if( - g_interface_map->begin(), - g_interface_map->end(), + g_interface_map.begin(), + g_interface_map.end(), [&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) { return device.rfind( torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; }); TORCH_CHECK( - deviceInterface != g_interface_map->end(), - "Unsupported device: ", - device); + deviceInterface != g_interface_map.end(), "Unsupported device: ", device); return torch::Device(device); } @@ -77,12 +67,11 @@ std::unique_ptr<DeviceInterface> createDeviceInterface( std::scoped_lock lock(g_interface_mutex); TORCH_CHECK( - g_interface_map->find(deviceType) != g_interface_map->end(), + g_interface_map.find(deviceType) != g_interface_map.end(), "Unsupported device: ", device); - return std::unique_ptr<DeviceInterface>( - (*g_interface_map)[deviceType](device)); + return std::unique_ptr<DeviceInterface>(g_interface_map[deviceType](device)); } } // namespace facebook::torchcodec