Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2022-05-12 01:51:29

0001 #ifndef HeterogeneousCore_SonicTriton_TritonService
0002 #define HeterogeneousCore_SonicTriton_TritonService
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/Utilities/interface/GlobalIdentifier.h"
0006 
0007 #include <vector>
0008 #include <unordered_set>
0009 #include <unordered_map>
0010 #include <string>
0011 #include <functional>
0012 #include <utility>
0013 
0014 #include "grpc_client.h"
0015 
0016 //forward declarations
0017 namespace edm {
0018   class ActivityRegistry;
0019   class ConfigurationDescriptions;
0020   class PathsAndConsumesOfModulesBase;
0021   class ProcessContext;
0022   class ModuleDescription;
0023 }  // namespace edm
0024 
0025 enum class TritonServerType { Remote = 0, LocalCPU = 1, LocalGPU = 2 };
0026 
0027 class TritonService {
0028 public:
0029   //classes and defs
0030   struct FallbackOpts {
0031     FallbackOpts(const edm::ParameterSet& pset)
0032         : enable(pset.getUntrackedParameter<bool>("enable")),
0033           debug(pset.getUntrackedParameter<bool>("debug")),
0034           verbose(pset.getUntrackedParameter<bool>("verbose")),
0035           useDocker(pset.getUntrackedParameter<bool>("useDocker")),
0036           useGPU(pset.getUntrackedParameter<bool>("useGPU")),
0037           retries(pset.getUntrackedParameter<int>("retries")),
0038           wait(pset.getUntrackedParameter<int>("wait")),
0039           instanceName(pset.getUntrackedParameter<std::string>("instanceName")),
0040           tempDir(pset.getUntrackedParameter<std::string>("tempDir")),
0041           imageName(pset.getUntrackedParameter<std::string>("imageName")),
0042           sandboxName(pset.getUntrackedParameter<std::string>("sandboxName")) {
0043       //randomize instance name
0044       if (instanceName.empty()) {
0045         instanceName =
0046             pset.getUntrackedParameter<std::string>("instanceBaseName") + "_" + edm::createGlobalIdentifier();
0047       }
0048     }
0049 
0050     bool enable;
0051     bool debug;
0052     bool verbose;
0053     bool useDocker;
0054     bool useGPU;
0055     int retries;
0056     int wait;
0057     std::string instanceName;
0058     std::string tempDir;
0059     std::string imageName;
0060     std::string sandboxName;
0061     std::string command;
0062   };
0063   struct Server {
0064     Server(const edm::ParameterSet& pset)
0065         : url(pset.getUntrackedParameter<std::string>("address") + ":" +
0066               std::to_string(pset.getUntrackedParameter<unsigned>("port"))),
0067           isFallback(pset.getUntrackedParameter<std::string>("name") == fallbackName),
0068           useSsl(pset.getUntrackedParameter<bool>("useSsl")),
0069           type(TritonServerType::Remote) {
0070       if (useSsl) {
0071         sslOptions.root_certificates = pset.getUntrackedParameter<std::string>("rootCertificates");
0072         sslOptions.private_key = pset.getUntrackedParameter<std::string>("privateKey");
0073         sslOptions.certificate_chain = pset.getUntrackedParameter<std::string>("certificateChain");
0074       }
0075     }
0076     Server(const std::string& name_, const std::string& url_, TritonServerType type_)
0077         : url(url_), isFallback(name_ == fallbackName), useSsl(false), type(type_) {}
0078 
0079     //members
0080     std::string url;
0081     bool isFallback;
0082     bool useSsl;
0083     TritonServerType type;
0084     triton::client::SslOptions sslOptions;
0085     std::unordered_set<std::string> models;
0086     static const std::string fallbackName;
0087     static const std::string fallbackAddress;
0088   };
0089   struct Model {
0090     Model(const std::string& path_ = "") : path(path_) {}
0091 
0092     //members
0093     std::string path;
0094     std::unordered_set<std::string> servers;
0095     std::unordered_set<unsigned> modules;
0096   };
0097   struct Module {
0098     //currently assumes that a module can only have one associated model
0099     Module(const std::string& model_) : model(model_) {}
0100 
0101     //members
0102     std::string model;
0103   };
0104 
0105   TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg);
0106   ~TritonService() = default;
0107 
0108   //accessors
0109   void addModel(const std::string& modelName, const std::string& path);
0110   Server serverInfo(const std::string& model, const std::string& preferred = "") const;
0111   const std::string& pid() const { return pid_; }
0112 
0113   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0114 
0115 private:
0116   void preModuleConstruction(edm::ModuleDescription const&);
0117   void postModuleConstruction(edm::ModuleDescription const&);
0118   void preModuleDestruction(edm::ModuleDescription const&);
0119   void preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&);
0120   void postEndJob();
0121 
0122   //helper
0123   template <typename LOG>
0124   void printFallbackServerLog() const;
0125 
0126   bool verbose_;
0127   FallbackOpts fallbackOpts_;
0128   unsigned currentModuleId_;
0129   bool allowAddModel_;
0130   bool startedFallback_;
0131   std::string pid_;
0132   std::unordered_map<std::string, Model> unservedModels_;
0133   //this represents a many:many:many map
0134   std::unordered_map<std::string, Server> servers_;
0135   std::unordered_map<std::string, Model> models_;
0136   std::unordered_map<unsigned, Module> modules_;
0137 };
0138 
0139 #endif