File indexing completed on 2024-08-02 05:16:40
0001 #include <torch/torch.h>
0002 #include <iostream>
0003
0004 struct Net : torch::nn::Module {
0005 Net(int64_t N, int64_t M) : linear(register_module("linear", torch::nn::Linear(N, M))) {
0006 another_bias = register_parameter("b", torch::randn(M));
0007 }
0008 torch::Tensor forward(torch::Tensor input) { return linear(input) + another_bias; }
0009 torch::nn::Linear linear;
0010 torch::Tensor another_bias;
0011 };
0012
0013 int main(int , char* []) {
0014
0015 Net net(4, 5);
0016
0017 torch::Device device(torch::kCPU);
0018 if (torch::cuda::is_available()) {
0019 device = torch::Device(torch::kCUDA);
0020 std::cout << "CUDA is available! Training on GPU." << std::endl;
0021 }
0022
0023 net.to(device);
0024
0025 for (const auto& pair : net.named_parameters()) {
0026 std::cout << pair.key() << ": " << pair.value() << std::endl;
0027 }
0028
0029 std::cout << net.forward(torch::ones({2, 4}).to(device)) << std::endl;
0030
0031 return 0;
0032 }