File indexing completed on 2024-04-06 12:15:48
0001 #include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h"
0002
0003 #include "FWCore/ParameterSet/interface/FileInPath.h"
0004 #include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
0005 #include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
0006 #include "FWCore/Framework/interface/MakerMacros.h"
0007
0008 #include <sstream>
0009 #include <string>
0010 #include <vector>
0011 #include <map>
0012 #include <cmath>
0013
0014 class TritonIdentityProducer : public TritonEDProducer<> {
0015 public:
0016 explicit TritonIdentityProducer(edm::ParameterSet const& cfg)
0017 : TritonEDProducer<>(cfg), batchSizes_{1, 2, 0}, batchCounter_(0) {}
0018 void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override {
0019
0020 std::vector<std::vector<float>> value_lists{{2, 2}, {4, 4, 4, 4}, {1}, {3, 3, 3}};
0021
0022 client_->setBatchSize(batchSizes_[batchCounter_]);
0023 batchCounter_ = (batchCounter_ + 1) % batchSizes_.size();
0024 auto& input1 = iInput.at("INPUT0");
0025 auto data1 = input1.allocate<float>();
0026 for (unsigned i = 0; i < client_->batchSize(); ++i) {
0027 (*data1)[i] = value_lists[i];
0028 input1.setShape(0, (*data1)[i].size(), i);
0029 }
0030
0031
0032 input1.toServer(data1);
0033 }
0034 void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override {
0035
0036 const auto& output1 = iOutput.at("OUTPUT0");
0037
0038 const auto& tmp = output1.fromServer<float>();
0039 edm::LogInfo msg(debugName_);
0040 for (unsigned i = 0; i < client_->batchSize(); ++i) {
0041 msg << "output " << i << " (" << triton_utils::printColl(output1.shape(i)) << "): ";
0042 for (int j = 0; j < output1.shape(i)[0]; ++j) {
0043 msg << tmp[i][j] << " ";
0044 }
0045 msg << "\n";
0046 }
0047 }
0048 ~TritonIdentityProducer() override = default;
0049
0050 static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
0051 edm::ParameterSetDescription desc;
0052 TritonClient::fillPSetDescription(desc);
0053
0054 descriptions.addWithDefaultLabel(desc);
0055 }
0056
0057 private:
0058 std::vector<unsigned> batchSizes_;
0059 unsigned batchCounter_;
0060 };
0061
0062 DEFINE_FWK_MODULE(TritonIdentityProducer);