Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:47

0001 #ifndef HeterogeneousCore_SonicTriton_TritonService
0002 #define HeterogeneousCore_SonicTriton_TritonService
0003 
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/Utilities/interface/GlobalIdentifier.h"
0006 
0007 #include <vector>
0008 #include <unordered_set>
0009 #include <unordered_map>
0010 #include <string>
0011 #include <functional>
0012 #include <utility>
0013 #include <atomic>
0014 
0015 #include "grpc_client.h"
0016 
0017 //forward declarations
0018 namespace edm {
0019   class ActivityRegistry;
0020   class ConfigurationDescriptions;
0021   class PathsAndConsumesOfModulesBase;
0022   class ProcessContext;
0023   class ModuleDescription;
0024   namespace service {
0025     class SystemBounds;
0026   }
0027 }  // namespace edm
0028 
0029 enum class TritonServerType { Remote = 0, LocalCPU = 1, LocalGPU = 2 };
0030 
0031 class TritonService {
0032 public:
0033   //classes and defs
0034   struct FallbackOpts {
0035     FallbackOpts(const edm::ParameterSet& pset)
0036         : enable(pset.getUntrackedParameter<bool>("enable")),
0037           debug(pset.getUntrackedParameter<bool>("debug")),
0038           verbose(pset.getUntrackedParameter<bool>("verbose")),
0039           useDocker(pset.getUntrackedParameter<bool>("useDocker")),
0040           useGPU(pset.getUntrackedParameter<bool>("useGPU")),
0041           retries(pset.getUntrackedParameter<int>("retries")),
0042           wait(pset.getUntrackedParameter<int>("wait")),
0043           instanceName(pset.getUntrackedParameter<std::string>("instanceName")),
0044           tempDir(pset.getUntrackedParameter<std::string>("tempDir")),
0045           imageName(pset.getUntrackedParameter<std::string>("imageName")),
0046           sandboxName(pset.getUntrackedParameter<std::string>("sandboxName")) {
0047       //randomize instance name
0048       if (instanceName.empty()) {
0049         instanceName =
0050             pset.getUntrackedParameter<std::string>("instanceBaseName") + "_" + edm::createGlobalIdentifier();
0051       }
0052     }
0053 
0054     bool enable;
0055     bool debug;
0056     bool verbose;
0057     bool useDocker;
0058     bool useGPU;
0059     int retries;
0060     int wait;
0061     std::string instanceName;
0062     std::string tempDir;
0063     std::string imageName;
0064     std::string sandboxName;
0065     std::string command;
0066   };
0067   struct Server {
0068     Server(const edm::ParameterSet& pset)
0069         : url(pset.getUntrackedParameter<std::string>("address") + ":" +
0070               std::to_string(pset.getUntrackedParameter<unsigned>("port"))),
0071           isFallback(pset.getUntrackedParameter<std::string>("name") == fallbackName),
0072           useSsl(pset.getUntrackedParameter<bool>("useSsl")),
0073           type(TritonServerType::Remote) {
0074       if (useSsl) {
0075         sslOptions.root_certificates = pset.getUntrackedParameter<std::string>("rootCertificates");
0076         sslOptions.private_key = pset.getUntrackedParameter<std::string>("privateKey");
0077         sslOptions.certificate_chain = pset.getUntrackedParameter<std::string>("certificateChain");
0078       }
0079     }
0080     Server(const std::string& name_, const std::string& url_, TritonServerType type_)
0081         : url(url_), isFallback(name_ == fallbackName), useSsl(false), type(type_) {}
0082 
0083     //members
0084     std::string url;
0085     bool isFallback;
0086     bool useSsl;
0087     TritonServerType type;
0088     triton::client::SslOptions sslOptions;
0089     std::unordered_set<std::string> models;
0090     static const std::string fallbackName;
0091     static const std::string fallbackAddress;
0092   };
0093   struct Model {
0094     Model(const std::string& path_ = "") : path(path_) {}
0095 
0096     //members
0097     std::string path;
0098     std::unordered_set<std::string> servers;
0099     std::unordered_set<unsigned> modules;
0100   };
0101   struct Module {
0102     //currently assumes that a module can only have one associated model
0103     Module(const std::string& model_) : model(model_) {}
0104 
0105     //members
0106     std::string model;
0107   };
0108 
0109   TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg);
0110   ~TritonService() = default;
0111 
0112   //accessors
0113   void addModel(const std::string& modelName, const std::string& path);
0114   Server serverInfo(const std::string& model, const std::string& preferred = "") const;
0115   const std::string& pid() const { return pid_; }
0116   void notifyCallStatus(bool status) const;
0117 
0118   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0119 
0120 private:
0121   void preallocate(edm::service::SystemBounds const&);
0122   void preModuleConstruction(edm::ModuleDescription const&);
0123   void postModuleConstruction(edm::ModuleDescription const&);
0124   void preModuleDestruction(edm::ModuleDescription const&);
0125   void preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&);
0126   void postEndJob();
0127 
0128   //helper
0129   template <typename LOG>
0130   void printFallbackServerLog() const;
0131 
0132   bool verbose_;
0133   FallbackOpts fallbackOpts_;
0134   unsigned currentModuleId_;
0135   bool allowAddModel_;
0136   bool startedFallback_;
0137   mutable std::atomic<int> callFails_;
0138   std::string pid_;
0139   std::unordered_map<std::string, Model> unservedModels_;
0140   //this represents a many:many:many map
0141   std::unordered_map<std::string, Server> servers_;
0142   std::unordered_map<std::string, Model> models_;
0143   std::unordered_map<unsigned, Module> modules_;
0144   int numberOfThreads_;
0145 };
0146 
0147 #endif