nomnigraph.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import errno
  2. import os
  3. from subprocess import PIPE, Popen
  4. import caffe2.python._import_c_extension as C
  5. from caffe2.proto import caffe2_pb2
  6. from caffe2.python import core
  7. class NNModule(object):
  8. def __init__(self, net=None, device_map=None):
  9. if net is not None:
  10. serialized_proto = None
  11. if isinstance(net, core.Net):
  12. serialized_proto = net.Proto().SerializeToString()
  13. elif isinstance(net, caffe2_pb2.NetDef):
  14. serialized_proto = net.SerializeToString()
  15. # Distributed
  16. if device_map is not None:
  17. serialized_device_map = {}
  18. for k in device_map:
  19. serialized_device_map[k] = device_map[k].SerializeToString()
  20. self._NNModule = C.NNModuleFromProtobufDistributed(
  21. serialized_proto, serialized_device_map
  22. )
  23. # Default
  24. elif serialized_proto:
  25. self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
  26. else:
  27. raise Exception(
  28. "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
  29. )
  30. else:
  31. self._NNModule = C.NNModule()
  32. @property
  33. def dataFlow(self):
  34. return self._NNModule.dataFlow()
  35. @property
  36. def controlFlow(self):
  37. return self._NNModule.getExecutionOrder()
  38. @property
  39. def nodes(self):
  40. return self._NNModule.dataFlow().nodes
  41. @property
  42. def operators(self):
  43. return self._NNModule.dataFlow().operators
  44. @property
  45. def tensors(self):
  46. return self._NNModule.dataFlow().tensors
  47. def createNode(self, val):
  48. return self._NNModule.dataFlow().createNode(val)
  49. def deleteNode(self, node):
  50. return self._NNModule.dataFlow().deleteNode(node)
  51. def createEdge(self, a, b):
  52. return self._NNModule.dataFlow().createEdge(a, b)
  53. def deleteEdge(self, a, b=None):
  54. if b:
  55. self._NNModule.dataFlow().deleteEdge(a, b)
  56. else:
  57. self._NNModule.dataFlow().deleteEdge(a)
  58. def replaceNode(self, old_node, new_node):
  59. return self._NNModule.dataFlow().replaceNode(old_node, new_node)
  60. def replaceProducer(self, tensor, new_producer):
  61. C.replaceProducer(tensor, new_producer)
  62. def replaceAllUsesWith(self, old_tensor, new_tensor):
  63. C.replaceAllUsesWith(old_tensor, new_tensor)
  64. def replaceAsConsumer(self, old_consumer, new_consumer):
  65. C.replaceAsConsumer(old_consumer, new_consumer)
  66. def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
  67. self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
  68. def deleteSubgraph(self, subgraph):
  69. self._NNModule.deleteSubgraph(subgraph)
  70. def createUniqueDataNode(self, prefix="_unique"):
  71. return self._NNModule.createUniqueDataNode(prefix)
  72. def convertToCaffe2Proto(self, old_proto=None):
  73. if not old_proto:
  74. old_proto = caffe2_pb2.NetDef()
  75. output = self._NNModule.convertToCaffe2Proto(old_proto)
  76. new_proto = caffe2_pb2.NetDef()
  77. new_proto.ParseFromString(output)
  78. return new_proto
  79. def match(self, pattern):
  80. for n in self.dataFlow.getMutableNodes():
  81. m = C.matchSubgraph(n, pattern)
  82. if m:
  83. yield m
  84. def render(s):
  85. s = str(s)
  86. cmd_exists = lambda x: any(
  87. os.access(os.path.join(path, x), os.X_OK)
  88. for path in os.getenv("PATH", "").split(os.pathsep)
  89. )
  90. if cmd_exists("graph-easy"):
  91. p = Popen("graph-easy", stdin=PIPE)
  92. try:
  93. p.stdin.write(s.encode("utf-8"))
  94. except IOError as e:
  95. if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
  96. pass
  97. else:
  98. # Raise any other error.
  99. raise
  100. p.stdin.close()
  101. p.wait()
  102. else:
  103. print(s)
  104. NeuralNetOperator = C.NeuralNetOperator
  105. Operator = C.NeuralNetOperator
  106. NeuralNetData = C.NeuralNetData
  107. Data = C.NeuralNetData
  108. NNSubgraph = C.NNSubgraph
  109. NNMatchGraph = C.NNMatchGraph
  110. Graph = C.Graph
  111. Annotation = C.Annotation