Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-10-19 10:28:52

0001 /*
0002  * \class DeepTauBase
0003  *
0004  * Implementation of the base class for tau identification using Deep NN.
0005  *
0006  * \author Konstantin Androsov, INFN Pisa
0007  * \author Maria Rosaria Di Domenico, University of Siena & INFN Pisa
0008  */
0009 
0010 //TODO: port to offline RECO/AOD inputs to allow usage with offline AOD
0011 //TODO: Take into account that PFTaus can also be build with pat::PackedCandidates
0012 
0013 #include "RecoTauTag/RecoTau/interface/DeepTauBase.h"
0014 
0015 namespace deep_tau {
0016 
0017   TauWPThreshold::TauWPThreshold(const std::string& cut_str) {
0018     bool simple_value = false;
0019     try {
0020       size_t pos = 0;
0021       value_ = std::stod(cut_str, &pos);
0022       simple_value = (pos == cut_str.size());
0023     } catch (std::invalid_argument&) {
0024     } catch (std::out_of_range&) {
0025     }
0026     if (!simple_value) {
0027       static const std::string prefix =
0028           "[&](double *x, double *p) { const int decayMode = p[0];"
0029           "const double pt = p[1]; const double eta = p[2];";
0030       static const int n_params = 3;
0031       static const auto handler = [](int, Bool_t, const char*, const char*) -> void {};
0032 
0033       std::string fn_str = prefix;
0034       if (cut_str.find("return") == std::string::npos)
0035         fn_str += " return " + cut_str + ";}";
0036       else
0037         fn_str += cut_str + "}";
0038       auto old_handler = SetErrorHandler(handler);
0039       fn_ = std::make_unique<TF1>("fn_", fn_str.c_str(), 0, 1, n_params);
0040       SetErrorHandler(old_handler);
0041       if (!fn_->IsValid())
0042         throw cms::Exception("TauWPThreshold: invalid formula") << "Invalid WP cut formula = '" << cut_str << "'.";
0043     }
0044   }
0045 
0046   double TauWPThreshold::operator()(const reco::BaseTau& tau, bool isPFTau) const {
0047     if (!fn_) {
0048       return value_;
0049     }
0050 
0051     if (isPFTau)
0052       fn_->SetParameter(0, dynamic_cast<const reco::PFTau&>(tau).decayMode());
0053     else
0054       fn_->SetParameter(0, dynamic_cast<const pat::Tau&>(tau).decayMode());
0055     fn_->SetParameter(1, tau.pt());
0056     fn_->SetParameter(2, tau.eta());
0057     return fn_->Eval(0);
0058   }
0059 
0060   std::unique_ptr<DeepTauBase::TauDiscriminator> DeepTauBase::Output::get_value(const edm::Handle<TauCollection>& taus,
0061                                                                                 const tensorflow::Tensor& pred,
0062                                                                                 const WPList* working_points,
0063                                                                                 bool is_online) const {
0064     std::vector<reco::SingleTauDiscriminatorContainer> outputbuffer(taus->size());
0065 
0066     for (size_t tau_index = 0; tau_index < taus->size(); ++tau_index) {
0067       float x = 0;
0068       for (size_t num_elem : num_)
0069         x += pred.matrix<float>()(tau_index, num_elem);
0070       if (x != 0 && !den_.empty()) {
0071         float den_val = 0;
0072         for (size_t den_elem : den_)
0073           den_val += pred.matrix<float>()(tau_index, den_elem);
0074         x = den_val != 0 ? x / den_val : std::numeric_limits<float>::max();
0075       }
0076       outputbuffer[tau_index].rawValues.push_back(x);
0077       if (working_points) {
0078         for (const auto& wp : *working_points) {
0079           const bool pass = x > (*wp)(taus->at(tau_index), is_online);
0080           outputbuffer[tau_index].workingPoints.push_back(pass);
0081         }
0082       }
0083     }
0084     std::unique_ptr<TauDiscriminator> output = std::make_unique<TauDiscriminator>();
0085     reco::TauDiscriminatorContainer::Filler filler(*output);
0086     filler.insert(taus, outputbuffer.begin(), outputbuffer.end());
0087     filler.fill();
0088     return output;
0089   }
0090 
0091   DeepTauBase::DeepTauBase(const edm::ParameterSet& cfg,
0092                            const OutputCollection& outputCollection,
0093                            const DeepTauCache* cache)
0094       : tausToken_(consumes<TauCollection>(cfg.getParameter<edm::InputTag>("taus"))),
0095         pfcandToken_(consumes<CandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
0096         vtxToken_(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
0097         is_online_(cfg.getParameter<bool>("is_online")),
0098         outputs_(outputCollection),
0099         cache_(cache) {
0100     for (const auto& output_desc : outputs_) {
0101       produces<TauDiscriminator>(output_desc.first);
0102       const auto& cut_list = cfg.getParameter<std::vector<std::string>>(output_desc.first + "WP");
0103       for (const std::string& cut_str : cut_list) {
0104         workingPoints_[output_desc.first].push_back(std::make_unique<Cutter>(cut_str));
0105       }
0106     }
0107 
0108     // prediscriminant operator
0109     // require the tau to pass the following prediscriminants
0110     const edm::ParameterSet& prediscriminantConfig = cfg.getParameter<edm::ParameterSet>("Prediscriminants");
0111 
0112     // determine boolean operator used on the prediscriminants
0113     std::string pdBoolOperator = prediscriminantConfig.getParameter<std::string>("BooleanOperator");
0114     // convert string to lowercase
0115     transform(pdBoolOperator.begin(), pdBoolOperator.end(), pdBoolOperator.begin(), ::tolower);
0116 
0117     if (pdBoolOperator == "and") {
0118       andPrediscriminants_ = 0x1;  //use chars instead of bools so we can do a bitwise trick later
0119     } else if (pdBoolOperator == "or") {
0120       andPrediscriminants_ = 0x0;
0121     } else {
0122       throw cms::Exception("TauDiscriminationProducerBase")
0123           << "PrediscriminantBooleanOperator defined incorrectly, options are: AND,OR";
0124     }
0125 
0126     // get the list of prediscriminants
0127     std::vector<std::string> prediscriminantsNames =
0128         prediscriminantConfig.getParameterNamesForType<edm::ParameterSet>();
0129 
0130     for (auto const& iDisc : prediscriminantsNames) {
0131       const edm::ParameterSet& iPredisc = prediscriminantConfig.getParameter<edm::ParameterSet>(iDisc);
0132       const edm::InputTag& label = iPredisc.getParameter<edm::InputTag>("Producer");
0133       double cut = iPredisc.getParameter<double>("cut");
0134 
0135       if (is_online_) {
0136         TauDiscInfo<reco::PFTauDiscriminator> thisDiscriminator;
0137         thisDiscriminator.label = label;
0138         thisDiscriminator.cut = cut;
0139         thisDiscriminator.disc_token = consumes<reco::PFTauDiscriminator>(label);
0140         recoPrediscriminants_.push_back(thisDiscriminator);
0141       } else {
0142         TauDiscInfo<pat::PATTauDiscriminator> thisDiscriminator;
0143         thisDiscriminator.label = label;
0144         thisDiscriminator.cut = cut;
0145         thisDiscriminator.disc_token = consumes<pat::PATTauDiscriminator>(label);
0146         patPrediscriminants_.push_back(thisDiscriminator);
0147       }
0148     }
0149   }
0150 
0151   void DeepTauBase::produce(edm::Event& event, const edm::EventSetup& es) {
0152     edm::Handle<TauCollection> taus;
0153     event.getByToken(tausToken_, taus);
0154     edm::ProductID tauProductID = taus.id();
0155 
0156     // load prediscriminators
0157     size_t nPrediscriminants =
0158         patPrediscriminants_.empty() ? recoPrediscriminants_.size() : patPrediscriminants_.size();
0159     for (size_t iDisc = 0; iDisc < nPrediscriminants; ++iDisc) {
0160       edm::ProductID discKeyId;
0161       if (is_online_) {
0162         recoPrediscriminants_[iDisc].fill(event);
0163         discKeyId = recoPrediscriminants_[iDisc].handle->keyProduct().id();
0164       } else {
0165         patPrediscriminants_[iDisc].fill(event);
0166         discKeyId = patPrediscriminants_[iDisc].handle->keyProduct().id();
0167       }
0168 
0169       // Check to make sure the product is correct for the discriminator.
0170       // If not, throw a more informative exception.
0171       if (tauProductID != discKeyId) {
0172         throw cms::Exception("MisconfiguredPrediscriminant")
0173             << "The tau collection has product ID: " << tauProductID
0174             << " but the pre-discriminator is keyed with product ID: " << discKeyId << std::endl;
0175       }
0176     }
0177 
0178     const tensorflow::Tensor& pred = getPredictions(event, taus);
0179     createOutputs(event, pred, taus);
0180   }
0181 
0182   void DeepTauBase::createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus) {
0183     for (const auto& output_desc : outputs_) {
0184       const WPList* working_points = nullptr;
0185       if (workingPoints_.find(output_desc.first) != workingPoints_.end()) {
0186         working_points = &workingPoints_.at(output_desc.first);
0187       }
0188       auto result = output_desc.second.get_value(taus, pred, working_points, is_online_);
0189       event.put(std::move(result), output_desc.first);
0190     }
0191   }
0192 
0193   std::unique_ptr<DeepTauCache> DeepTauBase::initializeGlobalCache(const edm::ParameterSet& cfg) {
0194     const auto graph_name_vector = cfg.getParameter<std::vector<std::string>>("graph_file");
0195     std::map<std::string, std::string> graph_names;
0196     for (const auto& entry : graph_name_vector) {
0197       const size_t sep_pos = entry.find(':');
0198       std::string entry_name, graph_file;
0199       if (sep_pos != std::string::npos) {
0200         entry_name = entry.substr(0, sep_pos);
0201         graph_file = entry.substr(sep_pos + 1);
0202       } else {
0203         entry_name = "";
0204         graph_file = entry;
0205       }
0206       graph_file = edm::FileInPath(graph_file).fullPath();
0207       if (graph_names.count(entry_name))
0208         throw cms::Exception("DeepTauCache") << "Duplicated graph entries";
0209       graph_names[entry_name] = graph_file;
0210     }
0211     bool mem_mapped = cfg.getParameter<bool>("mem_mapped");
0212     return std::make_unique<DeepTauCache>(graph_names, mem_mapped);
0213   }
0214 
0215   DeepTauCache::DeepTauCache(const std::map<std::string, std::string>& graph_names, bool mem_mapped) {
0216     for (const auto& graph_entry : graph_names) {
0217       tensorflow::SessionOptions options;
0218       tensorflow::setThreading(options, 1);
0219 
0220       const std::string& entry_name = graph_entry.first;
0221       const std::string& graph_file = graph_entry.second;
0222       if (mem_mapped) {
0223         memmappedEnv_[entry_name] = std::make_unique<tensorflow::MemmappedEnv>(tensorflow::Env::Default());
0224         const tensorflow::Status mmap_status = memmappedEnv_.at(entry_name)->InitializeFromFile(graph_file);
0225         if (!mmap_status.ok()) {
0226           throw cms::Exception("DeepTauCache: unable to initalize memmapped environment for ")
0227               << graph_file << ". \n"
0228               << mmap_status.ToString();
0229         }
0230 
0231         graphs_[entry_name] = std::make_unique<tensorflow::GraphDef>();
0232         const tensorflow::Status load_graph_status =
0233             ReadBinaryProto(memmappedEnv_.at(entry_name).get(),
0234                             tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
0235                             graphs_.at(entry_name).get());
0236         if (!load_graph_status.ok())
0237           throw cms::Exception("DeepTauCache: unable to load graph from ") << graph_file << ". \n"
0238                                                                            << load_graph_status.ToString();
0239 
0240         options.config.mutable_graph_options()->mutable_optimizer_options()->set_opt_level(
0241             ::tensorflow::OptimizerOptions::L0);
0242         options.env = memmappedEnv_.at(entry_name).get();
0243 
0244         sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
0245 
0246       } else {
0247         graphs_[entry_name].reset(tensorflow::loadGraphDef(graph_file));
0248         sessions_[entry_name] = tensorflow::createSession(graphs_.at(entry_name).get(), options);
0249       }
0250     }
0251   }
0252 
0253   DeepTauCache::~DeepTauCache() {
0254     for (auto& session_entry : sessions_)
0255       tensorflow::closeSession(session_entry.second);
0256   }
0257 
0258 }  // namespace deep_tau