Skip to content

Commit

Permalink
codeformat
Browse files Browse the repository at this point in the history
  • Loading branch information
valsdav committed Dec 1, 2023
1 parent baa5bc9 commit 0bfcc4a
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 35 deletions.
20 changes: 20 additions & 0 deletions PhysicsTools/PyTorch/test/create_simple_dnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(N, M))
self.bias = torch.nn.Parameter(torch.ones(N))

def forward(self, input):
return torch.sum(torch.nn.functional.elu(self.weight.mv(input) + self.bias))


module = MyModule(10, 10)
x = torch.ones(10)

tm = torch.jit.trace(module.eval(), x)

print(tm.graph)

tm.save("simple_dnn.pt")
1 change: 0 additions & 1 deletion PhysicsTools/PyTorch/test/testBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class testBasePyTorch : public CppUnit::TestFixture {
std::string cmsswPath(std::string path);

virtual void test() = 0;

};

void testBasePyTorch::setUp() {
Expand Down
1 change: 0 additions & 1 deletion PhysicsTools/PyTorch/test/testBaseCUDA.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class testBasePyTorchCUDA : public CppUnit::TestFixture {
std::string cmsswPath(std::string path);

virtual void test() = 0;

};

void testBasePyTorchCUDA::setUp() {
Expand Down
12 changes: 3 additions & 9 deletions PhysicsTools/PyTorch/test/testTorchSimpleDnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <memory>
#include <vector>


class testSimpleDNN : public testBasePyTorch {
CPPUNIT_TEST_SUITE(testSimpleDNN);
CPPUNIT_TEST(test);
Expand All @@ -24,25 +23,20 @@ void testSimpleDNN::test() {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(model_path);
module.to(device);
}
catch (const c10::Error& e) {

} catch (const c10::Error& e) {
std::cerr << "error loading the model\n" << e.what() << std::endl;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones(10, device));


// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << "output: "<< output << '\n';
std::cout << "output: " << output << '\n';
CPPUNIT_ASSERT(output.item<float_t>() == 110.);
std::cout << "ok\n";
}



// int main(int argc, const char* argv[]) {
// std::cout << "Running model on CPU" << std::endl;
// torch::Device cpu(torch::kCPU);
Expand All @@ -51,6 +45,6 @@ void testSimpleDNN::test() {
// std::cout << "Running model on CUDA" << std::endl;
// torch::Device cuda(torch::kCUDA);
// runModel("/data/user/dvalsecc/simple_dnn.pt", cuda);

// return 0;
// }
16 changes: 4 additions & 12 deletions PhysicsTools/PyTorch/test/testTorchSimpleDnnCUDA.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <memory>
#include <vector>


class testSimpleDNNCUDA : public testBasePyTorchCUDA {
CPPUNIT_TEST_SUITE(testSimpleDNNCUDA);
CPPUNIT_TEST(test);
Expand Down Expand Up @@ -41,34 +40,27 @@ process.add_(cms.Service('CUDAService'))

std::cout << "Testing CUDA backend" << std::endl;



std::string model_path = testPath_ + "/simple_dnn.pt";
torch::Device device(torch::kCUDA );
torch::Device device(torch::kCUDA);
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(model_path);
module.to(device);
}
catch (const c10::Error& e) {

} catch (const c10::Error& e) {
std::cerr << "error loading the model\n" << e.what() << std::endl;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones(10, device));


// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << "output: "<< output << '\n';
std::cout << "output: " << output << '\n';
CPPUNIT_ASSERT(output.item<float_t>() == 110.);
std::cout << "ok\n";
}



// int main(int argc, const char* argv[]) {
// std::cout << "Running model on CPU" << std::endl;
// torch::Device cpu(torch::kCPU);
Expand All @@ -77,6 +69,6 @@ process.add_(cms.Service('CUDAService'))
// std::cout << "Running model on CUDA" << std::endl;
// torch::Device cuda(torch::kCUDA);
// runModel("/data/user/dvalsecc/simple_dnn.pt", cuda);

// return 0;
// }
16 changes: 4 additions & 12 deletions PhysicsTools/PyTorch/test/testTorchSimpleDnnCuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <memory>
#include <vector>


class testSimpleDNNCUDA : public testBasePyTorchCUDA {
CPPUNIT_TEST_SUITE(testSimpleDNNCUDA);
CPPUNIT_TEST(test);
Expand Down Expand Up @@ -41,34 +40,27 @@ process.add_(cms.Service('CUDAService'))

std::cout << "Testing CUDA backend" << std::endl;



std::string model_path = testPath_ + "/simple_dnn.pt";
torch::Device device(torch::kCUDA );
torch::Device device(torch::kCUDA);
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(model_path);
module.to(device);
}
catch (const c10::Error& e) {

} catch (const c10::Error& e) {
std::cerr << "error loading the model\n" << e.what() << std::endl;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones(10, device));


// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << "output: "<< output << '\n';
std::cout << "output: " << output << '\n';
CPPUNIT_ASSERT(output.item<float_t>() == 110.);
std::cout << "ok\n";
}



// int main(int argc, const char* argv[]) {
// std::cout << "Running model on CPU" << std::endl;
// torch::Device cpu(torch::kCPU);
Expand All @@ -77,6 +69,6 @@ process.add_(cms.Service('CUDAService'))
// std::cout << "Running model on CUDA" << std::endl;
// torch::Device cuda(torch::kCUDA);
// runModel("/data/user/dvalsecc/simple_dnn.pt", cuda);

// return 0;
// }

0 comments on commit 0bfcc4a

Please sign in to comment.