Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-12 01:51:29

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonService.h"
0002 #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
0003 
0004 #include "DataFormats/Provenance/interface/ModuleDescription.h"
0005 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0006 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0007 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0008 #include "FWCore/ServiceRegistry/interface/ActivityRegistry.h"
0009 #include "FWCore/ServiceRegistry/interface/ProcessContext.h"
0010 #include "FWCore/Utilities/interface/Exception.h"
0011 
0012 #include "grpc_client.h"
0013 #include "grpc_service.pb.h"
0014 
0015 #include <cstdio>
0016 #include <cstdlib>
0017 #include <filesystem>
0018 #include <fstream>
0019 #include <utility>
0020 #include <tuple>
0021 #include <unistd.h>
0022 
0023 namespace tc = triton::client;
0024 
0025 const std::string TritonService::Server::fallbackName{"fallback"};
0026 const std::string TritonService::Server::fallbackAddress{"0.0.0.0"};
0027 
0028 namespace {
0029   std::pair<std::string, int> execSys(const std::string& cmd) {
0030     //redirect stderr to stdout
0031     auto pipe = popen((cmd + " 2>&1").c_str(), "r");
0032     int thisErrno = errno;
0033     if (!pipe)
0034       throw cms::Exception("SystemError")
0035           << "TritonService: popen() failed with errno " << thisErrno << " for command: " << cmd;
0036 
0037     //extract output
0038     constexpr static unsigned buffSize = 128;
0039     std::array<char, buffSize> buffer;
0040     std::string result;
0041     while (!feof(pipe)) {
0042       if (fgets(buffer.data(), buffSize, pipe))
0043         result += buffer.data();
0044       else {
0045         thisErrno = ferror(pipe);
0046         if (thisErrno)
0047           throw cms::Exception("SystemError")
0048               << "TritonService: failed reading command output with errno " << thisErrno;
0049       }
0050     }
0051 
0052     int rv = pclose(pipe);
0053     return std::make_pair(result, rv);
0054   }
0055 }  // namespace
0056 
0057 TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg)
0058     : verbose_(pset.getUntrackedParameter<bool>("verbose")),
0059       fallbackOpts_(pset.getParameterSet("fallback")),
0060       currentModuleId_(0),
0061       allowAddModel_(false),
0062       startedFallback_(false),
0063       pid_(std::to_string(::getpid())) {
0064   //module construction is assumed to be serial (correct at the time this code was written)
0065   areg.watchPreModuleConstruction(this, &TritonService::preModuleConstruction);
0066   areg.watchPostModuleConstruction(this, &TritonService::postModuleConstruction);
0067   areg.watchPreModuleDestruction(this, &TritonService::preModuleDestruction);
0068   //fallback server will be launched (if needed) before beginJob
0069   areg.watchPreBeginJob(this, &TritonService::preBeginJob);
0070   areg.watchPostEndJob(this, &TritonService::postEndJob);
0071 
0072   //include fallback server in set if enabled
0073   if (fallbackOpts_.enable) {
0074     auto serverType = TritonServerType::Remote;
0075     if (!fallbackOpts_.useGPU)
0076       serverType = TritonServerType::LocalCPU;
0077 #ifdef TRITON_ENABLE_GPU
0078     else
0079       serverType = TritonServerType::LocalGPU;
0080 #endif
0081 
0082     servers_.emplace(std::piecewise_construct,
0083                      std::forward_as_tuple(Server::fallbackName),
0084                      std::forward_as_tuple(Server::fallbackName, Server::fallbackAddress, serverType));
0085   }
0086 
0087   //loop over input servers: check which models they have
0088   std::string msg;
0089   if (verbose_)
0090     msg = "List of models for each server:\n";
0091   for (const auto& serverPset : pset.getUntrackedParameterSetVector("servers")) {
0092     const std::string& serverName(serverPset.getUntrackedParameter<std::string>("name"));
0093     //ensure uniqueness
0094     auto [sit, unique] = servers_.emplace(serverName, serverPset);
0095     if (!unique)
0096       throw cms::Exception("DuplicateServer")
0097           << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")";
0098     auto& server(sit->second);
0099 
0100     std::unique_ptr<tc::InferenceServerGrpcClient> client;
0101     TRITON_THROW_IF_ERROR(
0102         tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions),
0103         "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")");
0104 
0105     if (verbose_) {
0106       inference::ServerMetadataResponse serverMetaResponse;
0107       TRITON_THROW_IF_ERROR(client->ServerMetadata(&serverMetaResponse),
0108                             "TritonService(): unable to get metadata for " + serverName + " (" + server.url + ")");
0109       edm::LogInfo("TritonService") << "Server " << serverName << ": url = " << server.url
0110                                     << ", version = " << serverMetaResponse.version();
0111     }
0112 
0113     inference::RepositoryIndexResponse repoIndexResponse;
0114     TRITON_THROW_IF_ERROR(
0115         client->ModelRepositoryIndex(&repoIndexResponse),
0116         "TritonService(): unable to get repository index for " + serverName + " (" + server.url + ")");
0117 
0118     //servers keep track of models and vice versa
0119     if (verbose_)
0120       msg += serverName + ": ";
0121     for (const auto& modelIndex : repoIndexResponse.models()) {
0122       const auto& modelName = modelIndex.name();
0123       auto mit = models_.find(modelName);
0124       if (mit == models_.end())
0125         mit = models_.emplace(modelName, "").first;
0126       auto& modelInfo(mit->second);
0127       modelInfo.servers.insert(serverName);
0128       server.models.insert(modelName);
0129       if (verbose_)
0130         msg += modelName + ", ";
0131     }
0132     if (verbose_)
0133       msg += "\n";
0134   }
0135   if (verbose_)
0136     edm::LogInfo("TritonService") << msg;
0137 }
0138 
0139 void TritonService::preModuleConstruction(edm::ModuleDescription const& desc) {
0140   currentModuleId_ = desc.id();
0141   allowAddModel_ = true;
0142 }
0143 
0144 void TritonService::addModel(const std::string& modelName, const std::string& path) {
0145   //should only be called in module constructors
0146   if (!allowAddModel_)
0147     throw cms::Exception("DisallowedAddModel")
0148         << "TritonService: Attempt to call addModel() outside of module constructors";
0149   //if model is not in the list, then no specified server provides it
0150   auto mit = models_.find(modelName);
0151   if (mit == models_.end()) {
0152     auto& modelInfo(unservedModels_.emplace(modelName, path).first->second);
0153     modelInfo.modules.insert(currentModuleId_);
0154     //only keep track of modules that need unserved models
0155     modules_.emplace(currentModuleId_, modelName);
0156   }
0157 }
0158 
0159 void TritonService::postModuleConstruction(edm::ModuleDescription const& desc) { allowAddModel_ = false; }
0160 
0161 void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) {
0162   //remove destructed modules from unserved list
0163   if (unservedModels_.empty())
0164     return;
0165   auto id = desc.id();
0166   auto oit = modules_.find(id);
0167   if (oit != modules_.end()) {
0168     const auto& moduleInfo(oit->second);
0169     auto mit = unservedModels_.find(moduleInfo.model);
0170     if (mit != unservedModels_.end()) {
0171       auto& modelInfo(mit->second);
0172       modelInfo.modules.erase(id);
0173       //remove a model if it is no longer needed by any modules
0174       if (modelInfo.modules.empty())
0175         unservedModels_.erase(mit);
0176     }
0177     modules_.erase(oit);
0178   }
0179 }
0180 
0181 //second return value is only true if fallback CPU server is being used
0182 TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const {
0183   auto mit = models_.find(model);
0184   if (mit == models_.end())
0185     throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model;
0186   const auto& modelInfo(mit->second);
0187   const auto& modelServers = modelInfo.servers;
0188 
0189   auto msit = modelServers.end();
0190   if (!preferred.empty()) {
0191     msit = modelServers.find(preferred);
0192     //todo: add a "strict" parameter to stop execution if preferred server isn't found?
0193     if (msit == modelServers.end())
0194       edm::LogWarning("PreferredServer") << "Preferred server " << preferred << " for model " << model
0195                                          << " not available, will choose another server";
0196   }
0197   const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred);
0198 
0199   //todo: use some algorithm to select server rather than just picking arbitrarily
0200   const auto& server(servers_.find(serverName)->second);
0201   return server;
0202 }
0203 
0204 void TritonService::preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&) {
0205   //only need fallback if there are unserved models
0206   if (!fallbackOpts_.enable or unservedModels_.empty())
0207     return;
0208 
0209   std::string msg;
0210   if (verbose_)
0211     msg = "List of models for fallback server: ";
0212   //all unserved models are provided by fallback server
0213   auto& server(servers_.find(Server::fallbackName)->second);
0214   for (const auto& [modelName, model] : unservedModels_) {
0215     auto& modelInfo(models_.emplace(modelName, model).first->second);
0216     modelInfo.servers.insert(Server::fallbackName);
0217     server.models.insert(modelName);
0218     if (verbose_)
0219       msg += modelName + ", ";
0220   }
0221   if (verbose_)
0222     edm::LogInfo("TritonService") << msg;
0223 
0224   //assemble server start command
0225   fallbackOpts_.command = "cmsTriton -P -1 -p " + pid_;
0226   if (fallbackOpts_.debug)
0227     fallbackOpts_.command += " -c";
0228   if (fallbackOpts_.verbose)
0229     fallbackOpts_.command += " -v";
0230   if (fallbackOpts_.useDocker)
0231     fallbackOpts_.command += " -d";
0232   if (fallbackOpts_.useGPU)
0233     fallbackOpts_.command += " -g";
0234   if (!fallbackOpts_.instanceName.empty())
0235     fallbackOpts_.command += " -n " + fallbackOpts_.instanceName;
0236   if (fallbackOpts_.retries >= 0)
0237     fallbackOpts_.command += " -r " + std::to_string(fallbackOpts_.retries);
0238   if (fallbackOpts_.wait >= 0)
0239     fallbackOpts_.command += " -w " + std::to_string(fallbackOpts_.wait);
0240   for (const auto& [modelName, model] : unservedModels_) {
0241     fallbackOpts_.command += " -m " + model.path;
0242   }
0243   if (!fallbackOpts_.imageName.empty())
0244     fallbackOpts_.command += " -i " + fallbackOpts_.imageName;
0245   if (!fallbackOpts_.sandboxName.empty())
0246     fallbackOpts_.command += " -s " + fallbackOpts_.sandboxName;
0247   //don't need this anymore
0248   unservedModels_.clear();
0249 
0250   //get a random temporary directory if none specified
0251   if (fallbackOpts_.tempDir.empty()) {
0252     auto tmp_dir_path{std::filesystem::temp_directory_path() /= edm::createGlobalIdentifier()};
0253     fallbackOpts_.tempDir = tmp_dir_path.string();
0254   }
0255   //special case ".": use script default (temp dir = .$instanceName)
0256   if (fallbackOpts_.tempDir != ".")
0257     fallbackOpts_.command += " -t " + fallbackOpts_.tempDir;
0258 
0259   std::string command = fallbackOpts_.command + " start";
0260 
0261   if (fallbackOpts_.debug)
0262     edm::LogInfo("TritonService") << "Fallback server temporary directory: " << fallbackOpts_.tempDir;
0263   if (verbose_)
0264     edm::LogInfo("TritonService") << command;
0265 
0266   //mark as started before executing in case of ctrl+c while command is running
0267   startedFallback_ = true;
0268   const auto& [output, rv] = execSys(command);
0269   if (rv != 0) {
0270     edm::LogError("TritonService") << output;
0271     printFallbackServerLog<edm::LogError>();
0272     throw cms::Exception("FallbackFailed")
0273         << "TritonService: Starting the fallback server failed with exit code " << rv;
0274   } else if (verbose_)
0275     edm::LogInfo("TritonService") << output;
0276   //get the port
0277   const std::string& portIndicator("CMS_TRITON_GRPC_PORT: ");
0278   //find last instance in log in case multiple ports were tried
0279   auto pos = output.rfind(portIndicator);
0280   if (pos != std::string::npos) {
0281     auto pos2 = pos + portIndicator.size();
0282     auto pos3 = output.find('\n', pos2);
0283     const auto& portNum = output.substr(pos2, pos3 - pos2);
0284     server.url += ":" + portNum;
0285   } else
0286     throw cms::Exception("FallbackFailed") << "TritonService: Unknown port for fallback server, log follows:\n"
0287                                            << output;
0288 }
0289 
0290 void TritonService::postEndJob() {
0291   if (!startedFallback_)
0292     return;
0293 
0294   std::string command = fallbackOpts_.command + " stop";
0295   if (verbose_)
0296     edm::LogInfo("TritonService") << command;
0297 
0298   const auto& [output, rv] = execSys(command);
0299   if (rv != 0) {
0300     edm::LogError("TritonService") << output;
0301     printFallbackServerLog<edm::LogError>();
0302     throw cms::Exception("FallbackFailed")
0303         << "TritonService: Stopping the fallback server failed with exit code " << rv;
0304   } else if (verbose_) {
0305     edm::LogInfo("TritonService") << output;
0306     printFallbackServerLog<edm::LogInfo>();
0307   }
0308 }
0309 
0310 template <typename LOG>
0311 void TritonService::printFallbackServerLog() const {
0312   std::vector<std::string> logNames{"log_" + fallbackOpts_.instanceName + ".log"};
0313   //cmsTriton script moves log from temp to current dir in verbose mode or in some cases when auto_stop is called
0314   // -> check both places
0315   logNames.push_back(fallbackOpts_.tempDir + "/" + logNames[0]);
0316   bool foundLog = false;
0317   for (const auto& logName : logNames) {
0318     std::ifstream infile(logName);
0319     if (infile.is_open()) {
0320       LOG("TritonService") << "TritonService: server log " << logName << "\n" << infile.rdbuf();
0321       foundLog = true;
0322       break;
0323     }
0324   }
0325   if (!foundLog)
0326     LOG("TritonService") << "TritonService: could not find server log " << logNames[0] << " in current directory or "
0327                          << fallbackOpts_.tempDir;
0328 }
0329 
0330 void TritonService::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0331   edm::ParameterSetDescription desc;
0332   desc.addUntracked<bool>("verbose", false);
0333 
0334   edm::ParameterSetDescription validator;
0335   validator.addUntracked<std::string>("name");
0336   validator.addUntracked<std::string>("address");
0337   validator.addUntracked<unsigned>("port");
0338   validator.addUntracked<bool>("useSsl", false);
0339   validator.addUntracked<std::string>("rootCertificates", "");
0340   validator.addUntracked<std::string>("privateKey", "");
0341   validator.addUntracked<std::string>("certificateChain", "");
0342 
0343   desc.addVPSetUntracked("servers", validator, {});
0344 
0345   edm::ParameterSetDescription fallbackDesc;
0346   fallbackDesc.addUntracked<bool>("enable", false);
0347   fallbackDesc.addUntracked<bool>("debug", false);
0348   fallbackDesc.addUntracked<bool>("verbose", false);
0349   fallbackDesc.addUntracked<bool>("useDocker", false);
0350   fallbackDesc.addUntracked<bool>("useGPU", false);
0351   fallbackDesc.addUntracked<int>("retries", -1);
0352   fallbackDesc.addUntracked<int>("wait", -1);
0353   fallbackDesc.addUntracked<std::string>("instanceBaseName", "triton_server_instance");
0354   fallbackDesc.addUntracked<std::string>("instanceName", "");
0355   fallbackDesc.addUntracked<std::string>("tempDir", "");
0356   fallbackDesc.addUntracked<std::string>("imageName", "");
0357   fallbackDesc.addUntracked<std::string>("sandboxName", "");
0358   desc.add<edm::ParameterSetDescription>("fallback", fallbackDesc);
0359 
0360   descriptions.addWithDefaultLabel(desc);
0361 }