Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:07

0001 #include <torch/torch.h>
0002 #include <iostream>
0003 
0004 struct Net : torch::nn::Module {
0005   Net(int64_t N, int64_t M) : linear(register_module("linear", torch::nn::Linear(N, M))) {
0006     another_bias = register_parameter("b", torch::randn(M));
0007   }
0008   torch::Tensor forward(torch::Tensor input) { return linear(input) + another_bias; }
0009   torch::nn::Linear linear;
0010   torch::Tensor another_bias;
0011 };
0012 
0013 int main(int /*argc*/, char* /*argv*/[]) {
0014   // Use GPU when present, CPU otherwise.
0015   Net net(4, 5);
0016 
0017   torch::Device device(torch::kCPU);
0018   if (torch::cuda::is_available()) {
0019     device = torch::Device(torch::kCUDA);
0020     std::cout << "CUDA is available! Training on GPU." << std::endl;
0021   }
0022 
0023   net.to(device);
0024 
0025   for (const auto& pair : net.named_parameters()) {
0026     std::cout << pair.key() << ": " << pair.value() << std::endl;
0027   }
0028 
0029   std::cout << net.forward(torch::ones({2, 4})) << std::endl;
0030 
0031   return 0;
0032 }