Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:47

0001 #ifndef HeterogeneousCore_SonicTriton_triton_utils
0002 #define HeterogeneousCore_SonicTriton_triton_utils
0003 
0004 #include "FWCore/Utilities/interface/Exception.h"
0005 #include "FWCore/Utilities/interface/Span.h"
0006 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0007 
0008 #include <string>
0009 #include <string_view>
0010 #include <vector>
0011 #include <unordered_set>
0012 
0013 #include "grpc_client.h"
0014 
0015 namespace triton_utils {
0016   template <typename C>
0017   std::string printColl(const C& coll, const std::string& delim = ", ");
0018   //implemented as a standalone function to avoid repeated specializations for different TritonData types
0019   template <typename DT>
0020   bool checkType(inference::DataType dtype) {
0021     return false;
0022   }
0023   //turn CMS exceptions into warnings
0024   void convertToWarning(const cms::Exception& e);
0025 }  // namespace triton_utils
0026 
0027 //explicit specializations (inlined)
0028 //special cases:
0029 //bool: vector<bool> doesn't have data() accessor, so use char (same byte size)
0030 //FP16 (half precision): no C++ primitive exists, so use uint16_t (e.g. with libminifloat)
0031 template <>
0032 inline bool triton_utils::checkType<char>(inference::DataType dtype) {
0033   return dtype == inference::DataType::TYPE_BOOL or dtype == inference::DataType::TYPE_STRING;
0034 }
0035 template <>
0036 inline bool triton_utils::checkType<uint8_t>(inference::DataType dtype) {
0037   return dtype == inference::DataType::TYPE_UINT8;
0038 }
0039 template <>
0040 inline bool triton_utils::checkType<uint16_t>(inference::DataType dtype) {
0041   return dtype == inference::DataType::TYPE_UINT16 or dtype == inference::DataType::TYPE_FP16;
0042 }
0043 template <>
0044 inline bool triton_utils::checkType<uint32_t>(inference::DataType dtype) {
0045   return dtype == inference::DataType::TYPE_UINT32;
0046 }
0047 template <>
0048 inline bool triton_utils::checkType<uint64_t>(inference::DataType dtype) {
0049   return dtype == inference::DataType::TYPE_UINT64;
0050 }
0051 template <>
0052 inline bool triton_utils::checkType<int8_t>(inference::DataType dtype) {
0053   return dtype == inference::DataType::TYPE_INT8;
0054 }
0055 template <>
0056 inline bool triton_utils::checkType<int16_t>(inference::DataType dtype) {
0057   return dtype == inference::DataType::TYPE_INT16;
0058 }
0059 template <>
0060 inline bool triton_utils::checkType<int32_t>(inference::DataType dtype) {
0061   return dtype == inference::DataType::TYPE_INT32;
0062 }
0063 template <>
0064 inline bool triton_utils::checkType<int64_t>(inference::DataType dtype) {
0065   return dtype == inference::DataType::TYPE_INT64;
0066 }
0067 template <>
0068 inline bool triton_utils::checkType<float>(inference::DataType dtype) {
0069   return dtype == inference::DataType::TYPE_FP32;
0070 }
0071 template <>
0072 inline bool triton_utils::checkType<double>(inference::DataType dtype) {
0073   return dtype == inference::DataType::TYPE_FP64;
0074 }
0075 
0076 //helper to turn triton error into exception
0077 //implemented as a macro to avoid constructing the MSG string for successful function calls
0078 #define TRITON_THROW_IF_ERROR(X, MSG, NOTIFY)                                                                         \
0079   {                                                                                                                   \
0080     triton::client::Error err = (X);                                                                                  \
0081     if (!err.IsOk())                                                                                                  \
0082       throw TritonException("TritonFailure", NOTIFY) << (MSG) << (err.Message().empty() ? "" : ": " + err.Message()); \
0083   }
0084 
0085 extern template std::string triton_utils::printColl(const edm::Span<std::vector<int64_t>::const_iterator>& coll,
0086                                                     const std::string& delim);
0087 extern template std::string triton_utils::printColl(const std::vector<uint8_t>& coll, const std::string& delim);
0088 extern template std::string triton_utils::printColl(const std::vector<float>& coll, const std::string& delim);
0089 extern template std::string triton_utils::printColl(const std::vector<std::string>& coll, const std::string& delim);
0090 extern template std::string triton_utils::printColl(const std::unordered_set<std::string>& coll,
0091                                                     const std::string& delim);
0092 
0093 #endif