Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-06 22:36:40

0001 #ifndef HeterogeneousCore_SonicTriton_TritonService
0002 #define HeterogeneousCore_SonicTriton_TritonService
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/Utilities/interface/GlobalIdentifier.h"
0006 
0007 #include <vector>
0008 #include <unordered_set>
0009 #include <unordered_map>
0010 #include <string>
0011 #include <functional>
0012 #include <utility>
0013 #include <atomic>
0014 
0015 #include "grpc_client.h"
0016 
0017 //forward declarations
0018 namespace edm {
0019   class ActivityRegistry;
0020   class ConfigurationDescriptions;
0021   class PathsAndConsumesOfModulesBase;
0022   class ProcessContext;
0023   class ModuleDescription;
0024   namespace service {
0025     class SystemBounds;
0026   }
0027 }  // namespace edm
0028 
0029 enum class TritonServerType { Remote = 0, LocalCPU = 1, LocalGPU = 2 };
0030 
0031 class TritonService {
0032 public:
0033   //classes and defs
0034   struct FallbackOpts {
0035     FallbackOpts(const edm::ParameterSet& pset)
0036         : enable(pset.getUntrackedParameter<bool>("enable")),
0037           debug(pset.getUntrackedParameter<bool>("debug")),
0038           verbose(pset.getUntrackedParameter<bool>("verbose")),
0039           container(pset.getUntrackedParameter<std::string>("container")),
0040           device(pset.getUntrackedParameter<std::string>("device")),
0041           retries(pset.getUntrackedParameter<int>("retries")),
0042           wait(pset.getUntrackedParameter<int>("wait")),
0043           instanceName(pset.getUntrackedParameter<std::string>("instanceName")),
0044           tempDir(pset.getUntrackedParameter<std::string>("tempDir")),
0045           imageName(pset.getUntrackedParameter<std::string>("imageName")),
0046           sandboxName(pset.getUntrackedParameter<std::string>("sandboxName")) {
0047       //randomize instance name
0048       if (instanceName.empty()) {
0049         instanceName =
0050             pset.getUntrackedParameter<std::string>("instanceBaseName") + "_" + edm::createGlobalIdentifier();
0051       }
0052     }
0053 
0054     bool enable;
0055     bool debug;
0056     bool verbose;
0057     std::string container;
0058     std::string device;
0059     int retries;
0060     int wait;
0061     std::string instanceName;
0062     std::string tempDir;
0063     std::string imageName;
0064     std::string sandboxName;
0065     std::string command;
0066   };
0067   struct Server {
0068     Server(const edm::ParameterSet& pset)
0069         : url(pset.getUntrackedParameter<std::string>("address") + ":" +
0070               std::to_string(pset.getUntrackedParameter<unsigned>("port"))),
0071           isFallback(pset.getUntrackedParameter<std::string>("name") == fallbackName),
0072           useSsl(pset.getUntrackedParameter<bool>("useSsl")),
0073           type(TritonServerType::Remote) {
0074       if (useSsl) {
0075         sslOptions.root_certificates = pset.getUntrackedParameter<std::string>("rootCertificates");
0076         sslOptions.private_key = pset.getUntrackedParameter<std::string>("privateKey");
0077         sslOptions.certificate_chain = pset.getUntrackedParameter<std::string>("certificateChain");
0078       }
0079     }
0080     Server(const std::string& name_, const std::string& url_, TritonServerType type_)
0081         : url(url_), isFallback(name_ == fallbackName), useSsl(false), type(type_) {}
0082 
0083     //members
0084     std::string url;
0085     bool isFallback;
0086     bool useSsl;
0087     TritonServerType type;
0088     triton::client::SslOptions sslOptions;
0089     std::unordered_set<std::string> models;
0090     static const std::string fallbackName;
0091     static const std::string fallbackAddress;
0092     static const std::string siteconfName;
0093   };
0094   struct Model {
0095     Model(const std::string& path_ = "") : path(path_) {}
0096 
0097     //members
0098     std::string path;
0099     std::unordered_set<std::string> servers;
0100     std::unordered_set<unsigned> modules;
0101   };
0102   struct Module {
0103     //currently assumes that a module can only have one associated model
0104     Module(const std::string& model_) : model(model_) {}
0105 
0106     //members
0107     std::string model;
0108   };
0109 
0110   TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg);
0111   ~TritonService() = default;
0112 
0113   //accessors
0114   void addModel(const std::string& modelName, const std::string& path);
0115   Server serverInfo(const std::string& model, const std::string& preferred = "") const;
0116   const std::string& pid() const { return pid_; }
0117   void notifyCallStatus(bool status) const;
0118 
0119   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0120 
0121 private:
0122   void preallocate(edm::service::SystemBounds const&);
0123   void preModuleConstruction(edm::ModuleDescription const&);
0124   void postModuleConstruction(edm::ModuleDescription const&);
0125   void preModuleDestruction(edm::ModuleDescription const&);
0126   void preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&);
0127   void postEndJob();
0128 
0129   //helper
0130   template <typename LOG>
0131   void printFallbackServerLog() const;
0132 
0133   bool verbose_;
0134   FallbackOpts fallbackOpts_;
0135   unsigned currentModuleId_;
0136   bool allowAddModel_;
0137   bool startedFallback_;
0138   mutable std::atomic<int> callFails_;
0139   std::string pid_;
0140   std::unordered_map<std::string, Model> unservedModels_;
0141   //this represents a many:many:many map
0142   std::unordered_map<std::string, Server> servers_;
0143   std::unordered_map<std::string, Model> models_;
0144   std::unordered_map<unsigned, Module> modules_;
0145   int numberOfThreads_;
0146 };
0147 
0148 #endif