Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:48

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0003 
0004 #include "DataFormats/Provenance/interface/ModuleDescription.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0007 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0008 #include "FWCore/ServiceRegistry/interface/ActivityRegistry.h"
0009 #include "FWCore/ServiceRegistry/interface/SystemBounds.h"
0010 #include "FWCore/ServiceRegistry/interface/ProcessContext.h"
0011 #include "FWCore/Utilities/interface/Exception.h"
0012 
0013 #include "grpc_client.h"
0014 #include "grpc_service.pb.h"
0015 
0016 #include <cstdio>
0017 #include <cstdlib>
0018 #include <filesystem>
0019 #include <fstream>
0020 #include <utility>
0021 #include <tuple>
0022 #include <unistd.h>
0023 
0024 namespace tc = triton::client;
0025 
0026 const std::string TritonService::Server::fallbackName{"fallback"};
0027 const std::string TritonService::Server::fallbackAddress{"0.0.0.0"};
0028 
0029 namespace {
0030   std::pair<std::string, int> execSys(const std::string& cmd) {
0031     //redirect stderr to stdout
0032     auto pipe = popen((cmd + " 2>&1").c_str(), "r");
0033     int thisErrno = errno;
0034     if (!pipe)
0035       throw cms::Exception("SystemError")
0036           << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
0037 
0038     //extract output
0039     constexpr static unsigned buffSize = 128;
0040     std::array<char, buffSize> buffer;
0041     std::string result;
0042     while (!feof(pipe)) {
0043       if (fgets(buffer.data(), buffSize, pipe))
0044         result += buffer.data();
0045       else {
0046         thisErrno = ferror(pipe);
0047         if (thisErrno)
0048           throw cms::Exception("SystemError")
0049               << "TritonService: failed reading command output with errno " << thisErrno;
0050       }
0051     }
0052 
0053     int rv = pclose(pipe);
0054     return std::make_pair(result, rv);
0055   }
0056 }  // namespace
0057 
0058 TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg)
0059     : verbose_(pset.getUntrackedParameter<bool>("verbose")),
0060       fallbackOpts_(pset.getParameterSet("fallback")),
0061       currentModuleId_(0),
0062       allowAddModel_(false),
0063       startedFallback_(false),
0064       callFails_(0),
0065       pid_(std::to_string(::getpid())) {
0066   //module construction is assumed to be serial (correct at the time this code was written)
0067 
0068   areg.watchPreallocate(this, &TritonService::preallocate);
0069 
0070   areg.watchPreModuleConstruction(this, &TritonService::preModuleConstruction);
0071   areg.watchPostModuleConstruction(this, &TritonService::postModuleConstruction);
0072   areg.watchPreModuleDestruction(this, &TritonService::preModuleDestruction);
0073   //fallback server will be launched (if needed) before beginJob
0074   areg.watchPreBeginJob(this, &TritonService::preBeginJob);
0075   areg.watchPostEndJob(this, &TritonService::postEndJob);
0076 
0077   //include fallback server in set if enabled
0078   if (fallbackOpts_.enable) {
0079     auto serverType = TritonServerType::Remote;
0080     if (!fallbackOpts_.useGPU)
0081       serverType = TritonServerType::LocalCPU;
0082 #ifdef TRITON_ENABLE_GPU
0083     else
0084       serverType = TritonServerType::LocalGPU;
0085 #endif
0086 
0087     servers_.emplace(std::piecewise_construct,
0088                      std::forward_as_tuple(Server::fallbackName),
0089                      std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
0090   }
0091 
0092   //loop over input servers: check which models they have
0093   std::string msg;
0094   if (verbose_)
0095     msg = "List of models for each server:\n";
0096   for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
0097     const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
0098     //ensure uniqueness
0099     auto [sit, unique] = servers_.emplace(serverName, serverPset);
0100     if (!unique)
0101       throw cms::Exception("DuplicateServer")
0102           << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
0103     auto& server(sit->second);
0104 
0105     std::unique_ptr<tc::InferenceServerGrpcClient> client;
0106     TRITON_THROW_IF_ERROR(
0107         tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
0108         "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")",
0109         false);
0110 
0111     if (verbose_) {
0112       inference::ServerMetadataResponse serverMetaResponse;
0113       TRITON_THROW_IF_ERROR(client->ServerMetadata(&serverMetaResponse),
0114                             "TritonService(): unable to get metadata for " + serverName + " (" + server.url + ")",
0115                             false);
0116       edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
0117                                     << ", version = " << serverMetaResponse.version();
0118     }
0119 
0120     inference::RepositoryIndexResponse repoIndexResponse;
0121     TRITON_THROW_IF_ERROR(client->ModelRepositoryIndex(&repoIndexResponse),
0122                           "TritonService(): unable to get repository index for " + serverName + " (" + server.url + ")",
0123                           false);
0124 
0125     //servers keep track of models and vice versa
0126     if (verbose_)
0127       msg += serverName + ": ";
0128     for (const auto& modelIndex : repoIndexResponse.models()) {
0129       const auto& modelName = modelIndex.name();
0130       auto mit = models_.find(modelName);
0131       if (mit == models_.end())
0132         mit = models_.emplace(modelName, "").first;
0133       auto& modelInfo(mit->second);
0134       modelInfo.servers.insert(serverName);
0135       server.models.insert(modelName);
0136       if (verbose_)
0137         msg += modelName + ", ";
0138     }
0139     if (verbose_)
0140       msg += "\n";
0141   }
0142   if (verbose_)
0143     edm::LogInfo("TritonService") << msg;
0144 }
0145 
0146 void TritonService::preallocate(edm::service::SystemBounds const& bounds) {
0147   numberOfThreads_ = bounds.maxNumberOfThreads();
0148 }
0149 
0150 void TritonService::preModuleConstruction(edm::ModuleDescription const& desc) {
0151   currentModuleId_ = desc.id();
0152   allowAddModel_ = true;
0153 }
0154 
0155 void TritonService::addModel(const std::string& modelName, const std::string& path) {
0156   //should only be called in module constructors
0157   if (!allowAddModel_)
0158     throw cms::Exception("DisallowedAddModel")
0159         << "TritonService: Attempt to call addModel() outside of module constructors";
0160   //if model is not in the list, then no specified server provides it
0161   auto mit = models_.find(modelName);
0162   if (mit == models_.end()) {
0163     auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
0164     modelInfo.modules.insert(currentModuleId_);
0165     //only keep track of modules that need unserved models
0166     modules_.emplace(currentModuleId_, modelName);
0167   }
0168 }
0169 
0170 void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; }
0171 
0172 void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) {
0173   //remove destructed modules from unserved list
0174   if (unservedModels_.empty())
0175     return;
0176   auto id = desc.id();
0177   auto oit = modules_.find(id);
0178   if (oit != modules_.end()) {
0179     const auto& moduleInfo(oit->second);
0180     auto mit = unservedModels_.find(moduleInfo.model);
0181     if (mit != unservedModels_.end()) {
0182       auto& modelInfo(mit->second);
0183       modelInfo.modules.erase(id);
0184       //remove a model if it is no longer needed by any modules
0185       if (modelInfo.modules.empty())
0186         unservedModels_.erase(mit);
0187     }
0188     modules_.erase(oit);
0189   }
0190 }
0191 
0192 //second return value is only true if fallback CPU server is being used
0193 TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const {
0194   auto mit = models_.find(model);
0195   if (mit == models_.end())
0196     throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
0197   const auto& modelInfo(mit->second);
0198   const auto& modelServers = modelInfo.servers;
0199 
0200   auto msit = modelServers.end();
0201   if (!preferred.empty()) {
0202     msit = modelServers.find(preferred);
0203     //todo: add a "strict" parameter to stop execution if preferred server isn't found?
0204     if (msit == modelServers.end())
0205       edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
0206                                          << " not available, will choose another server";
0207   }
0208   const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
0209 
0210   //todo: use some algorithm to select server rather than just picking arbitrarily
0211   const auto& server(servers_.find(serverName)->second);
0212   return server;
0213 }
0214 
0215 void TritonService::preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&) {
0216   //only need fallback if there are unserved models
0217   if (!fallbackOpts_.enable or unservedModels_.empty())
0218     return;
0219 
0220   std::string msg;
0221   if (verbose_)
0222     msg = "List of models for fallback server: ";
0223   //all unserved models are provided by fallback server
0224   auto& server(servers_.find(Server::fallbackName)->second);
0225   for (const auto& [modelName, model] : unservedModels_) {
0226     auto& modelInfo(models_.emplace(modelName, model).first->second);
0227     modelInfo.servers.insert(Server::fallbackName);
0228     server.models.insert(modelName);
0229     if (verbose_)
0230       msg += modelName + ", ";
0231   }
0232   if (verbose_)
0233     edm::LogInfo("TritonService") << msg;
0234 
0235   //assemble server start command
0236   fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
0237   if (fallbackOpts_.debug)
0238     fallbackOpts_.command += " -c";
0239   if (fallbackOpts_.verbose)
0240     fallbackOpts_.command += " -v";
0241   if (fallbackOpts_.useDocker)
0242     fallbackOpts_.command += " -d";
0243   if (fallbackOpts_.useGPU)
0244     fallbackOpts_.command += " -g";
0245   if (!fallbackOpts_.instanceName.empty())
0246     fallbackOpts_.command += " -n " + fallbackOpts_.instanceName;
0247   if (fallbackOpts_.retries >= 0)
0248     fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries);
0249   if (fallbackOpts_.wait >= 0)
0250     fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait);
0251   for (const auto& [modelName, model] : unservedModels_) {
0252     fallbackOpts_.command += " -m " + model.path;
0253   }
0254   std::string thread_string = " -I " + std::to_string(numberOfThreads_);
0255   fallbackOpts_.command += thread_string;
0256   if (!fallbackOpts_.imageName.empty())
0257     fallbackOpts_.command += " -i " + fallbackOpts_.imageName;
0258   if (!fallbackOpts_.sandboxName.empty())
0259     fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName;
0260   //don't need this anymore
0261   unservedModels_.clear();
0262 
0263   //get a random temporary directory if none specified
0264   if (fallbackOpts_.tempDir.empty()) {
0265     auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
0266     fallbackOpts_.tempDir = tmp_dir_path.string();
0267   }
0268   //special case ".": use script default (temp dir = .$instanceName)
0269   if (fallbackOpts_.tempDir != ".")
0270     fallbackOpts_.command += " -t " + fallbackOpts_.tempDir;
0271 
0272   std::string command = fallbackOpts_.command + " start";
0273 
0274   if (fallbackOpts_.debug)
0275     edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
0276   if (verbose_)
0277     edm::LogInfo("TritonService") << command;
0278 
0279   //mark as started before executing in case of ctrl+c while command is running
0280   startedFallback_ = true;
0281   const auto& [output, rv] = execSys(command);
0282   if (rv != 0) {
0283     edm::LogError("TritonService") << output;
0284     printFallbackServerLog<edm::LogError>();
0285     throw cms::Exception("FallbackFailed")
0286         << "TritonService: Starting the fallback server failed with exit code " << rv;
0287   } else if (verbose_)
0288     edm::LogInfo("TritonService") << output;
0289   //get the port
0290   const std::string& portIndicator("CMS_TRITON_GRPC_PORT: ");
0291   //find last instance in log in case multiple ports were tried
0292   auto pos = output.rfind(portIndicator);
0293   if (pos != std::string::npos) {
0294     auto pos2 = pos + portIndicator.size();
0295     auto pos3 = output.find('\n', pos2);
0296     const auto& portNum = output.substr(pos2, pos3 - pos2);
0297     server.url += ":" + portNum;
0298   } else
0299     throw cms::Exception("FallbackFailed") << "TritonService: Unknown port for fallback server, log follows:\n"
0300                                            << output;
0301 }
0302 
0303 void TritonService::notifyCallStatus(bool status) const {
0304   if (status)
0305     --callFails_;
0306   else
0307     ++callFails_;
0308 }
0309 
0310 void TritonService::postEndJob() {
0311   if (!startedFallback_)
0312     return;
0313 
0314   std::string command = fallbackOpts_.command;
0315   //prevent log cleanup during server stop
0316   if (callFails_ > 0)
0317     command += " -c";
0318   command += " stop";
0319   if (verbose_)
0320     edm::LogInfo("TritonService") << command;
0321 
0322   const auto& [output, rv] = execSys(command);
0323   if (rv != 0 or callFails_ > 0) {
0324     //print logs if cmsRun is currently exiting because of a TritonException
0325     edm::LogError("TritonService") << output;
0326     printFallbackServerLog<edm::LogError>();
0327     if (rv != 0) {
0328       std::string stopCat("FallbackFailed");
0329       std::string stopMsg = fmt::format("TritonService: Stopping the fallback server failed with exit code {}", rv);
0330       //avoid throwing if the stack is already unwinding
0331       if (callFails_ > 0)
0332         edm::LogWarning(stopCat) << stopMsg;
0333       else
0334         throw cms::Exception(stopCat) << stopMsg;
0335     }
0336   } else if (verbose_) {
0337     edm::LogInfo("TritonService") << output;
0338     printFallbackServerLog<edm::LogInfo>();
0339   }
0340 }
0341 
0342 template <typename LOG>
0343 void TritonService::printFallbackServerLog() const {
0344   std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
0345   //cmsTriton script moves log from temp to current dir in verbose mode or in some cases when auto_stop is called
0346   // -> check both places
0347   logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
0348   bool foundLog = false;
0349   for (const auto& logName : logNames) {
0350     std::ifstream infile(logName);
0351     if (infile.is_open()) {
0352       LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
0353       foundLog = true;
0354       break;
0355     }
0356   }
0357   if (!foundLog)
0358     LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
0359                          << fallbackOpts_.tempDir;
0360 }
0361 
0362 void TritonService::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0363   edm::ParameterSetDescription desc;
0364   desc.addUntracked<bool>("verbose", false);
0365 
0366   edm::ParameterSetDescription validator;
0367   validator.addUntracked<std::string>("name");
0368   validator.addUntracked<std::string>("address");
0369   validator.addUntracked<unsigned>("port");
0370   validator.addUntracked<bool>("useSsl", false);
0371   validator.addUntracked<std::string>("rootCertificates", "");
0372   validator.addUntracked<std::string>("privateKey", "");
0373   validator.addUntracked<std::string>("certificateChain", "");
0374 
0375   desc.addVPSetUntracked("servers", validator, {});
0376 
0377   edm::ParameterSetDescription fallbackDesc;
0378   fallbackDesc.addUntracked<bool>("enable", false);
0379   fallbackDesc.addUntracked<bool>("debug", false);
0380   fallbackDesc.addUntracked<bool>("verbose", false);
0381   fallbackDesc.addUntracked<bool>("useDocker", false);
0382   fallbackDesc.addUntracked<bool>("useGPU", false);
0383   fallbackDesc.addUntracked<int>("retries", -1);
0384   fallbackDesc.addUntracked<int>("wait", -1);
0385   fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
0386   fallbackDesc.addUntracked<std::string>("instanceName", "");
0387   fallbackDesc.addUntracked<std::string>("tempDir", "");
0388   fallbackDesc.addUntracked<std::string>("imageName", "");
0389   fallbackDesc.addUntracked<std::string>("sandboxName", "");
0390   desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
0391 
0392   descriptions.addWithDefaultLabel(desc);
0393 }