Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-19 22:31:31

0001 #ifndef PHYSICSTOOLS_TENSORFLOWAOT_WRAPPER_H
0002 #define PHYSICSTOOLS_TENSORFLOWAOT_WRAPPER_H
0003 
0004 /*
0005  * AOT wrapper interface for interacting with xla functions compiled for different batch sizes.
0006  *
0007  * Author: Marcel Rieger, Bogdan Wiederspan
0008  */
0009 
0010 #include <map>
0011 #include <algorithm>
0012 
0013 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
0014 #include "tensorflow/core/platform/types.h"
0015 #include "eigen3/unsupported/Eigen/CXX11/Tensor"
0016 
0017 #include "FWCore/Utilities/interface/Exception.h"
0018 
0019 #include "PhysicsTools/TensorFlowAOT/interface/Util.h"
0020 
0021 namespace tfaot {
0022 
0023   // object that wraps multiple variants of the same xla function, but each compiled for a different
0024   // batch size, and providing access to arguments (inputs) and results (outputs) by index
0025   class Wrapper {
0026   public:
0027     // constructor
0028     explicit Wrapper(const std::string& name)
0029         : name_(name), allocMode_(AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS) {}
0030 
0031     // disable copy constructor
0032     Wrapper(const Wrapper&) = delete;
0033 
0034     // disable assignment operator
0035     Wrapper& operator=(const Wrapper&) = delete;
0036 
0037     // disable move operator
0038     Wrapper& operator=(Wrapper&&) = delete;
0039 
0040     // destructor
0041     virtual ~Wrapper() = default;
0042 
0043     // getter for the name
0044     const std::string& name() const { return name_; }
0045 
0046     // getter for the allocation mode
0047     AllocMode allocMode() const { return allocMode_; }
0048 
0049     // getter for the compiled batch sizes
0050     virtual const std::vector<size_t>& batchSizes() const = 0;
0051 
0052     // returns the number of compiled batch sizes
0053     size_t nBatchSizes() const { return batchSizes().size(); }
0054 
0055     // returns whether a compiled xla function exists for a certain batch size
0056     // (batchSizes is sorted by default)
0057     bool hasBatchSize(size_t batchSize) const {
0058       const auto& bs = batchSizes();
0059       return std::binary_search(bs.begin(), bs.end(), batchSize);
0060     }
0061 
0062     // getter for the number of arguments (inputs)
0063     virtual size_t nArgs() const = 0;
0064 
0065     // number of elements in arguments per batch size
0066     virtual const std::map<size_t, std::vector<size_t>>& argCounts() const = 0;
0067 
0068     // number of elements in arguments, divided by batch size
0069     virtual const std::vector<size_t>& argCountsNoBatch() const = 0;
0070 
0071     // returns a pointer to the argument data at a certain index for the xla function at some batch
0072     // size
0073     template <typename T>
0074     T* argData(size_t batchSize, size_t argIndex);
0075 
0076     // returns a const pointer to the argument data at a certain index for the xla function at some
0077     // batch size
0078     template <typename T>
0079     const T* argData(size_t batchSize, size_t argIndex) const;
0080 
0081     // returns the total number of values in the argument data at a certain index for the xla
0082     // function at some batch size
0083     int argCount(size_t batchSize, size_t argIndex) const;
0084 
0085     // returns the number of values excluding the leading batch axis in the argument data at a
0086     // certain index for the xla function at some batch size
0087     int argCountNoBatch(size_t argIndex) const;
0088 
0089     // getter for the number of results (outputs)
0090     virtual size_t nResults() const = 0;
0091 
0092     // number of elements in results per batch size
0093     virtual const std::map<size_t, std::vector<size_t>>& resultCounts() const = 0;
0094 
0095     // number of elements in results, divided by batch size
0096     virtual const std::vector<size_t>& resultCountsNoBatch() const = 0;
0097 
0098     // returns a pointer to the result data at a certain index for the xla function at some batch
0099     // size
0100     template <typename T>
0101     T* resultData(size_t batchSize, size_t resultIndex);
0102 
0103     // returns a const pointer to the result data at a certain index for the xla function at some
0104     // batch size
0105     template <typename T>
0106     const T* resultData(size_t batchSize, size_t resultIndex) const;
0107 
0108     // returns the total number of values in the result data at a certain index for the xla function
0109     // at some batch size
0110     int resultCount(size_t batchSize, size_t resultIndex) const;
0111 
0112     // returns the number of values excluding the leading batch axis in the result data at a
0113     // certain index for the xla function at some batch size
0114     int resultCountNoBatch(size_t resultIndex) const;
0115 
0116     // evaluates the xla function at some batch size and returns whether the call succeeded
0117     virtual bool runSilent(size_t batchSize) = 0;
0118 
0119     // evaluates the xla function at some batch size and throws an exception in case of an error
0120     void run(size_t batchSize);
0121 
0122   protected:
0123     // throws an exception for the case where an unknown batch size was requested
0124     void unknownBatchSize(size_t batchSize, const std::string& method) const {
0125       throw cms::Exception("UnknownBatchSize")
0126           << "batch size " << batchSize << " not known to model '" << name_ << "' in '" << method << "'";
0127     }
0128 
0129     // throws an exception for the case where an unknown argument index was requested
0130     void unknownArgument(size_t argIndex, const std::string& method) const {
0131       throw cms::Exception("UnknownArgument")
0132           << "argument " << argIndex << " not known to model '" << name_ << "' in '" << method << "'";
0133     }
0134 
0135     // throws an exception for the case where an unknown result index was requested
0136     void unknownResult(size_t resultIndex, const std::string& method) const {
0137       throw cms::Exception("UnknownResult")
0138           << "result " << resultIndex << " not known to model '" << name_ << "' in '" << method << "'";
0139     }
0140 
0141   private:
0142     std::string name_;
0143     AllocMode allocMode_;
0144   };
0145 
0146 }  // namespace tfaot
0147 
0148 #endif  // PHYSICSTOOLS_TENSORFLOWAOT_WRAPPER_H