Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:32:43

0001 /*
0002  * MXNetCppPredictor.cc
0003  *
0004  *  Created on: Jul 19, 2018
0005  *      Author: hqu
0006  */
0007 
0008 #include "PhysicsTools/MXNet/interface/Predictor.h"
0009 
0010 #include <cassert>
0011 #include <memory>
0012 
0013 #include "FWCore/Utilities/interface/Exception.h"
0014 
0015 namespace mxnet {
0016 
0017   namespace cpp {
0018 
0019     Block::Block() {}
0020 
0021     Block::Block(const std::string& symbol_file, const std::string& param_file) {
0022       // load the symbol
0023       sym_ = Symbol::Load(symbol_file);
0024       // load the parameters
0025       load_parameters(param_file);
0026     }
0027 
0028     Block::~Block() {}
0029 
0030     void Block::load_parameters(const std::string& param_file) {
0031       std::map<std::string, NDArray> paramters;
0032       NDArray::Load(param_file, nullptr, &paramters);
0033       for (const auto& k : paramters) {
0034         if (k.first.substr(0, 4) == "aux:") {
0035           auto name = k.first.substr(4, k.first.size() - 4);
0036           aux_map_[name] = k.second;
0037         }
0038         if (k.first.substr(0, 4) == "arg:") {
0039           auto name = k.first.substr(4, k.first.size() - 4);
0040           arg_map_[name] = k.second;
0041         }
0042       }
0043     }
0044 
0045     std::mutex Predictor::mutex_;
0046     const Context Predictor::context_ = Context(DeviceType::kCPU, 0);
0047 
0048     Predictor::Predictor() {}
0049 
0050     Predictor::Predictor(const Block& block)
0051         : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
0052 
0053     Predictor::Predictor(const Block& block, const std::string& output_node)
0054         : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
0055 
0056     Predictor::~Predictor() {}
0057 
0058     void Predictor::set_input_shapes(const std::vector<std::string>& input_names,
0059                                      const std::vector<std::vector<mx_uint> >& input_shapes) {
0060       assert(input_names.size() == input_shapes.size());
0061       input_names_ = input_names;
0062       // init the input NDArrays and add them to the arg_map
0063       for (unsigned i = 0; i < input_names_.size(); ++i) {
0064         const auto& name = input_names_[i];
0065         arg_map_.emplace(name, NDArray(input_shapes[i], context_, false));
0066       }
0067     }
0068 
0069     const std::vector<float>& Predictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
0070       assert(input_names_.size() == input_data.size());
0071 
0072       try {
0073         // create the executor (if not done yet)
0074         if (!exec_) {
0075           bind_executor();
0076         }
0077         assert(exec_);
0078         // set the inputs
0079         for (unsigned i = 0; i < input_names_.size(); ++i) {
0080           const auto& name = input_names_[i];
0081           arg_map_[name].SyncCopyFromCPU(input_data[i]);
0082         }
0083         // run forward
0084         exec_->Forward(false);
0085         // copy the output to pred_
0086         exec_->outputs[0].SyncCopyToCPU(&pred_);
0087         return pred_;
0088       } catch (const dmlc::Error& e) {
0089         throw cms::Exception("RuntimeError") << e.what() << MXGetLastError();
0090       }
0091     }
0092 
0093     void Predictor::bind_executor() {
0094       // acquire lock
0095       std::lock_guard<std::mutex> lock(mutex_);
0096 
0097       // infer shapes
0098       const auto arg_name_list = sym_.ListArguments();
0099       std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
0100       std::map<std::string, std::vector<mx_uint> > arg_shapes;
0101 
0102       for (const auto& arg_name : arg_name_list) {
0103         auto iter = arg_map_.find(arg_name);
0104         if (iter != arg_map_.end()) {
0105           arg_shapes[arg_name] = iter->second.GetShape();
0106         }
0107       }
0108       sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
0109 
0110       // init argument arrays
0111       std::vector<NDArray> arg_arrays;
0112       for (size_t i = 0; i < in_shapes.size(); ++i) {
0113         const auto& shape = in_shapes[i];
0114         const auto& arg_name = arg_name_list[i];
0115         auto iter_arg = arg_map_.find(arg_name);
0116         if (iter_arg != arg_map_.end()) {
0117           arg_arrays.push_back(iter_arg->second);
0118         } else {
0119           arg_arrays.push_back(NDArray(shape, context_, false));
0120         }
0121       }
0122       std::vector<NDArray> grad_arrays(arg_arrays.size());
0123       std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
0124 
0125       // init auxiliary array
0126       std::vector<NDArray> aux_arrays;
0127       const auto aux_name_list = sym_.ListAuxiliaryStates();
0128       for (size_t i = 0; i < aux_shapes.size(); ++i) {
0129         const auto& shape = aux_shapes[i];
0130         const auto& aux_name = aux_name_list[i];
0131         auto iter_aux = aux_map_.find(aux_name);
0132         if (iter_aux != aux_map_.end()) {
0133           aux_arrays.push_back(iter_aux->second);
0134         } else {
0135           aux_arrays.push_back(NDArray(shape, context_, false));
0136         }
0137       }
0138 
0139       // bind executor
0140       exec_ = std::make_unique<Executor>(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays);
0141     }
0142 
0143   }  // namespace cpp
0144 
0145 } /* namespace mxnet */