File indexing completed on 2024-08-02 05:16:40
0001 import sys
0002 import os
0003 import torch
0004
0005
0006 if len(sys.argv) >= 2:
0007 datadir = sys.argv[1]
0008 else:
0009 thisdir = os.path.dirname(os.path.abspath(__file__))
0010 datadir = os.path.join(os.path.dirname(thisdir), "bin", "data")
0011
0012 os.makedirs(datadir, exist_ok=True)
0013
0014 class MyModule(torch.nn.Module):
0015 def __init__(self, N, M):
0016 super(MyModule, self).__init__()
0017 self.weight = torch.nn.Parameter(torch.ones(N, M))
0018 self.bias = torch.nn.Parameter(torch.ones(N))
0019
0020 def forward(self, input):
0021 return torch.sum(torch.nn.functional.elu(self.weight.mv(input) + self.bias))
0022
0023
0024 module = MyModule(10, 10)
0025 x = torch.ones(10)
0026
0027 tm = torch.jit.trace(module.eval(), x)
0028
0029 tm.save(f"{datadir}/simple_dnn.pt")
0030
0031 print("simple_dnn.pt created successfully!")