Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:21:01

0001 #include <cmath>
0002 
0003 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005 
0006 #include "L1Trigger/L1TMuonEndCapPhase2/interface/EMTFContext.h"
0007 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Utils/DataUtils.h"
0008 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Utils/DebugUtils.h"
0009 
0010 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Algo/ParameterAssignmentLayer.h"
0011 
0012 using namespace emtf::phase2;
0013 using namespace emtf::phase2::algo;
0014 
0015 ParameterAssignmentLayer::ParameterAssignmentLayer(const EMTFContext& context) : context_(context) {}
0016 
0017 void ParameterAssignmentLayer::apply(const bool& displaced_en, std::vector<track_t>& tracks) const {
0018   std::vector<int> feature_sites = {0, 1, 2,  3,  4, 5, 6, 7, 8, 9,  10, 11, 0, 1, 2, 3,  4,  5,  6,  7,
0019                                     8, 9, 10, 11, 0, 1, 2, 3, 4, 11, 0,  1,  2, 3, 4, 11, -1, -1, -1, -1};
0020 
0021   for (auto& track : tracks) {  // Begin loop tracks
0022     // Init Parameters
0023     track.pt = 0;
0024     track.rels = 0;
0025     track.dxy = 0;
0026     track.z0 = 0;
0027     track.beta = 0;
0028 
0029     track.pt_address = 0;
0030     track.rels_address = 0;
0031     track.dxy_address = 0;
0032 
0033     // Short-Circuit: Skip invalid tracks
0034     if (track.valid == 0) {
0035       continue;
0036     }
0037 
0038     // Get Features
0039     const auto& site_mask = track.site_mask;
0040     const auto& features = track.features;
0041 
0042     // Single batch of NTrackFeatures values
0043     tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, v3::kNumTrackFeatures});
0044 
0045     if (this->context_.config_.verbosity_ > 1) {
0046       edm::LogInfo("L1TEMTFpp") << "Parameter Assignment In"
0047                                 << " disp " << displaced_en << " zone " << track.zone << " col " << track.col << " pat "
0048                                 << track.pattern << " qual " << track.quality << " phi " << track.phi << " theta "
0049                                 << track.theta << " features " << std::endl;
0050     }
0051 
0052     // Prepare input tensor
0053     float* input_data = input.flat<float>().data();
0054 
0055     for (unsigned int i_feature = 0; i_feature < v3::kNumTrackFeatures; ++i_feature) {
0056       const auto& feature = features[i_feature];
0057       const auto& feature_site = feature_sites[i_feature];
0058 
0059       bool mask_value = false;
0060 
0061       // Debug Info
0062       if (this->context_.config_.verbosity_ > 1 && i_feature > 0) {
0063         edm::LogInfo("L1TEMTFpp") << " ";
0064       }
0065 
0066       // Mask invalid sites
0067       if (feature_site > -1) {
0068         mask_value = (site_mask[feature_site] == 0);
0069       }
0070 
0071       if (mask_value) {
0072         (*input_data) = 0.;
0073 
0074         // Debug Info
0075         if (this->context_.config_.verbosity_ > 1) {
0076           edm::LogInfo("L1TEMTFpp") << "0";
0077         }
0078       } else {
0079         (*input_data) = feature.to_float();
0080 
0081         // Debug Info
0082         if (this->context_.config_.verbosity_ > 1) {
0083           edm::LogInfo("L1TEMTFpp") << feature.to_float();
0084         }
0085       }
0086 
0087       input_data++;
0088     }
0089 
0090     // Debug Info
0091     if (this->context_.config_.verbosity_ > 1) {
0092       edm::LogInfo("L1TEMTFpp") << std::endl;
0093     }
0094 
0095     // Select TF Session
0096     auto* session_ptr = context_.prompt_session_ptr_;
0097 
0098     if (displaced_en) {
0099       session_ptr = context_.disp_session_ptr_;
0100     }
0101 
0102     // Evaluate Prompt
0103     std::vector<tensorflow::Tensor> outputs;
0104 
0105     tensorflow::run(session_ptr,
0106                     {{"inputs", input}},  // Input layer name
0107                     {"Identity"},         // Output layer name
0108                     &outputs);
0109 
0110     // Assign parameters
0111     if (displaced_en) {
0112       // Read displaced pb outputs
0113       auto pt_address = outputs[0].matrix<float>()(0, 0);
0114       auto rels_address = outputs[0].matrix<float>()(0, 1);
0115       auto dxy_address = outputs[0].matrix<float>()(0, 2);
0116 
0117       track.pt_address = std::clamp<float>(pt_address, -512, 511);
0118       track.rels_address = std::clamp<float>(rels_address, -512, 511);
0119       track.dxy_address = std::clamp<float>(dxy_address, -512, 511);
0120 
0121       track.q = (track.pt_address < 0);
0122       track.pt = context_.activation_lut_.lookupDispPt(track.pt_address);
0123       track.rels = context_.activation_lut_.lookupRels(track.rels_address);
0124       track.dxy = context_.activation_lut_.lookupDxy(track.dxy_address);
0125     } else {
0126       // Read prompt pb outputs
0127       auto pt_address = outputs[0].matrix<float>()(0, 0);
0128       auto rels_address = outputs[0].matrix<float>()(0, 1);
0129 
0130       track.pt_address = std::clamp<float>(pt_address, -512, 511);
0131       track.rels_address = std::clamp<float>(rels_address, -512, 511);
0132       track.dxy_address = 0;
0133 
0134       track.q = (track.pt_address < 0);
0135       track.pt = context_.activation_lut_.lookupPromptPt(track.pt_address);
0136       track.rels = context_.activation_lut_.lookupRels(track.rels_address);
0137       track.dxy = 0;
0138     }
0139 
0140     // DEBUG
0141     if (this->context_.config_.verbosity_ > 1) {
0142       edm::LogInfo("L1TEMTFpp") << "Parameter Assignment Out"
0143                                 << " disp " << displaced_en << " zone " << track.zone << " col " << track.col << " pat "
0144                                 << track.pattern << " qual " << track.quality << " q " << track.q << " pt " << track.pt
0145                                 << " rels " << track.rels << " dxy " << track.dxy << " z0 " << track.z0 << " phi "
0146                                 << track.phi << " theta " << track.theta << " beta " << track.beta << " pt_address "
0147                                 << track.pt_address << " rels_address " << track.rels_address << " dxy_address "
0148                                 << track.dxy_address << " valid " << track.valid << std::endl;
0149     }
0150   }  // End loop tracks
0151 }