Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2025-04-13 22:50:00

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0003 
0004 #include "DataFormats/Provenance/interface/ModuleDescription.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/allowedValues.h"
0007 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0009 #include "FWCore/ServiceRegistry/interface/ActivityRegistry.h"
0010 #include "FWCore/ServiceRegistry/interface/SystemBounds.h"
0011 #include "FWCore/ServiceRegistry/interface/ProcessContext.h"
0012 #include "FWCore/Utilities/interface/Exception.h"
0013 #include "FWCore/Utilities/interface/GetEnvironmentVariable.h"
0014 
0015 #include "grpc_client.h"
0016 #include "grpc_service.pb.h"
0017 
0018 #include <algorithm>
0019 #include <cctype>
0020 #include <cstdio>
0021 #include <cstdlib>
0022 #include <filesystem>
0023 #include <fstream>
0024 #include <utility>
0025 #include <tuple>
0026 #include <unistd.h>
0027 
0028 namespace tc = triton::client;
0029 
0030 const std::string TritonService::Server::fallbackName{"fallback"};
0031 const std::string TritonService::Server::fallbackAddress{"0.0.0.0"};
0032 const std::string TritonService::Server::siteconfName{"SONIC_LOCAL_BALANCER"};
0033 
0034 namespace {
0035   std::pair<std::string, int> execSys(const std::string& cmd) {
0036     //redirect stderr to stdout
0037     auto pipe = popen((cmd + " 2>&1").c_str(), "r");
0038     int thisErrno = errno;
0039     if (!pipe)
0040       throw cms::Exception("SystemError")
0041           << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
0042 
0043     //extract output
0044     constexpr static unsigned buffSize = 128;
0045     std::array<char, buffSize> buffer;
0046     std::string result;
0047     while (!feof(pipe)) {
0048       if (fgets(buffer.data(), buffSize, pipe))
0049         result += buffer.data();
0050       else {
0051         thisErrno = ferror(pipe);
0052         if (thisErrno)
0053           throw cms::Exception("SystemError")
0054               << "TritonService: failed reading command output with errno " << thisErrno;
0055       }
0056     }
0057 
0058     int rv = pclose(pipe);
0059     return std::make_pair(result, rv);
0060   }
0061 
0062   //extract specific info from log
0063   std::string extractFromLog(const std::string& output, const std::string& indicator) {
0064     //find last instance in log (in case of multiple)
0065     auto pos = output.rfind(indicator);
0066     if (pos != std::string::npos) {
0067       auto pos2 = pos + indicator.size();
0068       auto pos3 = output.find('\n', pos2);
0069       return output.substr(pos2, pos3 - pos2);
0070     } else
0071       return "";
0072   }
0073 }  // namespace
0074 
0075 TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg)
0076     : verbose_(pset.getUntrackedParameter<bool>("verbose")),
0077       fallbackOpts_(pset.getParameterSet("fallback")),
0078       currentModuleId_(0),
0079       allowAddModel_(false),
0080       startedFallback_(false),
0081       callFails_(0),
0082       pid_(std::to_string(::getpid())) {
0083   //module construction is assumed to be serial (correct at the time this code was written)
0084 
0085   areg.watchPreallocate(this, &TritonService::preallocate);
0086 
0087   areg.watchPreModuleConstruction(this, &TritonService::preModuleConstruction);
0088   areg.watchPostModuleConstruction(this, &TritonService::postModuleConstruction);
0089   areg.watchPreModuleDestruction(this, &TritonService::preModuleDestruction);
0090   //fallback server will be launched (if needed) before beginJob
0091   areg.watchPreBeginJob(this, &TritonService::preBeginJob);
0092   areg.watchPostEndJob(this, &TritonService::postEndJob);
0093 
0094   //check for server specified in SITECONF
0095   //(temporary solution, to be replaced with entry in site-local-config.xml or similar)
0096   std::string siteconf_address(edm::getEnvironmentVariable(Server::siteconfName + "_HOST"));
0097   std::string siteconf_port(edm::getEnvironmentVariable(Server::siteconfName + "_PORT"));
0098   if (!siteconf_address.empty() and !siteconf_port.empty()) {
0099     servers_.emplace(
0100         std::piecewise_construct,
0101         std::forward_as_tuple(Server::siteconfName),
0102         std::forward_as_tuple(Server::siteconfName, siteconf_address + ":" + siteconf_port, TritonServerType::Remote));
0103     if (verbose_)
0104       edm::LogInfo("TritonDiscovery") << "Obtained server from SITECONF: "
0105                                       << servers_.find(Server::siteconfName)->second.url;
0106   } else if (siteconf_address.empty() != siteconf_port.empty()) {  //xor
0107     edm::LogWarning("TritonDiscovery") << "Incomplete server information from SITECONF: HOST = " << siteconf_address
0108                                        << ", PORT = " << siteconf_port;
0109   } else
0110     edm::LogWarning("TritonDiscovery") << "No server information from SITECONF";
0111 
0112   //finally, populate list of servers from config input
0113   for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
0114     const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
0115     //ensure uniqueness
0116     auto [sit, unique] = servers_.emplace(serverName, serverPset);
0117     if (!unique)
0118       throw cms::Exception("DuplicateServer")
0119           << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
0120   }
0121 
0122   //loop over all servers: check which models they have
0123   std::string msg;
0124   if (verbose_)
0125     msg = "List of models for each server:\n";
0126   for (auto& [serverName, server] : servers_) {
0127     std::unique_ptr<tc::InferenceServerGrpcClient> client;
0128     TRITON_THROW_IF_ERROR(
0129         tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
0130         "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")",
0131         false);
0132 
0133     if (verbose_) {
0134       inference::ServerMetadataResponse serverMetaResponse;
0135       auto err = client->ServerMetadata(&serverMetaResponse);
0136       if (err.IsOk())
0137         edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
0138                                       << ", version = " << serverMetaResponse.version();
0139       else
0140         edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")";
0141     }
0142 
0143     //if this query fails, it indicates that the server is nonresponsive or saturated
0144     //in which case it should just be skipped
0145     inference::RepositoryIndexResponse repoIndexResponse;
0146     auto err = client->ModelRepositoryIndex(&repoIndexResponse);
0147 
0148     //servers keep track of models and vice versa
0149     if (verbose_)
0150       msg += serverName + ": ";
0151     if (err.IsOk()) {
0152       for (const auto& modelIndex : repoIndexResponse.models()) {
0153         const auto& modelName = modelIndex.name();
0154         auto mit = models_.find(modelName);
0155         if (mit == models_.end())
0156           mit = models_.emplace(modelName, "").first;
0157         auto& modelInfo(mit->second);
0158         modelInfo.servers.insert(serverName);
0159         server.models.insert(modelName);
0160         if (verbose_)
0161           msg += modelName + ", ";
0162       }
0163     } else {
0164       const std::string& baseMsg = "unable to get repository index";
0165       const std::string& extraMsg = err.Message().empty() ? "" : ": " + err.Message();
0166       if (verbose_)
0167         msg += baseMsg + extraMsg;
0168       else
0169         edm::LogWarning("TritonFailure") << "TritonService(): " << baseMsg << " for " << serverName << " ("
0170                                          << server.url << ")" << extraMsg;
0171     }
0172     if (verbose_)
0173       msg += "\n";
0174   }
0175   if (verbose_)
0176     edm::LogInfo("TritonDiscovery") << msg;
0177 }
0178 
0179 void TritonService::preallocate(edm::service::SystemBounds const& bounds) {
0180   numberOfThreads_ = bounds.maxNumberOfThreads();
0181 }
0182 
0183 void TritonService::preModuleConstruction(edm::ModuleDescription const& desc) {
0184   currentModuleId_ = desc.id();
0185   allowAddModel_ = true;
0186 }
0187 
0188 void TritonService::addModel(const std::string& modelName, const std::string& path) {
0189   //should only be called in module constructors
0190   if (!allowAddModel_)
0191     throw cms::Exception("DisallowedAddModel")
0192         << "TritonService: Attempt to call addModel() outside of module constructors";
0193   //if model is not in the list, then no specified server provides it
0194   auto mit = models_.find(modelName);
0195   if (mit == models_.end()) {
0196     auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
0197     modelInfo.modules.insert(currentModuleId_);
0198     //only keep track of modules that need unserved models
0199     modules_.emplace(currentModuleId_, modelName);
0200   }
0201 }
0202 
0203 void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; }
0204 
0205 void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) {
0206   //remove destructed modules from unserved list
0207   if (unservedModels_.empty())
0208     return;
0209   auto id = desc.id();
0210   auto oit = modules_.find(id);
0211   if (oit != modules_.end()) {
0212     const auto& moduleInfo(oit->second);
0213     auto mit = unservedModels_.find(moduleInfo.model);
0214     if (mit != unservedModels_.end()) {
0215       auto& modelInfo(mit->second);
0216       modelInfo.modules.erase(id);
0217       //remove a model if it is no longer needed by any modules
0218       if (modelInfo.modules.empty())
0219         unservedModels_.erase(mit);
0220     }
0221     modules_.erase(oit);
0222   }
0223 }
0224 
0225 //second return value is only true if fallback CPU server is being used
0226 TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const {
0227   auto mit = models_.find(model);
0228   if (mit == models_.end())
0229     throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
0230   const auto& modelInfo(mit->second);
0231   const auto& modelServers = modelInfo.servers;
0232 
0233   auto msit = modelServers.end();
0234   if (!preferred.empty()) {
0235     msit = modelServers.find(preferred);
0236     //todo: add a "strict" parameter to stop execution if preferred server isn't found?
0237     if (msit == modelServers.end())
0238       edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
0239                                          << " not available, will choose another server";
0240   }
0241   const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
0242 
0243   //todo: use some algorithm to select server rather than just picking arbitrarily
0244   const auto& server(servers_.find(serverName)->second);
0245   return server;
0246 }
0247 
0248 void TritonService::preBeginJob(edm::ProcessContext const&) {
0249   //only need fallback if there are unserved models
0250   if (!fallbackOpts_.enable or unservedModels_.empty())
0251     return;
0252 
0253   //include fallback server in set
0254   auto serverType = TritonServerType::LocalCPU;
0255   if (fallbackOpts_.device == "gpu")
0256     serverType = TritonServerType::LocalGPU;
0257   servers_.emplace(std::piecewise_construct,
0258                    std::forward_as_tuple(Server::fallbackName),
0259                    std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
0260 
0261   std::string msg;
0262   if (verbose_)
0263     msg = "List of models for fallback server: ";
0264   //all unserved models are provided by fallback server
0265   auto& server(servers_.find(Server::fallbackName)->second);
0266   for (const auto& [modelName, model] : unservedModels_) {
0267     auto& modelInfo(models_.emplace(modelName, model).first->second);
0268     modelInfo.servers.insert(Server::fallbackName);
0269     server.models.insert(modelName);
0270     if (verbose_)
0271       msg += modelName + ", ";
0272   }
0273   if (verbose_)
0274     edm::LogInfo("TritonDiscovery") << msg;
0275 
0276   //assemble server start command
0277   fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
0278   fallbackOpts_.command += " -g " + fallbackOpts_.device;
0279   fallbackOpts_.command += " -d " + fallbackOpts_.container;
0280   if (fallbackOpts_.debug)
0281     fallbackOpts_.command += " -c";
0282   if (fallbackOpts_.verbose)
0283     fallbackOpts_.command += " -v";
0284   if (!fallbackOpts_.instanceName.empty())
0285     fallbackOpts_.command += " -n " + fallbackOpts_.instanceName;
0286   if (fallbackOpts_.retries >= 0)
0287     fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries);
0288   if (fallbackOpts_.wait >= 0)
0289     fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait);
0290   for (const auto& [modelName, model] : unservedModels_) {
0291     fallbackOpts_.command += " -m " + model.path;
0292   }
0293   std::string thread_string = " -I " + std::to_string(numberOfThreads_);
0294   fallbackOpts_.command += thread_string;
0295   if (!fallbackOpts_.imageName.empty())
0296     fallbackOpts_.command += " -i " + fallbackOpts_.imageName;
0297   if (!fallbackOpts_.sandboxName.empty())
0298     fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName;
0299   //don't need this anymore
0300   unservedModels_.clear();
0301 
0302   //get a random temporary directory if none specified
0303   if (fallbackOpts_.tempDir.empty()) {
0304     auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
0305     fallbackOpts_.tempDir = tmp_dir_path.string();
0306   }
0307   //special case ".": use script default (temp dir = .$instanceName)
0308   if (fallbackOpts_.tempDir != ".")
0309     fallbackOpts_.command += " -t " + fallbackOpts_.tempDir;
0310 
0311   std::string command = fallbackOpts_.command + " start";
0312 
0313   if (fallbackOpts_.debug)
0314     edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
0315   if (verbose_)
0316     edm::LogInfo("TritonService") << command;
0317 
0318   //mark as started before executing in case of ctrl+c while command is running
0319   startedFallback_ = true;
0320   const auto& [output, rv] = execSys(command);
0321   if (rv != 0) {
0322     edm::LogError("TritonService") << output;
0323     printFallbackServerLog<edm::LogError>();
0324     throw edm::Exception(edm::errors::ExternalFailure)
0325         << "TritonService: Starting the fallback server failed with exit code " << rv;
0326   } else if (verbose_)
0327     edm::LogInfo("TritonService") << output;
0328 
0329   //get the chosen device
0330   std::string chosenDevice(fallbackOpts_.device);
0331   if (chosenDevice == "auto") {
0332     chosenDevice = extractFromLog(output, "CMS_TRITON_CHOSEN_DEVICE: ");
0333     if (!chosenDevice.empty()) {
0334       if (chosenDevice == "cpu")
0335         server.type = TritonServerType::LocalCPU;
0336       else if (chosenDevice == "gpu")
0337         server.type = TritonServerType::LocalGPU;
0338       else
0339         throw edm::Exception(edm::errors::ExternalFailure)
0340             << "TritonService: unsupported device choice " << chosenDevice << " for fallback server, log follows:\n"
0341             << output;
0342     } else
0343       throw edm::Exception(edm::errors::ExternalFailure)
0344           << "TritonService: unknown device choice for fallback server, log follows:\n"
0345           << output;
0346   }
0347   //print server info
0348   std::transform(chosenDevice.begin(), chosenDevice.end(), chosenDevice.begin(), toupper);
0349   if (verbose_)
0350     edm::LogInfo("TritonDiscovery") << "Fallback server started: " << chosenDevice;
0351 
0352   //get the port
0353   const auto& portNum = extractFromLog(output, "CMS_TRITON_GRPC_PORT: ");
0354   if (!portNum.empty())
0355     server.url += ":" + portNum;
0356   else
0357     throw edm::Exception(edm::errors::ExternalFailure)
0358         << "TritonService: Unknown port for fallback server, log follows:\n"
0359         << output;
0360 }
0361 
0362 void TritonService::notifyCallStatus(bool status) const {
0363   if (status)
0364     --callFails_;
0365   else
0366     ++callFails_;
0367 }
0368 
0369 void TritonService::postEndJob() {
0370   if (!startedFallback_)
0371     return;
0372 
0373   std::string command = fallbackOpts_.command;
0374   //prevent log cleanup during server stop
0375   if (callFails_ > 0)
0376     command += " -c";
0377   command += " stop";
0378   if (verbose_)
0379     edm::LogInfo("TritonService") << command;
0380 
0381   const auto& [output, rv] = execSys(command);
0382   if (rv != 0 or callFails_ > 0) {
0383     //print logs if cmsRun is currently exiting because of a TritonException
0384     edm::LogError("TritonService") << output;
0385     printFallbackServerLog<edm::LogError>();
0386     if (rv != 0) {
0387       std::string stopCat("FallbackFailed");
0388       std::string stopMsg = fmt::format("TritonService: Stopping the fallback server failed with exit code {}", rv);
0389       //avoid throwing if the stack is already unwinding
0390       if (callFails_ > 0)
0391         edm::LogWarning(stopCat) << stopMsg;
0392       else
0393         throw cms::Exception(stopCat) << stopMsg;
0394     }
0395   } else if (verbose_) {
0396     edm::LogInfo("TritonService") << output;
0397     printFallbackServerLog<edm::LogInfo>();
0398   }
0399 }
0400 
0401 template <typename LOG>
0402 void TritonService::printFallbackServerLog() const {
0403   std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
0404   //cmsTriton script moves log from temp to current dir in verbose mode or in some cases when auto_stop is called
0405   // -> check both places
0406   logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
0407   bool foundLog = false;
0408   for (const auto& logName : logNames) {
0409     std::ifstream infile(logName);
0410     if (infile.is_open()) {
0411       LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
0412       foundLog = true;
0413       break;
0414     }
0415   }
0416   if (!foundLog)
0417     LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
0418                          << fallbackOpts_.tempDir;
0419 }
0420 
0421 void TritonService::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0422   edm::ParameterSetDescription desc;
0423   desc.addUntracked<bool>("verbose", false);
0424 
0425   edm::ParameterSetDescription validator;
0426   validator.addUntracked<std::string>("name");
0427   validator.addUntracked<std::string>("address");
0428   validator.addUntracked<unsigned>("port");
0429   validator.addUntracked<bool>("useSsl", false);
0430   validator.addUntracked<std::string>("rootCertificates", "");
0431   validator.addUntracked<std::string>("privateKey", "");
0432   validator.addUntracked<std::string>("certificateChain", "");
0433 
0434   desc.addVPSetUntracked("servers", validator, {});
0435 
0436   edm::ParameterSetDescription fallbackDesc;
0437   fallbackDesc.addUntracked<bool>("enable", false);
0438   fallbackDesc.addUntracked<bool>("debug", false);
0439   fallbackDesc.addUntracked<bool>("verbose", false);
0440   fallbackDesc.ifValue(edm::ParameterDescription<std::string>("container", "apptainer", false),
0441                        edm::allowedValues<std::string>("apptainer", "docker", "podman"));
0442   fallbackDesc.ifValue(edm::ParameterDescription<std::string>("device", "auto", false),
0443                        edm::allowedValues<std::string>("auto", "cpu", "gpu"));
0444   fallbackDesc.addUntracked<int>("retries", -1);
0445   fallbackDesc.addUntracked<int>("wait", -1);
0446   fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
0447   fallbackDesc.addUntracked<std::string>("instanceName", "");
0448   fallbackDesc.addUntracked<std::string>("tempDir", "");
0449   fallbackDesc.addUntracked<std::string>("imageName", "");
0450   fallbackDesc.addUntracked<std::string>("sandboxName", "");
0451   desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
0452 
0453   descriptions.addWithDefaultLabel(desc);
0454 }