Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-08-02 05:16:40

0001 import sys
0002 import os
0003 import torch
0004 
0005 # prepare the datadir
0006 if len(sys.argv) >= 2:
0007     datadir = sys.argv[1]
0008 else:
0009     thisdir = os.path.dirname(os.path.abspath(__file__))
0010     datadir = os.path.join(os.path.dirname(thisdir), "bin", "data")
0011 
0012 os.makedirs(datadir, exist_ok=True)
0013 
0014 class MyModule(torch.nn.Module):
0015     def __init__(self, N, M):
0016         super(MyModule, self).__init__()
0017         self.weight = torch.nn.Parameter(torch.ones(N, M))
0018         self.bias = torch.nn.Parameter(torch.ones(N))
0019         
0020     def forward(self, input):
0021           return torch.sum(torch.nn.functional.elu(self.weight.mv(input) + self.bias))
0022 
0023 
0024 module = MyModule(10, 10)
0025 x = torch.ones(10)
0026 
0027 tm = torch.jit.trace(module.eval(), x)
0028 
0029 tm.save(f"{datadir}/simple_dnn.pt")
0030 
0031 print("simple_dnn.pt created successfully!")