File indexing completed on 2024-08-06 22:36:40
0001 #ifndef HeterogeneousCore_SonicTriton_TritonService
0002 #define HeterogeneousCore_SonicTriton_TritonService
0003
0004 #include "FWCore/ParameterSet/interface/ParameterSet.h"
0005 #include "FWCore/Utilities/interface/GlobalIdentifier.h"
0006
0007 #include <vector>
0008 #include <unordered_set>
0009 #include <unordered_map>
0010 #include <string>
0011 #include <functional>
0012 #include <utility>
0013 #include <atomic>
0014
0015 #include "grpc_client.h"
0016
0017
0018 namespace edm {
0019 class ActivityRegistry;
0020 class ConfigurationDescriptions;
0021 class PathsAndConsumesOfModulesBase;
0022 class ProcessContext;
0023 class ModuleDescription;
0024 namespace service {
0025 class SystemBounds;
0026 }
0027 }
0028
0029 enum class TritonServerType { Remote = 0, LocalCPU = 1, LocalGPU = 2 };
0030
0031 class TritonService {
0032 public:
0033
0034 struct FallbackOpts {
0035 FallbackOpts(const edm::ParameterSet& pset)
0036 : enable(pset.getUntrackedParameter<bool>("enable")),
0037 debug(pset.getUntrackedParameter<bool>("debug")),
0038 verbose(pset.getUntrackedParameter<bool>("verbose")),
0039 container(pset.getUntrackedParameter<std::string>("container")),
0040 device(pset.getUntrackedParameter<std::string>("device")),
0041 retries(pset.getUntrackedParameter<int>("retries")),
0042 wait(pset.getUntrackedParameter<int>("wait")),
0043 instanceName(pset.getUntrackedParameter<std::string>("instanceName")),
0044 tempDir(pset.getUntrackedParameter<std::string>("tempDir")),
0045 imageName(pset.getUntrackedParameter<std::string>("imageName")),
0046 sandboxName(pset.getUntrackedParameter<std::string>("sandboxName")) {
0047
0048 if (instanceName.empty()) {
0049 instanceName =
0050 pset.getUntrackedParameter<std::string>("instanceBaseName") + "_" + edm::createGlobalIdentifier();
0051 }
0052 }
0053
0054 bool enable;
0055 bool debug;
0056 bool verbose;
0057 std::string container;
0058 std::string device;
0059 int retries;
0060 int wait;
0061 std::string instanceName;
0062 std::string tempDir;
0063 std::string imageName;
0064 std::string sandboxName;
0065 std::string command;
0066 };
0067 struct Server {
0068 Server(const edm::ParameterSet& pset)
0069 : url(pset.getUntrackedParameter<std::string>("address") + ":" +
0070 std::to_string(pset.getUntrackedParameter<unsigned>("port"))),
0071 isFallback(pset.getUntrackedParameter<std::string>("name") == fallbackName),
0072 useSsl(pset.getUntrackedParameter<bool>("useSsl")),
0073 type(TritonServerType::Remote) {
0074 if (useSsl) {
0075 sslOptions.root_certificates = pset.getUntrackedParameter<std::string>("rootCertificates");
0076 sslOptions.private_key = pset.getUntrackedParameter<std::string>("privateKey");
0077 sslOptions.certificate_chain = pset.getUntrackedParameter<std::string>("certificateChain");
0078 }
0079 }
0080 Server(const std::string& name_, const std::string& url_, TritonServerType type_)
0081 : url(url_), isFallback(name_ == fallbackName), useSsl(false), type(type_) {}
0082
0083
0084 std::string url;
0085 bool isFallback;
0086 bool useSsl;
0087 TritonServerType type;
0088 triton::client::SslOptions sslOptions;
0089 std::unordered_set<std::string> models;
0090 static const std::string fallbackName;
0091 static const std::string fallbackAddress;
0092 static const std::string siteconfName;
0093 };
0094 struct Model {
0095 Model(const std::string& path_ = "") : path(path_) {}
0096
0097
0098 std::string path;
0099 std::unordered_set<std::string> servers;
0100 std::unordered_set<unsigned> modules;
0101 };
0102 struct Module {
0103
0104 Module(const std::string& model_) : model(model_) {}
0105
0106
0107 std::string model;
0108 };
0109
0110 TritonService(const edm::ParameterSet& pset, edm::ActivityRegistry& areg);
0111 ~TritonService() = default;
0112
0113
0114 void addModel(const std::string& modelName, const std::string& path);
0115 Server serverInfo(const std::string& model, const std::string& preferred = "") const;
0116 const std::string& pid() const { return pid_; }
0117 void notifyCallStatus(bool status) const;
0118
0119 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
0120
0121 private:
0122 void preallocate(edm::service::SystemBounds const&);
0123 void preModuleConstruction(edm::ModuleDescription const&);
0124 void postModuleConstruction(edm::ModuleDescription const&);
0125 void preModuleDestruction(edm::ModuleDescription const&);
0126 void preBeginJob(edm::PathsAndConsumesOfModulesBase const&, edm::ProcessContext const&);
0127 void postEndJob();
0128
0129
0130 template <typename LOG>
0131 void printFallbackServerLog() const;
0132
0133 bool verbose_;
0134 FallbackOpts fallbackOpts_;
0135 unsigned currentModuleId_;
0136 bool allowAddModel_;
0137 bool startedFallback_;
0138 mutable std::atomic<int> callFails_;
0139 std::string pid_;
0140 std::unordered_map<std::string, Model> unservedModels_;
0141
0142 std::unordered_map<std::string, Server> servers_;
0143 std::unordered_map<std::string, Model> models_;
0144 std::unordered_map<unsigned, Module> modules_;
0145 int numberOfThreads_;
0146 };
0147
0148 #endif