Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-03-17 11:05:50

0001 #ifndef HeterogeneousCore_SonicTriton_triton_utils
0002 #define HeterogeneousCore_SonicTriton_triton_utils
0003 
0004 #include "FWCore/Utilities/interface/Span.h"
0005 #include "HeterogeneousCore/SonicTriton/interface/TritonException.h"
0006 
0007 #include <string>
0008 #include <string_view>
0009 #include <vector>
0010 #include <unordered_set>
0011 
0012 #include "grpc_client.h"
0013 
0014 namespace triton_utils {
0015   template <typename C>
0016   std::string printColl(const C& coll, const std::string& delim = ", ");
0017   //implemented as a standalone function to avoid repeated specializations for different TritonData types
0018   template <typename DT>
0019   bool checkType(inference::DataType dtype) {
0020     return false;
0021   }
0022 }  // namespace triton_utils
0023 
0024 //explicit specializations (inlined)
0025 //special cases:
0026 //bool: vector<bool> doesn't have data() accessor, so use char (same byte size)
0027 //FP16 (half precision): no C++ primitive exists, so use uint16_t (e.g. with libminifloat)
0028 template <>
0029 inline bool triton_utils::checkType<char>(inference::DataType dtype) {
0030   return dtype == inference::DataType::TYPE_BOOL or dtype == inference::DataType::TYPE_STRING;
0031 }
0032 template <>
0033 inline bool triton_utils::checkType<uint8_t>(inference::DataType dtype) {
0034   return dtype == inference::DataType::TYPE_UINT8;
0035 }
0036 template <>
0037 inline bool triton_utils::checkType<uint16_t>(inference::DataType dtype) {
0038   return dtype == inference::DataType::TYPE_UINT16 or dtype == inference::DataType::TYPE_FP16;
0039 }
0040 template <>
0041 inline bool triton_utils::checkType<uint32_t>(inference::DataType dtype) {
0042   return dtype == inference::DataType::TYPE_UINT32;
0043 }
0044 template <>
0045 inline bool triton_utils::checkType<uint64_t>(inference::DataType dtype) {
0046   return dtype == inference::DataType::TYPE_UINT64;
0047 }
0048 template <>
0049 inline bool triton_utils::checkType<int8_t>(inference::DataType dtype) {
0050   return dtype == inference::DataType::TYPE_INT8;
0051 }
0052 template <>
0053 inline bool triton_utils::checkType<int16_t>(inference::DataType dtype) {
0054   return dtype == inference::DataType::TYPE_INT16;
0055 }
0056 template <>
0057 inline bool triton_utils::checkType<int32_t>(inference::DataType dtype) {
0058   return dtype == inference::DataType::TYPE_INT32;
0059 }
0060 template <>
0061 inline bool triton_utils::checkType<int64_t>(inference::DataType dtype) {
0062   return dtype == inference::DataType::TYPE_INT64;
0063 }
0064 template <>
0065 inline bool triton_utils::checkType<float>(inference::DataType dtype) {
0066   return dtype == inference::DataType::TYPE_FP32;
0067 }
0068 template <>
0069 inline bool triton_utils::checkType<double>(inference::DataType dtype) {
0070   return dtype == inference::DataType::TYPE_FP64;
0071 }
0072 
0073 //helper to turn triton error into exception
0074 //implemented as a macro to avoid constructing the MSG string for successful function calls
0075 #define TRITON_THROW_IF_ERROR(X, MSG)                                                                         \
0076   {                                                                                                           \
0077     triton::client::Error err = (X);                                                                          \
0078     if (!err.IsOk())                                                                                          \
0079       throw TritonException("TritonFailure") << (MSG) << (err.Message().empty() ? "" : ": " + err.Message()); \
0080   }
0081 
0082 extern template std::string triton_utils::printColl(const edm::Span<std::vector<int64_t>::const_iterator>& coll,
0083                                                     const std::string& delim);
0084 extern template std::string triton_utils::printColl(const std::vector<uint8_t>& coll, const std::string& delim);
0085 extern template std::string triton_utils::printColl(const std::vector<float>& coll, const std::string& delim);
0086 extern template std::string triton_utils::printColl(const std::unordered_set<std::string>& coll,
0087                                                     const std::string& delim);
0088 
0089 #endif