Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2021-02-14 13:32:43

0001 /*
0002  * MXNetCppPredictor.h
0003  *
0004  *  Created on: Jul 19, 2018
0005  *      Author: hqu
0006  */
0007 
0008 #ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
0009 #define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
0010 
0011 #include <map>
0012 #include <vector>
0013 #include <memory>
0014 #include <mutex>
0015 
0016 #include "mxnet-cpp/MxNetCpp.h"
0017 
0018 namespace mxnet {
0019 
0020   namespace cpp {
0021 
0022     // note: Most of the objects in mxnet::cpp are effective just shared_ptr's
0023 
0024     // Simple class to hold MXNet model (symbol + params)
0025     // designed to be sharable by multiple threads
0026     class Block {
0027     public:
0028       Block();
0029       Block(const std::string& symbol_file, const std::string& param_file);
0030       virtual ~Block();
0031 
0032       const Symbol& symbol() const { return sym_; }
0033       Symbol symbol(const std::string& output_node) const { return sym_.GetInternals()[output_node]; }
0034       const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
0035       const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }
0036 
0037     private:
0038       void load_parameters(const std::string& param_file);
0039 
0040       // symbol
0041       Symbol sym_;
0042       // argument arrays
0043       std::map<std::string, NDArray> arg_map_;
0044       // auxiliary arrays
0045       std::map<std::string, NDArray> aux_map_;
0046     };
0047 
0048     // Simple helper class to run prediction
0049     // this cannot be shared between threads
0050     class Predictor {
0051     public:
0052       Predictor();
0053       Predictor(const Block& block);
0054       Predictor(const Block& block, const std::string& output_node);
0055       virtual ~Predictor();
0056 
0057       // set input array shapes
0058       void set_input_shapes(const std::vector<std::string>& input_names,
0059                             const std::vector<std::vector<mx_uint>>& input_shapes);
0060 
0061       // run prediction
0062       const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);
0063 
0064     private:
0065       static std::mutex mutex_;
0066 
0067       void bind_executor();
0068 
0069       // context
0070       static const Context context_;
0071       // executor
0072       std::unique_ptr<Executor> exec_;
0073       // symbol
0074       Symbol sym_;
0075       // argument arrays
0076       std::map<std::string, NDArray> arg_map_;
0077       // auxiliary arrays
0078       std::map<std::string, NDArray> aux_map_;
0079       // output of the prediction
0080       std::vector<float> pred_;
0081       // names of the input nodes
0082       std::vector<std::string> input_names_;
0083     };
0084 
0085   } /* namespace cpp */
0086 } /* namespace mxnet */
0087 
0088 #endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */