Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-06 22:36:40

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0003 
0004 #include "DataFormats/Provenance/interface/ModuleDescription.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/allowedValues.h"
0007 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0008 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0009 #include "FWCore/ServiceRegistry/interface/ActivityRegistry.h"
0010 #include "FWCore/ServiceRegistry/interface/SystemBounds.h"
0011 #include "FWCore/ServiceRegistry/interface/ProcessContext.h"
0012 #include "FWCore/Utilities/interface/Exception.h"
0013 #include "FWCore/Utilities/interface/GetEnvironmentVariable.h"
0014 
0015 #include "grpc_client.h"
0016 #include "grpc_service.pb.h"
0017 
0018 #include <algorithm>
0019 #include <cctype>
0020 #include <cstdio>
0021 #include <cstdlib>
0022 #include <filesystem>
0023 #include <fstream>
0024 #include <utility>
0025 #include <tuple>
0026 #include <unistd.h>
0027 
0028 namespace tc = triton::client;
0029 
0030 const std::string TritonService::Server::fallbackName{"fallback"};
0031 const std::string TritonService::Server::fallbackAddress{"0.0.0.0"};
0032 const std::string TritonService::Server::siteconfName{"SONIC_LOCAL_BALANCER"};
0033 
0034 namespace {
0035   std::pair<std::string, int> execSys(const std::string& cmd) {
0036     //redirect stderr to stdout
0037     auto pipe = popen((cmd + " 2>&1").c_str(), "r");
0038     int thisErrno = errno;
0039     if (!pipe)
0040       throw cms::Exception("SystemError")
0041           << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
0042 
0043     //extract output
0044     constexpr static unsigned buffSize = 128;
0045     std::array<char, buffSize> buffer;
0046     std::string result;
0047     while (!feof(pipe)) {
0048       if (fgets(buffer.data(), buffSize, pipe))
0049         result += buffer.data();
0050       else {
0051         thisErrno = ferror(pipe);
0052         if (thisErrno)
0053           throw cms::Exception("SystemError")
0054               << "TritonService: failed reading command output with errno " << thisErrno;
0055       }
0056     }
0057 
0058     int rv = pclose(pipe);
0059     return std::make_pair(result, rv);
0060   }
0061 
0062   //extract specific info from log
0063   std::string extractFromLog(const std::string& output, const std::string& indicator) {
0064     //find last instance in log (in case of multiple)
0065     auto pos = output.rfind(indicator);
0066     if (pos != std::string::npos) {
0067       auto pos2 = pos + indicator.size();
0068       auto pos3 = output.find('\n', pos2);
0069       return output.substr(pos2, pos3 - pos2);
0070     } else
0071       return "";
0072   }
0073 }  // namespace
0074 
0075 TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg)
0076     : verbose_(pset.getUntrackedParameter<bool>("verbose")),
0077       fallbackOpts_(pset.getParameterSet("fallback")),
0078       currentModuleId_(0),
0079       allowAddModel_(false),
0080       startedFallback_(false),
0081       callFails_(0),
0082       pid_(std::to_string(::getpid())) {
0083   //module construction is assumed to be serial (correct at the time this code was written)
0084 
0085   areg.watchPreallocate(this, &TritonService::preallocate);
0086 
0087   areg.watchPreModuleConstruction(this, &TritonService::preModuleConstruction);
0088   areg.watchPostModuleConstruction(this, &TritonService::postModuleConstruction);
0089   areg.watchPreModuleDestruction(this, &TritonService::preModuleDestruction);
0090   //fallback server will be launched (if needed) before beginJob
0091   areg.watchPreBeginJob(this, &TritonService::preBeginJob);
0092   areg.watchPostEndJob(this, &TritonService::postEndJob);
0093 
0094   //check for server specified in SITECONF
0095   //(temporary solution, to be replaced with entry in site-local-config.xml or similar)
0096   std::string siteconf_address(edm::getEnvironmentVariable(Server::siteconfName + "_HOST"));
0097   std::string siteconf_port(edm::getEnvironmentVariable(Server::siteconfName + "_PORT"));
0098   if (!siteconf_address.empty() and !siteconf_port.empty()) {
0099     servers_.emplace(
0100         std::piecewise_construct,
0101         std::forward_as_tuple(Server::siteconfName),
0102         std::forward_as_tuple(Server::siteconfName, siteconf_address + ":" + siteconf_port, TritonServerType::Remote));
0103     if (verbose_)
0104       edm::LogInfo("TritonDiscovery") << "Obtained server from SITECONF: "
0105                                       << servers_.find(Server::siteconfName)->second.url;
0106   } else if (siteconf_address.empty() != siteconf_port.empty()) {  //xor
0107     edm::LogWarning("TritonDiscovery") << "Incomplete server information from SITECONF: HOST = " << siteconf_address
0108                                        << ", PORT = " << siteconf_port;
0109   } else
0110     edm::LogWarning("TritonDiscovery") << "No server information from SITECONF";
0111 
0112   //finally, populate list of servers from config input
0113   for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
0114     const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
0115     //ensure uniqueness
0116     auto [sit, unique] = servers_.emplace(serverName, serverPset);
0117     if (!unique)
0118       throw cms::Exception("DuplicateServer")
0119           << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
0120   }
0121 
0122   //loop over all servers: check which models they have
0123   std::string msg;
0124   if (verbose_)
0125     msg = "List of models for each server:\n";
0126   for (auto& [serverName, server] : servers_) {
0127     std::unique_ptr<tc::InferenceServerGrpcClient> client;
0128     TRITON_THROW_IF_ERROR(
0129         tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
0130         "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")",
0131         false);
0132 
0133     if (verbose_) {
0134       inference::ServerMetadataResponse serverMetaResponse;
0135       auto err = client->ServerMetadata(&serverMetaResponse);
0136       if (err.IsOk())
0137         edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
0138                                       << ", version = " << serverMetaResponse.version();
0139       else
0140         edm::LogInfo("TritonService") << "unable to get metadata for " + serverName + " (" + server.url + ")";
0141     }
0142 
0143     //if this query fails, it indicates that the server is nonresponsive or saturated
0144     //in which case it should just be skipped
0145     inference::RepositoryIndexResponse repoIndexResponse;
0146     auto err = client->ModelRepositoryIndex(&repoIndexResponse);
0147 
0148     //servers keep track of models and vice versa
0149     if (verbose_)
0150       msg += serverName + ": ";
0151     if (err.IsOk()) {
0152       for (const auto& modelIndex : repoIndexResponse.models()) {
0153         const auto& modelName = modelIndex.name();
0154         auto mit = models_.find(modelName);
0155         if (mit == models_.end())
0156           mit = models_.emplace(modelName, "").first;
0157         auto& modelInfo(mit->second);
0158         modelInfo.servers.insert(serverName);
0159         server.models.insert(modelName);
0160         if (verbose_)
0161           msg += modelName + ", ";
0162       }
0163     } else {
0164       if (verbose_)
0165         msg += "unable to get repository index";
0166       else
0167         edm::LogWarning("TritonFailure") << "TritonService(): unable to get repository index for " + serverName + " (" +
0168                                                 server.url + ")";
0169     }
0170     if (verbose_)
0171       msg += "\n";
0172   }
0173   if (verbose_)
0174     edm::LogInfo("TritonDiscovery") << msg;
0175 }
0176 
0177 void TritonService::preallocate(edm::service::SystemBounds const& bounds) {
0178   numberOfThreads_ = bounds.maxNumberOfThreads();
0179 }
0180 
0181 void TritonService::preModuleConstruction(edm::ModuleDescription const& desc) {
0182   currentModuleId_ = desc.id();
0183   allowAddModel_ = true;
0184 }
0185 
0186 void TritonService::addModel(const std::string& modelName, const std::string& path) {
0187   //should only be called in module constructors
0188   if (!allowAddModel_)
0189     throw cms::Exception("DisallowedAddModel")
0190         << "TritonService: Attempt to call addModel() outside of module constructors";
0191   //if model is not in the list, then no specified server provides it
0192   auto mit = models_.find(modelName);
0193   if (mit == models_.end()) {
0194     auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
0195     modelInfo.modules.insert(currentModuleId_);
0196     //only keep track of modules that need unserved models
0197     modules_.emplace(currentModuleId_, modelName);
0198   }
0199 }
0200 
0201 void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; }
0202 
0203 void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) {
0204   //remove destructed modules from unserved list
0205   if (unservedModels_.empty())
0206     return;
0207   auto id = desc.id();
0208   auto oit = modules_.find(id);
0209   if (oit != modules_.end()) {
0210     const auto& moduleInfo(oit->second);
0211     auto mit = unservedModels_.find(moduleInfo.model);
0212     if (mit != unservedModels_.end()) {
0213       auto& modelInfo(mit->second);
0214       modelInfo.modules.erase(id);
0215       //remove a model if it is no longer needed by any modules
0216       if (modelInfo.modules.empty())
0217         unservedModels_.erase(mit);
0218     }
0219     modules_.erase(oit);
0220   }
0221 }
0222 
0223 //second return value is only true if fallback CPU server is being used
0224 TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const {
0225   auto mit = models_.find(model);
0226   if (mit == models_.end())
0227     throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
0228   const auto& modelInfo(mit->second);
0229   const auto& modelServers = modelInfo.servers;
0230 
0231   auto msit = modelServers.end();
0232   if (!preferred.empty()) {
0233     msit = modelServers.find(preferred);
0234     //todo: add a "strict" parameter to stop execution if preferred server isn't found?
0235     if (msit == modelServers.end())
0236       edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
0237                                          << " not available, will choose another server";
0238   }
0239   const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
0240 
0241   //todo: use some algorithm to select server rather than just picking arbitrarily
0242   const auto& server(servers_.find(serverName)->second);
0243   return server;
0244 }
0245 
0246 void TritonService::preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&) {
0247   //only need fallback if there are unserved models
0248   if (!fallbackOpts_.enable or unservedModels_.empty())
0249     return;
0250 
0251   //include fallback server in set
0252   auto serverType = TritonServerType::LocalCPU;
0253   if (fallbackOpts_.device == "gpu")
0254     serverType = TritonServerType::LocalGPU;
0255   servers_.emplace(std::piecewise_construct,
0256                    std::forward_as_tuple(Server::fallbackName),
0257                    std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
0258 
0259   std::string msg;
0260   if (verbose_)
0261     msg = "List of models for fallback server: ";
0262   //all unserved models are provided by fallback server
0263   auto& server(servers_.find(Server::fallbackName)->second);
0264   for (const auto& [modelName, model] : unservedModels_) {
0265     auto& modelInfo(models_.emplace(modelName, model).first->second);
0266     modelInfo.servers.insert(Server::fallbackName);
0267     server.models.insert(modelName);
0268     if (verbose_)
0269       msg += modelName + ", ";
0270   }
0271   if (verbose_)
0272     edm::LogInfo("TritonDiscovery") << msg;
0273 
0274   //assemble server start command
0275   fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
0276   fallbackOpts_.command += " -g " + fallbackOpts_.device;
0277   fallbackOpts_.command += " -d " + fallbackOpts_.container;
0278   if (fallbackOpts_.debug)
0279     fallbackOpts_.command += " -c";
0280   if (fallbackOpts_.verbose)
0281     fallbackOpts_.command += " -v";
0282   if (!fallbackOpts_.instanceName.empty())
0283     fallbackOpts_.command += " -n " + fallbackOpts_.instanceName;
0284   if (fallbackOpts_.retries >= 0)
0285     fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries);
0286   if (fallbackOpts_.wait >= 0)
0287     fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait);
0288   for (const auto& [modelName, model] : unservedModels_) {
0289     fallbackOpts_.command += " -m " + model.path;
0290   }
0291   std::string thread_string = " -I " + std::to_string(numberOfThreads_);
0292   fallbackOpts_.command += thread_string;
0293   if (!fallbackOpts_.imageName.empty())
0294     fallbackOpts_.command += " -i " + fallbackOpts_.imageName;
0295   if (!fallbackOpts_.sandboxName.empty())
0296     fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName;
0297   //don't need this anymore
0298   unservedModels_.clear();
0299 
0300   //get a random temporary directory if none specified
0301   if (fallbackOpts_.tempDir.empty()) {
0302     auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
0303     fallbackOpts_.tempDir = tmp_dir_path.string();
0304   }
0305   //special case ".": use script default (temp dir = .$instanceName)
0306   if (fallbackOpts_.tempDir != ".")
0307     fallbackOpts_.command += " -t " + fallbackOpts_.tempDir;
0308 
0309   std::string command = fallbackOpts_.command + " start";
0310 
0311   if (fallbackOpts_.debug)
0312     edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
0313   if (verbose_)
0314     edm::LogInfo("TritonService") << command;
0315 
0316   //mark as started before executing in case of ctrl+c while command is running
0317   startedFallback_ = true;
0318   const auto& [output, rv] = execSys(command);
0319   if (rv != 0) {
0320     edm::LogError("TritonService") << output;
0321     printFallbackServerLog<edm::LogError>();
0322     throw edm::Exception(edm::errors::ExternalFailure)
0323         << "TritonService: Starting the fallback server failed with exit code " << rv;
0324   } else if (verbose_)
0325     edm::LogInfo("TritonService") << output;
0326 
0327   //get the chosen device
0328   std::string chosenDevice(fallbackOpts_.device);
0329   if (chosenDevice == "auto") {
0330     chosenDevice = extractFromLog(output, "CMS_TRITON_CHOSEN_DEVICE: ");
0331     if (!chosenDevice.empty()) {
0332       if (chosenDevice == "cpu")
0333         server.type = TritonServerType::LocalCPU;
0334       else if (chosenDevice == "gpu")
0335         server.type = TritonServerType::LocalGPU;
0336       else
0337         throw edm::Exception(edm::errors::ExternalFailure)
0338             << "TritonService: unsupported device choice " << chosenDevice << " for fallback server, log follows:\n"
0339             << output;
0340     } else
0341       throw edm::Exception(edm::errors::ExternalFailure)
0342           << "TritonService: unknown device choice for fallback server, log follows:\n"
0343           << output;
0344   }
0345   //print server info
0346   std::transform(chosenDevice.begin(), chosenDevice.end(), chosenDevice.begin(), toupper);
0347   if (verbose_)
0348     edm::LogInfo("TritonDiscovery") << "Fallback server started: " << chosenDevice;
0349 
0350   //get the port
0351   const auto& portNum = extractFromLog(output, "CMS_TRITON_GRPC_PORT: ");
0352   if (!portNum.empty())
0353     server.url += ":" + portNum;
0354   else
0355     throw edm::Exception(edm::errors::ExternalFailure)
0356         << "TritonService: Unknown port for fallback server, log follows:\n"
0357         << output;
0358 }
0359 
0360 void TritonService::notifyCallStatus(bool status) const {
0361   if (status)
0362     --callFails_;
0363   else
0364     ++callFails_;
0365 }
0366 
0367 void TritonService::postEndJob() {
0368   if (!startedFallback_)
0369     return;
0370 
0371   std::string command = fallbackOpts_.command;
0372   //prevent log cleanup during server stop
0373   if (callFails_ > 0)
0374     command += " -c";
0375   command += " stop";
0376   if (verbose_)
0377     edm::LogInfo("TritonService") << command;
0378 
0379   const auto& [output, rv] = execSys(command);
0380   if (rv != 0 or callFails_ > 0) {
0381     //print logs if cmsRun is currently exiting because of a TritonException
0382     edm::LogError("TritonService") << output;
0383     printFallbackServerLog<edm::LogError>();
0384     if (rv != 0) {
0385       std::string stopCat("FallbackFailed");
0386       std::string stopMsg = fmt::format("TritonService: Stopping the fallback server failed with exit code {}", rv);
0387       //avoid throwing if the stack is already unwinding
0388       if (callFails_ > 0)
0389         edm::LogWarning(stopCat) << stopMsg;
0390       else
0391         throw cms::Exception(stopCat) << stopMsg;
0392     }
0393   } else if (verbose_) {
0394     edm::LogInfo("TritonService") << output;
0395     printFallbackServerLog<edm::LogInfo>();
0396   }
0397 }
0398 
0399 template <typename LOG>
0400 void TritonService::printFallbackServerLog() const {
0401   std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
0402   //cmsTriton script moves log from temp to current dir in verbose mode or in some cases when auto_stop is called
0403   // -> check both places
0404   logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
0405   bool foundLog = false;
0406   for (const auto& logName : logNames) {
0407     std::ifstream infile(logName);
0408     if (infile.is_open()) {
0409       LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
0410       foundLog = true;
0411       break;
0412     }
0413   }
0414   if (!foundLog)
0415     LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
0416                          << fallbackOpts_.tempDir;
0417 }
0418 
0419 void TritonService::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0420   edm::ParameterSetDescription desc;
0421   desc.addUntracked<bool>("verbose", false);
0422 
0423   edm::ParameterSetDescription validator;
0424   validator.addUntracked<std::string>("name");
0425   validator.addUntracked<std::string>("address");
0426   validator.addUntracked<unsigned>("port");
0427   validator.addUntracked<bool>("useSsl", false);
0428   validator.addUntracked<std::string>("rootCertificates", "");
0429   validator.addUntracked<std::string>("privateKey", "");
0430   validator.addUntracked<std::string>("certificateChain", "");
0431 
0432   desc.addVPSetUntracked("servers", validator, {});
0433 
0434   edm::ParameterSetDescription fallbackDesc;
0435   fallbackDesc.addUntracked<bool>("enable", false);
0436   fallbackDesc.addUntracked<bool>("debug", false);
0437   fallbackDesc.addUntracked<bool>("verbose", false);
0438   fallbackDesc.ifValue(edm::ParameterDescription<std::string>("container", "apptainer", false),
0439                        edm::allowedValues<std::string>("apptainer", "docker", "podman"));
0440   fallbackDesc.ifValue(edm::ParameterDescription<std::string>("device", "auto", false),
0441                        edm::allowedValues<std::string>("auto", "cpu", "gpu"));
0442   fallbackDesc.addUntracked<int>("retries", -1);
0443   fallbackDesc.addUntracked<int>("wait", -1);
0444   fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
0445   fallbackDesc.addUntracked<std::string>("instanceName", "");
0446   fallbackDesc.addUntracked<std::string>("tempDir", "");
0447   fallbackDesc.addUntracked<std::string>("imageName", "");
0448   fallbackDesc.addUntracked<std::string>("sandboxName", "");
0449   desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
0450 
0451   descriptions.addWithDefaultLabel(desc);
0452 }