** Warning **
Issuing rollback() due to DESTROY without explicit disconnect() of DBD::mysql::db handle dbname=lxr at /lxr/lib/LXR/Common.pm line 1113.
Last-Modified: Sun, 21 Jun 2025 01:29:21 GMT
Content-Type: text/html; charset=utf-8
/CMSSW_15_1_X_2025-06-20-2300/HeterogeneousCore/SonicTriton/interface/TritonService.h
File indexing completed on 2025-04-13 22:50:00
0001 #ifndef HeterogeneousCore_SonicTriton_TritonService
0002 #define HeterogeneousCore_SonicTriton_TritonService
0003
0004 #include "FWCore /ParameterSet /interface /ParameterSet.h "
0005 #include "FWCore /Utilities /interface /GlobalIdentifier.h "
0006
0007 #include <vector>
0008 #include <unordered_set>
0009 #include <unordered_map>
0010 #include <string>
0011 #include <functional>
0012 #include <utility>
0013 #include <atomic>
0014
0015 #include "grpc_client.h"
0016
0017
0018 namespace edm {
0019 class ActivityRegistry ;
0020 class ConfigurationDescriptions ;
0021 class ProcessContext ;
0022 class ModuleDescription ;
0023 namespace service {
0024 class SystemBounds ;
0025 }
0026 }
0027
0028 enum class TritonServerType { Remote = 0, LocalCPU = 1, LocalGPU = 2 };
0029
0030 class TritonService {
0031 public :
0032
0033 struct FallbackOpts {
0034 FallbackOpts (const edm ::ParameterSet & pset )
0035 : enable (pset .getUntrackedParameter <bool >("enable" )),
0036 debug (pset .getUntrackedParameter <bool >("debug" )),
0037 verbose (pset .getUntrackedParameter <bool >("verbose" )),
0038 container (pset .getUntrackedParameter <std ::string >("container" )),
0039 device (pset .getUntrackedParameter <std ::string >("device" )),
0040 retries (pset .getUntrackedParameter <int >("retries" )),
0041 wait (pset .getUntrackedParameter <int >("wait" )),
0042 instanceName (pset .getUntrackedParameter <std ::string >("instanceName" )),
0043 tempDir (pset .getUntrackedParameter <std ::string >("tempDir" )),
0044 imageName (pset .getUntrackedParameter <std ::string >("imageName" )),
0045 sandboxName (pset .getUntrackedParameter <std ::string >("sandboxName" )) {
0046
0047 if (instanceName .empty ()) {
0048 instanceName =
0049 pset .getUntrackedParameter <std ::string >("instanceBaseName" ) + "_" + edm ::createGlobalIdentifier ();
0050 }
0051 }
0052
0053 bool enable ;
0054 bool debug ;
0055 bool verbose ;
0056 std ::string container ;
0057 std ::string device ;
0058 int retries ;
0059 int wait ;
0060 std ::string instanceName ;
0061 std ::string tempDir ;
0062 std ::string imageName ;
0063 std ::string sandboxName ;
0064 std ::string command ;
0065 };
0066 struct Server {
0067 Server (const edm ::ParameterSet & pset )
0068 : url (pset .getUntrackedParameter <std ::string >("address" ) + ":" +
0069 std ::to_string (pset .getUntrackedParameter <unsigned >("port" ))),
0070 isFallback (pset .getUntrackedParameter <std ::string >("name" ) == fallbackName ),
0071 useSsl (pset .getUntrackedParameter <bool >("useSsl" )),
0072 type (TritonServerType ::Remote ) {
0073 if (useSsl ) {
0074 sslOptions .root_certificates = pset .getUntrackedParameter <std ::string >("rootCertificates" );
0075 sslOptions .private_key = pset .getUntrackedParameter <std ::string >("privateKey" );
0076 sslOptions .certificate_chain = pset .getUntrackedParameter <std ::string >("certificateChain" );
0077 }
0078 }
0079 Server (const std ::string & name_ , const std ::string & url_, TritonServerType type_ )
0080 : url (url_), isFallback (name_ == fallbackName ), useSsl (false ), type (type_ ) {}
0081
0082
0083 std ::string url ;
0084 bool isFallback ;
0085 bool useSsl ;
0086 TritonServerType type ;
0087 triton::client ::SslOptions sslOptions ;
0088 std ::unordered_set<std ::string > models ;
0089 static const std ::string fallbackName ;
0090 static const std ::string fallbackAddress ;
0091 static const std ::string siteconfName ;
0092 };
0093 struct Model {
0094 Model (const std ::string & path_ = "" ) : path (path_ ) {}
0095
0096
0097 std ::string path ;
0098 std ::unordered_set<std ::string > servers ;
0099 std ::unordered_set<unsigned > modules ;
0100 };
0101 struct Module {
0102
0103 Module (const std ::string & model_ ) : model (model_ ) {}
0104
0105
0106 std ::string model ;
0107 };
0108
0109 TritonService (const edm ::ParameterSet & pset , edm ::ActivityRegistry & areg );
0110 ~TritonService () = default ;
0111
0112
0113 void addModel (const std ::string & modelName , const std ::string & path );
0114 Server serverInfo (const std ::string & model , const std ::string & preferred = "" ) const ;
0115 const std ::string & pid () const { return pid_ ; }
0116 void notifyCallStatus (bool status ) const ;
0117
0118 static void fillDescriptions (edm ::ConfigurationDescriptions & descriptions );
0119
0120 private :
0121 void preallocate (edm ::service ::SystemBounds const &);
0122 void preModuleConstruction (edm ::ModuleDescription const &);
0123 void postModuleConstruction (edm ::ModuleDescription const &);
0124 void preModuleDestruction (edm ::ModuleDescription const &);
0125 void preBeginJob (edm ::ProcessContext const &);
0126 void postEndJob ();
0127
0128
0129 template <typename LOG >
0130 void printFallbackServerLog () const ;
0131
0132 bool verbose_ ;
0133 FallbackOpts fallbackOpts_ ;
0134 unsigned currentModuleId_ ;
0135 bool allowAddModel_ ;
0136 bool startedFallback_ ;
0137 mutable std ::atomic<int > callFails_ ;
0138 std ::string pid_ ;
0139 std ::unordered_map<std ::string , Model > unservedModels_ ;
0140
0141 std ::unordered_map<std ::string , Server > servers_ ;
0142 std ::unordered_map<std ::string , Model > models_ ;
0143 std ::unordered_map<unsigned , Module > modules_ ;
0144 int numberOfThreads_ ;
0145 };
0146
0147 #endif