sample.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import tempfile
  2. import numpy as np
  3. from torch import nn
  4. from torch.autograd import Variable, Function
  5. import torch.onnx
  6. import onnx
  7. import caffe2.python.onnx.backend
  8. class MyFunction(Function):
  9. @staticmethod
  10. def forward(ctx, x, y):
  11. return x * x + y
  12. @staticmethod
  13. def symbolic(graph, x, y):
  14. x2 = graph.at("mul", x, x)
  15. r = graph.at("add", x2, y)
  16. # x, y, x2, and r are 'Node' objects
  17. # print(r) or print(graph) will print out a textual representation for debugging.
  18. # this representation will be converted to ONNX protobufs on export.
  19. return r
  20. class MyModule(nn.Module):
  21. def forward(self, x, y):
  22. # you can combine your ATen ops with standard onnx ones
  23. x = nn.ReLU()(x)
  24. return MyFunction.apply(x, y)
  25. f = tempfile.NamedTemporaryFile()
  26. torch.onnx.export(MyModule(),
  27. (Variable(torch.ones(3, 4)), Variable(torch.ones(3, 4))),
  28. f, verbose=True)
  29. # prints the graph for debugging:
  30. # graph(%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
  31. # %y : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  32. # %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Relu(%input)
  33. # %3 : Tensor = aten::ATen[operator="mul"](%2, %2)
  34. # %4 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::ATen[operator="add"](%3, %y)
  35. # return (%4)
  36. graph = onnx.load(f.name)
  37. a = np.random.randn(3, 4).astype(np.float32)
  38. b = np.random.randn(3, 4).astype(np.float32)
  39. prepared_backend = caffe2.python.onnx.backend.prepare(graph)
  40. W = {graph.graph.input[0].name: a, graph.graph.input[1].name: b}
  41. c2_out = prepared_backend.run(W)[0]
  42. x = np.maximum(a, 0)
  43. r = x * x + b
  44. np.testing.assert_array_almost_equal(r, c2_out)