conversion.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. ## @package onnx
  2. # Module caffe2.python.onnx.bin.conversion
  3. import json
  4. from caffe2.proto import caffe2_pb2
  5. import click
  6. from onnx import ModelProto
  7. from caffe2.python.onnx.backend import Caffe2Backend as c2
  8. import caffe2.python.onnx.frontend as c2_onnx
  9. @click.command(
  10. help='convert caffe2 net to onnx model',
  11. context_settings={
  12. 'help_option_names': ['-h', '--help']
  13. }
  14. )
  15. @click.argument('caffe2_net', type=click.File('rb'))
  16. @click.option('--caffe2-net-name',
  17. type=str,
  18. help="Name of the caffe2 net")
  19. @click.option('--caffe2-init-net',
  20. type=click.File('rb'),
  21. help="Path of the caffe2 init net pb file")
  22. @click.option('--value-info',
  23. type=str,
  24. help='A json string providing the '
  25. 'type and shape information of the inputs')
  26. @click.option('-o', '--output', required=True,
  27. type=click.File('wb'),
  28. help='Output path for the onnx model pb file')
  29. def caffe2_to_onnx(caffe2_net,
  30. caffe2_net_name,
  31. caffe2_init_net,
  32. value_info,
  33. output):
  34. c2_net_proto = caffe2_pb2.NetDef()
  35. c2_net_proto.ParseFromString(caffe2_net.read())
  36. if not c2_net_proto.name and not caffe2_net_name:
  37. raise click.BadParameter(
  38. 'The input caffe2 net does not have name, '
  39. '--caffe2-net-name must be provided')
  40. c2_net_proto.name = caffe2_net_name or c2_net_proto.name
  41. if caffe2_init_net:
  42. c2_init_net_proto = caffe2_pb2.NetDef()
  43. c2_init_net_proto.ParseFromString(caffe2_init_net.read())
  44. c2_init_net_proto.name = '{}_init'.format(caffe2_net_name)
  45. else:
  46. c2_init_net_proto = None
  47. if value_info:
  48. value_info = json.loads(value_info)
  49. onnx_model = c2_onnx.caffe2_net_to_onnx_model(
  50. predict_net=c2_net_proto,
  51. init_net=c2_init_net_proto,
  52. value_info=value_info)
  53. output.write(onnx_model.SerializeToString())
  54. @click.command(
  55. help='convert onnx model to caffe2 net',
  56. context_settings={
  57. 'help_option_names': ['-h', '--help']
  58. }
  59. )
  60. @click.argument('onnx_model', type=click.File('rb'))
  61. @click.option('-o', '--output', required=True,
  62. type=click.File('wb'),
  63. help='Output path for the caffe2 net file')
  64. @click.option('--init-net-output',
  65. required=True,
  66. type=click.File('wb'),
  67. help='Output path for the caffe2 init net file')
  68. def onnx_to_caffe2(onnx_model, output, init_net_output):
  69. onnx_model_proto = ModelProto()
  70. onnx_model_proto.ParseFromString(onnx_model.read())
  71. init_net, predict_net = c2.onnx_graph_to_caffe2_net(onnx_model_proto)
  72. init_net_output.write(init_net.SerializeToString())
  73. output.write(predict_net.SerializeToString())