Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-02 05:16:40

0001 // Based on https://github.com/Maverobot/libtorch_examples/blob/master/src/simple_optimization_example.cpp
0002 #include <torch/torch.h>
0003 #include <cstdlib>
0004 #include <iostream>
0005 
0006 constexpr double kLearningRate = 0.001;
0007 constexpr int kMaxIterations = 100000;
0008 
0009 void native_run(double minimal) {
0010   // Initial x value
0011   auto x = torch::randn({1, 1}, torch::requires_grad(true));
0012 
0013   for (size_t t = 0; t < kMaxIterations; t++) {
0014     // Expression/value to be minimized
0015     auto y = (x - minimal) * (x - minimal);
0016     if (y.item<double>() < 1e-3) {
0017       break;
0018     }
0019     // Calculate gradient
0020     y.backward();
0021 
0022     // Step x value without considering gradient
0023     torch::NoGradGuard no_grad_guard;
0024     x -= kLearningRate * x.grad();
0025 
0026     // Reset the gradient of variable x
0027     x.mutable_grad().reset();
0028   }
0029 
0030   std::cout << "[native] Actual minimal x value: " << minimal << ", calculated optimal x value: " << x.item<double>()
0031             << std::endl;
0032 }
0033 
0034 void optimizer_run(double minimal) {
0035   // Initial x value
0036   std::vector<torch::Tensor> x;
0037   x.push_back(torch::randn({1, 1}, torch::requires_grad(true)));
0038   auto opt = torch::optim::SGD(x, torch::optim::SGDOptions(kLearningRate));
0039 
0040   for (size_t t = 0; t < kMaxIterations; t++) {
0041     // Expression/value to be minimized
0042     auto y = (x[0] - minimal) * (x[0] - minimal);
0043     if (y.item<double>() < 1e-3) {
0044       break;
0045     }
0046     // Calculate gradient
0047     y.backward();
0048 
0049     // Step x value without considering gradient
0050     opt.step();
0051     // Reset the gradient of variable x
0052     opt.zero_grad();
0053   }
0054 
0055   std::cout << "[optimizer] Actual minimal x value: " << minimal
0056             << ", calculated optimal x value: " << x[0].item<double>() << std::endl;
0057 }
0058 
0059 // optimize y = (x - 10)^2
0060 int main(int argc, char* argv[]) {
0061   native_run(0.01);
0062   optimizer_run(0.01);
0063   return 0;
0064 }