backend_rep.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # @package onnx
  2. # Module caffe2.python.onnx.backend_rep
  3. from caffe2.python import core
  4. from caffe2.proto import caffe2_pb2
  5. from onnx.backend.base import BackendRep, namedtupledict
  6. class Caffe2Rep(BackendRep):
  7. def __init__(self, init_net, predict_net, workspace, uninitialized):
  8. super(Caffe2Rep, self).__init__()
  9. self.init_net = init_net
  10. self.predict_net = predict_net
  11. self.workspace = workspace
  12. # The list of uninitialized external_inputs in workspace, we need this to
  13. # pair the name with given sequence inputs.
  14. self.uninitialized = uninitialized
  15. self.nets_created = False
  16. self.ran_init_net = False
  17. @property
  18. def _name_scope(self):
  19. if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
  20. return 'gpu_{}'.format(self.predict_net.device_option.device_id)
  21. return ''
  22. def run(self, inputs, **kwargs):
  23. super(Caffe2Rep, self).run(inputs, **kwargs)
  24. with core.DeviceScope(self.predict_net.device_option):
  25. if isinstance(inputs, dict):
  26. with core.NameScope(self._name_scope):
  27. for key, value in inputs.items():
  28. self.workspace.FeedBlob(key, value)
  29. elif isinstance(inputs, list) or isinstance(inputs, tuple):
  30. if len(self.uninitialized) != len(inputs):
  31. raise RuntimeError('Expected {} values for uninitialized '
  32. 'graph inputs ({}), but got {}.'.format(
  33. len(self.uninitialized),
  34. ', '.join(self.uninitialized),
  35. len(inputs)))
  36. for i, value in enumerate(inputs):
  37. # namescope already baked into protobuf
  38. self.workspace.FeedBlob(self.uninitialized[i], value)
  39. else:
  40. # single input
  41. self.workspace.FeedBlob(self.uninitialized[0], inputs)
  42. if not self.nets_created:
  43. self.workspace.CreateNet(self.init_net)
  44. self.workspace.CreateNet(self.predict_net)
  45. self.nets_created = True
  46. if not self.ran_init_net:
  47. self.workspace.RunNet(self.init_net.name)
  48. self.ran_init_net = True
  49. self.workspace.RunNet(self.predict_net.name)
  50. output_values = []
  51. for name in self.predict_net.external_output:
  52. try:
  53. output_values.append(self.workspace.FetchBlob(name))
  54. except Exception:
  55. output_values.append(self.workspace.FetchInt8Blob(name))
  56. return namedtupledict('Outputs',
  57. self.predict_net.external_output)(*output_values)