File indexing completed on 2025-04-13 22:50:00
0001 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0003
0004 #include "DataFormats/Provenance/interface/ModuleDescription.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/allowedValues.h"
0007 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0009 #include "FWCore/ServiceRegistry/interface/ActivityRegistry.h"
0010 #include "FWCore/ServiceRegistry/interface/SystemBounds.h"
0011 #include "FWCore/ServiceRegistry/interface/ProcessContext.h"
0012 #include "FWCore/Utilities/interface/Exception.h"
0013 #include "FWCore/Utilities/interface/GetEnvironmentVariable.h"
0014
0015 #include "grpc_client.h"
0016 #include "grpc_service.pb.h"
0017
0018 #include <algorithm>
0019 #include <cctype>
0020 #include <cstdio>
0021 #include <cstdlib>
0022 #include <filesystem>
0023 #include <fstream>
0024 #include <utility>
0025 #include <tuple>
0026 #include <unistd.h>
0027
0028 namespace tc = triton::client;
0029
0030 const std::string TritonService::Server::fallbackName{"fallback"};
0031 const std::string TritonService::Server::fallbackAddress{"0.0.0.0"};
0032 const std::string TritonService::Server::siteconfName{"SONIC_LOCAL_BALANCER"};
0033
0034 namespace {
0035 std::pair<std::string, int> execSys(const std::string& cmd) {
0036
0037 auto pipe = popen((cmd + " 2>&1").c_str(), "r");
0038 int thisErrno = errno;
0039 if (!pipe)
0040 throw cms::Exception("SystemError")
0041 << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
0042
0043
0044 constexpr static unsigned buffSize = 128;
0045 std::array<char, buffSize> buffer;
0046 std::string result;
0047 while (!feof(pipe)) {
0048 if (fgets(buffer.data(), buffSize, pipe))
0049 result += buffer.data();
0050 else {
0051 thisErrno = ferror(pipe);
0052 if (thisErrno)
0053 throw cms::Exception("SystemError")
0054 << "TritonService: failed reading command output with errno " << thisErrno;
0055 }
0056 }
0057
0058 int rv = pclose(pipe);
0059 return std::make_pair(result, rv);
0060 }
0061
0062
0063 std::string extractFromLog(const std::string& output, const std::string& indicator) {
0064
0065 auto pos = output.rfind(indicator);
0066 if (pos != std::string::npos) {
0067 auto pos2 = pos + indicator.size();
0068 auto pos3 = output.find('\n', pos2);
0069 return output.substr(pos2, pos3 - pos2);
0070 } else
0071 return "";
0072 }
0073 }
0074
0075 TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg)
0076 : verbose_(pset.getUntrackedParameter<bool>("verbose")),
0077 fallbackOpts_(pset.getParameterSet("fallback")),
0078 currentModuleId_(0),
0079 allowAddModel_(false),
0080 startedFallback_(false),
0081 callFails_(0),
0082 pid_(std::to_string(::getpid())) {
0083
0084
0085 areg.watchPreallocate(this, &TritonService::preallocate);
0086
0087 areg.watchPreModuleConstruction(this, &TritonService::preModuleConstruction);
0088 areg.watchPostModuleConstruction(this, &TritonService::postModuleConstruction);
0089 areg.watchPreModuleDestruction(this, &TritonService::preModuleDestruction);
0090
0091 areg.watchPreBeginJob(this, &TritonService::preBeginJob);
0092 areg.watchPostEndJob(this, &TritonService::postEndJob);
0093
0094
0095
0096 std::string siteconf_address(edm::getEnvironmentVariable(Server::siteconfName + "_HOST"));
0097 std::string siteconf_port(edm::getEnvironmentVariable(Server::siteconfName + "_PORT"));
0098 if (!siteconf_address.empty() and !siteconf_port.empty()) {
0099 servers_.emplace(
0100 std::piecewise_construct,
0101 std::forward_as_tuple(Server::siteconfName),
0102 std::forward_as_tuple(Server::siteconfName, siteconf_address + ":" + siteconf_port, TritonServerType::Remote));
0103 if (verbose_)
0104 edm::LogInfo("TritonDiscovery") << "Obtained server from SITECONF: "
0105 << servers_.find(Server::siteconfName)->second.url;
0106 } else if (siteconf_address.empty() != siteconf_port.empty()) {
0107 edm::LogWarning("TritonDiscovery") << "Incomplete server information from SITECONF: HOST = " << siteconf_address
0108 << ", PORT = " << siteconf_port;
0109 } else
0110 edm::LogWarning("TritonDiscovery") << "No server information from SITECONF";
0111
0112
0113 for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
0114 const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
0115
0116 auto [sit, unique] = servers_.emplace(serverName, serverPset);
0117 if (!unique)
0118 throw cms::Exception("DuplicateServer")
0119 << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
0120 }
0121
0122
0123 std::string msg;
0124 if (verbose_)
0125 msg = "List of models for each server:\n";
0126 for (auto& [serverName, server] : servers_) {
0127 std::unique_ptr<tc::InferenceServerGrpcClient> client;
0128 TRITON_THROW_IF_ERROR(
0129 tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
0130 "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")",
0131 false);
0132
0133 if (verbose_) {
0134 inference::ServerMetadataResponse serverMetaResponse;
0135 auto err = client->ServerMetadata(&serverMetaResponse);
0136 if (err.IsOk())
0137 edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
0138 << ", version = " << serverMetaResponse.version();
0139 else
0140 edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")";
0141 }
0142
0143
0144
0145 inference::RepositoryIndexResponse repoIndexResponse;
0146 auto err = client->ModelRepositoryIndex(&repoIndexResponse);
0147
0148
0149 if (verbose_)
0150 msg += serverName + ": ";
0151 if (err.IsOk()) {
0152 for (const auto& modelIndex : repoIndexResponse.models()) {
0153 const auto& modelName = modelIndex.name();
0154 auto mit = models_.find(modelName);
0155 if (mit == models_.end())
0156 mit = models_.emplace(modelName, "").first;
0157 auto& modelInfo(mit->second);
0158 modelInfo.servers.insert(serverName);
0159 server.models.insert(modelName);
0160 if (verbose_)
0161 msg += modelName + ", ";
0162 }
0163 } else {
0164 const std::string& baseMsg = "unable to get repository index";
0165 const std::string& extraMsg = err.Message().empty() ? "" : ": " + err.Message();
0166 if (verbose_)
0167 msg += baseMsg + extraMsg;
0168 else
0169 edm::LogWarning("TritonFailure") << "TritonService(): " << baseMsg << " for " << serverName << " ("
0170 << server.url << ")" << extraMsg;
0171 }
0172 if (verbose_)
0173 msg += "\n";
0174 }
0175 if (verbose_)
0176 edm::LogInfo("TritonDiscovery") << msg;
0177 }
0178
0179 void TritonService::preallocate(edm::service::SystemBounds const& bounds) {
0180 numberOfThreads_ = bounds.maxNumberOfThreads();
0181 }
0182
0183 void TritonService::preModuleConstruction(edm::ModuleDescription const& desc) {
0184 currentModuleId_ = desc.id();
0185 allowAddModel_ = true;
0186 }
0187
0188 void TritonService::addModel(const std::string& modelName, const std::string& path) {
0189
0190 if (!allowAddModel_)
0191 throw cms::Exception("DisallowedAddModel")
0192 << "TritonService: Attempt to call addModel() outside of module constructors";
0193
0194 auto mit = models_.find(modelName);
0195 if (mit == models_.end()) {
0196 auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
0197 modelInfo.modules.insert(currentModuleId_);
0198
0199 modules_.emplace(currentModuleId_, modelName);
0200 }
0201 }
0202
0203 void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; }
0204
0205 void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) {
0206
0207 if (unservedModels_.empty())
0208 return;
0209 auto id = desc.id();
0210 auto oit = modules_.find(id);
0211 if (oit != modules_.end()) {
0212 const auto& moduleInfo(oit->second);
0213 auto mit = unservedModels_.find(moduleInfo.model);
0214 if (mit != unservedModels_.end()) {
0215 auto& modelInfo(mit->second);
0216 modelInfo.modules.erase(id);
0217
0218 if (modelInfo.modules.empty())
0219 unservedModels_.erase(mit);
0220 }
0221 modules_.erase(oit);
0222 }
0223 }
0224
0225
0226 TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const {
0227 auto mit = models_.find(model);
0228 if (mit == models_.end())
0229 throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
0230 const auto& modelInfo(mit->second);
0231 const auto& modelServers = modelInfo.servers;
0232
0233 auto msit = modelServers.end();
0234 if (!preferred.empty()) {
0235 msit = modelServers.find(preferred);
0236
0237 if (msit == modelServers.end())
0238 edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
0239 << " not available, will choose another server";
0240 }
0241 const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
0242
0243
0244 const auto& server(servers_.find(serverName)->second);
0245 return server;
0246 }
0247
0248 void TritonService::preBeginJob(edm::ProcessContext const&) {
0249
0250 if (!fallbackOpts_.enable or unservedModels_.empty())
0251 return;
0252
0253
0254 auto serverType = TritonServerType::LocalCPU;
0255 if (fallbackOpts_.device == "gpu")
0256 serverType = TritonServerType::LocalGPU;
0257 servers_.emplace(std::piecewise_construct,
0258 std::forward_as_tuple(Server::fallbackName),
0259 std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
0260
0261 std::string msg;
0262 if (verbose_)
0263 msg = "List of models for fallback server: ";
0264
0265 auto& server(servers_.find(Server::fallbackName)->second);
0266 for (const auto& [modelName, model] : unservedModels_) {
0267 auto& modelInfo(models_.emplace(modelName, model).first->second);
0268 modelInfo.servers.insert(Server::fallbackName);
0269 server.models.insert(modelName);
0270 if (verbose_)
0271 msg += modelName + ", ";
0272 }
0273 if (verbose_)
0274 edm::LogInfo("TritonDiscovery") << msg;
0275
0276
0277 fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
0278 fallbackOpts_.command += " -g " + fallbackOpts_.device;
0279 fallbackOpts_.command += " -d " + fallbackOpts_.container;
0280 if (fallbackOpts_.debug)
0281 fallbackOpts_.command += " -c";
0282 if (fallbackOpts_.verbose)
0283 fallbackOpts_.command += " -v";
0284 if (!fallbackOpts_.instanceName.empty())
0285 fallbackOpts_.command += " -n " + fallbackOpts_.instanceName;
0286 if (fallbackOpts_.retries >= 0)
0287 fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries);
0288 if (fallbackOpts_.wait >= 0)
0289 fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait);
0290 for (const auto& [modelName, model] : unservedModels_) {
0291 fallbackOpts_.command += " -m " + model.path;
0292 }
0293 std::string thread_string = " -I " + std::to_string(numberOfThreads_);
0294 fallbackOpts_.command += thread_string;
0295 if (!fallbackOpts_.imageName.empty())
0296 fallbackOpts_.command += " -i " + fallbackOpts_.imageName;
0297 if (!fallbackOpts_.sandboxName.empty())
0298 fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName;
0299
0300 unservedModels_.clear();
0301
0302
0303 if (fallbackOpts_.tempDir.empty()) {
0304 auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
0305 fallbackOpts_.tempDir = tmp_dir_path.string();
0306 }
0307
0308 if (fallbackOpts_.tempDir != ".")
0309 fallbackOpts_.command += " -t " + fallbackOpts_.tempDir;
0310
0311 std::string command = fallbackOpts_.command + " start";
0312
0313 if (fallbackOpts_.debug)
0314 edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
0315 if (verbose_)
0316 edm::LogInfo("TritonService") << command;
0317
0318
0319 startedFallback_ = true;
0320 const auto& [output, rv] = execSys(command);
0321 if (rv != 0) {
0322 edm::LogError("TritonService") << output;
0323 printFallbackServerLog<edm::LogError>();
0324 throw edm::Exception(edm::errors::ExternalFailure)
0325 << "TritonService: Starting the fallback server failed with exit code " << rv;
0326 } else if (verbose_)
0327 edm::LogInfo("TritonService") << output;
0328
0329
0330 std::string chosenDevice(fallbackOpts_.device);
0331 if (chosenDevice == "auto") {
0332 chosenDevice = extractFromLog(output, "CMS_TRITON_CHOSEN_DEVICE: ");
0333 if (!chosenDevice.empty()) {
0334 if (chosenDevice == "cpu")
0335 server.type = TritonServerType::LocalCPU;
0336 else if (chosenDevice == "gpu")
0337 server.type = TritonServerType::LocalGPU;
0338 else
0339 throw edm::Exception(edm::errors::ExternalFailure)
0340 << "TritonService: unsupported device choice " << chosenDevice << " for fallback server, log follows:\n"
0341 << output;
0342 } else
0343 throw edm::Exception(edm::errors::ExternalFailure)
0344 << "TritonService: unknown device choice for fallback server, log follows:\n"
0345 << output;
0346 }
0347
0348 std::transform(chosenDevice.begin(), chosenDevice.end(), chosenDevice.begin(), toupper);
0349 if (verbose_)
0350 edm::LogInfo("TritonDiscovery") << "Fallback server started: " << chosenDevice;
0351
0352
0353 const auto& portNum = extractFromLog(output, "CMS_TRITON_GRPC_PORT: ");
0354 if (!portNum.empty())
0355 server.url += ":" + portNum;
0356 else
0357 throw edm::Exception(edm::errors::ExternalFailure)
0358 << "TritonService: Unknown port for fallback server, log follows:\n"
0359 << output;
0360 }
0361
0362 void TritonService::notifyCallStatus(bool status) const {
0363 if (status)
0364 --callFails_;
0365 else
0366 ++callFails_;
0367 }
0368
0369 void TritonService::postEndJob() {
0370 if (!startedFallback_)
0371 return;
0372
0373 std::string command = fallbackOpts_.command;
0374
0375 if (callFails_ > 0)
0376 command += " -c";
0377 command += " stop";
0378 if (verbose_)
0379 edm::LogInfo("TritonService") << command;
0380
0381 const auto& [output, rv] = execSys(command);
0382 if (rv != 0 or callFails_ > 0) {
0383
0384 edm::LogError("TritonService") << output;
0385 printFallbackServerLog<edm::LogError>();
0386 if (rv != 0) {
0387 std::string stopCat("FallbackFailed");
0388 std::string stopMsg = fmt::format("TritonService: Stopping the fallback server failed with exit code {}", rv);
0389
0390 if (callFails_ > 0)
0391 edm::LogWarning(stopCat) << stopMsg;
0392 else
0393 throw cms::Exception(stopCat) << stopMsg;
0394 }
0395 } else if (verbose_) {
0396 edm::LogInfo("TritonService") << output;
0397 printFallbackServerLog<edm::LogInfo>();
0398 }
0399 }
0400
0401 template <typename LOG>
0402 void TritonService::printFallbackServerLog() const {
0403 std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
0404
0405
0406 logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
0407 bool foundLog = false;
0408 for (const auto& logName : logNames) {
0409 std::ifstream infile(logName);
0410 if (infile.is_open()) {
0411 LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
0412 foundLog = true;
0413 break;
0414 }
0415 }
0416 if (!foundLog)
0417 LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
0418 << fallbackOpts_.tempDir;
0419 }
0420
0421 void TritonService::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0422 edm::ParameterSetDescription desc;
0423 desc.addUntracked<bool>("verbose", false);
0424
0425 edm::ParameterSetDescription validator;
0426 validator.addUntracked<std::string>("name");
0427 validator.addUntracked<std::string>("address");
0428 validator.addUntracked<unsigned>("port");
0429 validator.addUntracked<bool>("useSsl", false);
0430 validator.addUntracked<std::string>("rootCertificates", "");
0431 validator.addUntracked<std::string>("privateKey", "");
0432 validator.addUntracked<std::string>("certificateChain", "");
0433
0434 desc.addVPSetUntracked("servers", validator, {});
0435
0436 edm::ParameterSetDescription fallbackDesc;
0437 fallbackDesc.addUntracked<bool>("enable", false);
0438 fallbackDesc.addUntracked<bool>("debug", false);
0439 fallbackDesc.addUntracked<bool>("verbose", false);
0440 fallbackDesc.ifValue(edm::ParameterDescription<std::string>("container", "apptainer", false),
0441 edm::allowedValues<std::string>("apptainer", "docker", "podman"));
0442 fallbackDesc.ifValue(edm::ParameterDescription<std::string>("device", "auto", false),
0443 edm::allowedValues<std::string>("auto", "cpu", "gpu"));
0444 fallbackDesc.addUntracked<int>("retries", -1);
0445 fallbackDesc.addUntracked<int>("wait", -1);
0446 fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
0447 fallbackDesc.addUntracked<std::string>("instanceName", "");
0448 fallbackDesc.addUntracked<std::string>("tempDir", "");
0449 fallbackDesc.addUntracked<std::string>("imageName", "");
0450 fallbackDesc.addUntracked<std::string>("sandboxName", "");
0451 desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
0452
0453 descriptions.addWithDefaultLabel(desc);
0454 }