File indexing completed on 2024-04-06 12:21:01
0001 #include <cmath>
0002
0003 #include "FWCore/MessageLogger/interface/MessageLogger.h"
0004 #include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
0005
0006 #include "L1Trigger/L1TMuonEndCapPhase2/interface/EMTFContext.h"
0007 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Utils/DataUtils.h"
0008 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Utils/DebugUtils.h"
0009
0010 #include "L1Trigger/L1TMuonEndCapPhase2/interface/Algo/ParameterAssignmentLayer.h"
0011
0012 using namespace emtf::phase2;
0013 using namespace emtf::phase2::algo;
0014
0015 ParameterAssignmentLayer::ParameterAssignmentLayer(const EMTFContext& context) : context_(context) {}
0016
0017 void ParameterAssignmentLayer::apply(const bool& displaced_en, std::vector<track_t>& tracks) const {
0018 std::vector<int> feature_sites = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7,
0019 8, 9, 10, 11, 0, 1, 2, 3, 4, 11, 0, 1, 2, 3, 4, 11, -1, -1, -1, -1};
0020
0021 for (auto& track : tracks) {
0022
0023 track.pt = 0;
0024 track.rels = 0;
0025 track.dxy = 0;
0026 track.z0 = 0;
0027 track.beta = 0;
0028
0029 track.pt_address = 0;
0030 track.rels_address = 0;
0031 track.dxy_address = 0;
0032
0033
0034 if (track.valid == 0) {
0035 continue;
0036 }
0037
0038
0039 const auto& site_mask = track.site_mask;
0040 const auto& features = track.features;
0041
0042
0043 tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, v3::kNumTrackFeatures});
0044
0045 if (this->context_.config_.verbosity_ > 1) {
0046 edm::LogInfo("L1TEMTFpp") << "Parameter Assignment In"
0047 << " disp " << displaced_en << " zone " << track.zone << " col " << track.col << " pat "
0048 << track.pattern << " qual " << track.quality << " phi " << track.phi << " theta "
0049 << track.theta << " features " << std::endl;
0050 }
0051
0052
0053 float* input_data = input.flat<float>().data();
0054
0055 for (unsigned int i_feature = 0; i_feature < v3::kNumTrackFeatures; ++i_feature) {
0056 const auto& feature = features[i_feature];
0057 const auto& feature_site = feature_sites[i_feature];
0058
0059 bool mask_value = false;
0060
0061
0062 if (this->context_.config_.verbosity_ > 1 && i_feature > 0) {
0063 edm::LogInfo("L1TEMTFpp") << " ";
0064 }
0065
0066
0067 if (feature_site > -1) {
0068 mask_value = (site_mask[feature_site] == 0);
0069 }
0070
0071 if (mask_value) {
0072 (*input_data) = 0.;
0073
0074
0075 if (this->context_.config_.verbosity_ > 1) {
0076 edm::LogInfo("L1TEMTFpp") << "0";
0077 }
0078 } else {
0079 (*input_data) = feature.to_float();
0080
0081
0082 if (this->context_.config_.verbosity_ > 1) {
0083 edm::LogInfo("L1TEMTFpp") << feature.to_float();
0084 }
0085 }
0086
0087 input_data++;
0088 }
0089
0090
0091 if (this->context_.config_.verbosity_ > 1) {
0092 edm::LogInfo("L1TEMTFpp") << std::endl;
0093 }
0094
0095
0096 auto* session_ptr = context_.prompt_session_ptr_;
0097
0098 if (displaced_en) {
0099 session_ptr = context_.disp_session_ptr_;
0100 }
0101
0102
0103 std::vector<tensorflow::Tensor> outputs;
0104
0105 tensorflow::run(session_ptr,
0106 {{"inputs", input}},
0107 {"Identity"},
0108 &outputs);
0109
0110
0111 if (displaced_en) {
0112
0113 auto pt_address = outputs[0].matrix<float>()(0, 0);
0114 auto rels_address = outputs[0].matrix<float>()(0, 1);
0115 auto dxy_address = outputs[0].matrix<float>()(0, 2);
0116
0117 track.pt_address = std::clamp<float>(pt_address, -512, 511);
0118 track.rels_address = std::clamp<float>(rels_address, -512, 511);
0119 track.dxy_address = std::clamp<float>(dxy_address, -512, 511);
0120
0121 track.q = (track.pt_address < 0);
0122 track.pt = context_.activation_lut_.lookupDispPt(track.pt_address);
0123 track.rels = context_.activation_lut_.lookupRels(track.rels_address);
0124 track.dxy = context_.activation_lut_.lookupDxy(track.dxy_address);
0125 } else {
0126
0127 auto pt_address = outputs[0].matrix<float>()(0, 0);
0128 auto rels_address = outputs[0].matrix<float>()(0, 1);
0129
0130 track.pt_address = std::clamp<float>(pt_address, -512, 511);
0131 track.rels_address = std::clamp<float>(rels_address, -512, 511);
0132 track.dxy_address = 0;
0133
0134 track.q = (track.pt_address < 0);
0135 track.pt = context_.activation_lut_.lookupPromptPt(track.pt_address);
0136 track.rels = context_.activation_lut_.lookupRels(track.rels_address);
0137 track.dxy = 0;
0138 }
0139
0140
0141 if (this->context_.config_.verbosity_ > 1) {
0142 edm::LogInfo("L1TEMTFpp") << "Parameter Assignment Out"
0143 << " disp " << displaced_en << " zone " << track.zone << " col " << track.col << " pat "
0144 << track.pattern << " qual " << track.quality << " q " << track.q << " pt " << track.pt
0145 << " rels " << track.rels << " dxy " << track.dxy << " z0 " << track.z0 << " phi "
0146 << track.phi << " theta " << track.theta << " beta " << track.beta << " pt_address "
0147 << track.pt_address << " rels_address " << track.rels_address << " dxy_address "
0148 << track.dxy_address << " valid " << track.valid << std::endl;
0149 }
0150 }
0151 }