Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2023-10-25 09:50:30

0001 #include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h"
0002 
0003 #include "FWCore/ParameterSet/interface/FileInPath.h"
0004 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "FWCore/Framework/interface/MakerMacros.h"
0007 
0008 #include <sstream>
0009 #include <string>
0010 #include <vector>
0011 #include <map>
0012 #include <cmath>
0013 
0014 class TritonIdentityProducer : public TritonEDProducer<> {
0015 public:
0016   explicit TritonIdentityProducer(edm::ParameterSet const& cfg)
0017       : TritonEDProducer<>(cfg), batchSizes_{1, 2, 0}, batchCounter_(0) {}
0018   void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override {
0019     //follow Triton QA tests for ragged input
0020     std::vector<std::vector<float>> value_lists{{2, 2}, {4, 4, 4, 4}, {1}, {3, 3, 3}};
0021 
0022     client_->setBatchSize(batchSizes_[batchCounter_]);
0023     batchCounter_ = (batchCounter_ + 1) % batchSizes_.size();
0024     auto& input1 = iInput.at("INPUT0");
0025     auto data1 = input1.allocate<float>();
0026     for (unsigned i = 0; i < client_->batchSize(); ++i) {
0027       (*data1)[i] = value_lists[i];
0028       input1.setShape(0, (*data1)[i].size(), i);
0029     }
0030 
0031     // convert to server format
0032     input1.toServer(data1);
0033   }
0034   void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override {
0035     // check the results
0036     const auto& output1 = iOutput.at("OUTPUT0");
0037     // convert from server format
0038     const auto& tmp = output1.fromServer<float>();
0039     edm::LogInfo msg(debugName_);
0040     for (unsigned i = 0; i < client_->batchSize(); ++i) {
0041       msg << "output " << i << " (" << triton_utils::printColl(output1.shape(i)) << "): ";
0042       for (int j = 0; j < output1.shape(i)[0]; ++j) {
0043         msg << tmp[i][j] << " ";
0044       }
0045       msg << "\n";
0046     }
0047   }
0048   ~TritonIdentityProducer() override = default;
0049 
0050   static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0051     edm::ParameterSetDescription desc;
0052     TritonClient::fillPSetDescription(desc);
0053     //to ensure distinct cfi names
0054     descriptions.addWithDefaultLabel(desc);
0055   }
0056 
0057 private:
0058   std::vector<unsigned> batchSizes_;
0059   unsigned batchCounter_;
0060 };
0061 
0062 DEFINE_FWK_MODULE(TritonIdentityProducer);