Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:15:48

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h"
0002 
0003 #include "FWCore/ParameterSet/interface/FileInPath.h"
0004 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "FWCore/Framework/interface/MakerMacros.h"
0007 
0008 #include <sstream>
0009 #include <fstream>
0010 #include <string>
0011 #include <vector>
0012 #include <map>
0013 #include <cmath>
0014 #include <random>
0015 
0016 class TritonImageProducer : public TritonEDProducer<> {
0017 public:
0018   explicit TritonImageProducer(edm::ParameterSet const& cfg)
0019       : TritonEDProducer<>(cfg),
0020         batchSize_(cfg.getParameter<int>("batchSize")),
0021         topN_(cfg.getParameter<unsigned>("topN")) {
0022     //load score list
0023     std::string imageListFile(cfg.getParameter<edm::FileInPath>("imageList").fullPath());
0024     std::ifstream ifile(imageListFile);
0025     if (ifile.is_open()) {
0026       std::string line;
0027       while (std::getline(ifile, line)) {
0028         imageList_.push_back(line);
0029       }
0030     } else {
0031       throw cms::Exception("MissingFile") << "Could not open image list file: " << imageListFile;
0032     }
0033   }
0034   void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override {
0035     int actualBatchSize = batchSize_;
0036     //negative batch = generate random batch size from 1 to abs(batch)
0037     if (batchSize_ < 0) {
0038       //get event-based seed for RNG
0039       unsigned int runNum_uint = static_cast<unsigned int>(iEvent.id().run());
0040       unsigned int lumiNum_uint = static_cast<unsigned int>(iEvent.id().luminosityBlock());
0041       unsigned int evNum_uint = static_cast<unsigned int>(iEvent.id().event());
0042       std::uint32_t seed = (lumiNum_uint << 10) + (runNum_uint << 20) + evNum_uint;
0043       std::mt19937 rng(seed);
0044       std::uniform_int_distribution<int> randint(1, std::abs(batchSize_));
0045       actualBatchSize = randint(rng);
0046     }
0047 
0048     client_->setBatchSize(actualBatchSize);
0049     // create an npix x npix x ncol image w/ arbitrary color value
0050     // model only has one input, so just pick begin()
0051     auto& input1 = iInput.begin()->second;
0052     auto data1 = input1.allocate<float>();
0053     for (auto& vdata1 : *data1) {
0054       vdata1.assign(input1.sizeDims(), 0.5f);
0055     }
0056     // convert to server format
0057     input1.toServer(data1);
0058   }
0059   void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override {
0060     // check the results
0061     findTopN(iOutput.begin()->second);
0062   }
0063   ~TritonImageProducer() override = default;
0064 
0065   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0066     edm::ParameterSetDescription desc;
0067     TritonClient::fillPSetDescription(desc);
0068     desc.add<int>("batchSize", 1);
0069     desc.add<unsigned>("topN", 5);
0070     desc.add<edm::FileInPath>("imageList");
0071     //to ensure distinct cfi names
0072     descriptions.addWithDefaultLabel(desc);
0073   }
0074 
0075 private:
0076   void findTopN(const TritonOutputData& scores, unsigned n = 5) const {
0077     const auto& tmp = scores.fromServer<float>();
0078     auto dim = scores.sizeDims();
0079     for (unsigned i0 = 0; i0 < client_->batchSize(); i0++) {
0080       //match score to type by index, then put in largest-first map
0081       std::map<float, std::string, std::greater<float>> score_map;
0082       for (unsigned i = 0; i < std::min((unsigned)dim, (unsigned)imageList_.size()); ++i) {
0083         score_map.emplace(tmp[i0][i], imageList_[i]);
0084       }
0085       //get top n
0086       std::stringstream msg;
0087       msg << "Scores for image " << i0 << ":\n";
0088       unsigned counter = 0;
0089       for (const auto& item : score_map) {
0090         msg << item.second << " : " << item.first << "\n";
0091         ++counter;
0092         if (counter >= topN_)
0093           break;
0094       }
0095       edm::LogInfo(debugName_) << msg.str();
0096     }
0097   }
0098 
0099   int batchSize_;
0100   unsigned topN_;
0101   std::vector<std::string> imageList_;
0102 };
0103 
0104 DEFINE_FWK_MODULE(TritonImageProducer);