diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 7612334b..382de621 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -11,8 +11,9 @@ namespace facebook::torchcodec { namespace { +using DeviceInterfaceMap = std::map; std::mutex g_interface_mutex; -std::map g_interface_map; +std::unique_ptr g_interface_map; std::string getDeviceType(const std::string& device) { size_t pos = device.find(':'); @@ -28,11 +29,18 @@ bool registerDeviceInterface( torch::DeviceType deviceType, CreateDeviceInterfaceFn createInterface) { std::scoped_lock lock(g_interface_mutex); + if (!g_interface_map) { + // We delay this initialization until runtime to avoid the Static + // Initialization Order Fiasco: + // + // https://en.cppreference.com/w/cpp/language/siof + g_interface_map = std::make_unique(); + } TORCH_CHECK( - g_interface_map.find(deviceType) == g_interface_map.end(), + g_interface_map->find(deviceType) == g_interface_map->end(), "Device interface already registered for ", deviceType); - g_interface_map.insert({deviceType, createInterface}); + g_interface_map->insert({deviceType, createInterface}); return true; } @@ -45,14 +53,16 @@ torch::Device createTorchDevice(const std::string device) { std::scoped_lock lock(g_interface_mutex); std::string deviceType = getDeviceType(device); auto deviceInterface = std::find_if( - g_interface_map.begin(), - g_interface_map.end(), + g_interface_map->begin(), + g_interface_map->end(), [&](const std::pair& arg) { return device.rfind( torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0; }); TORCH_CHECK( - deviceInterface != g_interface_map.end(), "Unsupported device: ", device); + deviceInterface != g_interface_map->end(), + "Unsupported device: ", + device); return torch::Device(device); } @@ -67,11 +77,12 @@ std::unique_ptr createDeviceInterface( std::scoped_lock lock(g_interface_mutex); TORCH_CHECK( - g_interface_map.find(deviceType) != g_interface_map.end(), + g_interface_map->find(deviceType) != g_interface_map->end(), "Unsupported device: ", device); - return std::unique_ptr(g_interface_map[deviceType](device)); + return std::unique_ptr( + (*g_interface_map)[deviceType](device)); } } // namespace facebook::torchcodec