| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import errno
- import os
- from subprocess import PIPE, Popen
- import caffe2.python._import_c_extension as C
- from caffe2.proto import caffe2_pb2
- from caffe2.python import core
- class NNModule(object):
- def __init__(self, net=None, device_map=None):
- if net is not None:
- serialized_proto = None
- if isinstance(net, core.Net):
- serialized_proto = net.Proto().SerializeToString()
- elif isinstance(net, caffe2_pb2.NetDef):
- serialized_proto = net.SerializeToString()
- # Distributed
- if device_map is not None:
- serialized_device_map = {}
- for k in device_map:
- serialized_device_map[k] = device_map[k].SerializeToString()
- self._NNModule = C.NNModuleFromProtobufDistributed(
- serialized_proto, serialized_device_map
- )
- # Default
- elif serialized_proto:
- self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
- else:
- raise Exception(
- "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
- )
- else:
- self._NNModule = C.NNModule()
- @property
- def dataFlow(self):
- return self._NNModule.dataFlow()
- @property
- def controlFlow(self):
- return self._NNModule.getExecutionOrder()
- @property
- def nodes(self):
- return self._NNModule.dataFlow().nodes
- @property
- def operators(self):
- return self._NNModule.dataFlow().operators
- @property
- def tensors(self):
- return self._NNModule.dataFlow().tensors
- def createNode(self, val):
- return self._NNModule.dataFlow().createNode(val)
- def deleteNode(self, node):
- return self._NNModule.dataFlow().deleteNode(node)
- def createEdge(self, a, b):
- return self._NNModule.dataFlow().createEdge(a, b)
- def deleteEdge(self, a, b=None):
- if b:
- self._NNModule.dataFlow().deleteEdge(a, b)
- else:
- self._NNModule.dataFlow().deleteEdge(a)
- def replaceNode(self, old_node, new_node):
- return self._NNModule.dataFlow().replaceNode(old_node, new_node)
- def replaceProducer(self, tensor, new_producer):
- C.replaceProducer(tensor, new_producer)
- def replaceAllUsesWith(self, old_tensor, new_tensor):
- C.replaceAllUsesWith(old_tensor, new_tensor)
- def replaceAsConsumer(self, old_consumer, new_consumer):
- C.replaceAsConsumer(old_consumer, new_consumer)
- def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
- self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
- def deleteSubgraph(self, subgraph):
- self._NNModule.deleteSubgraph(subgraph)
- def createUniqueDataNode(self, prefix="_unique"):
- return self._NNModule.createUniqueDataNode(prefix)
- def convertToCaffe2Proto(self, old_proto=None):
- if not old_proto:
- old_proto = caffe2_pb2.NetDef()
- output = self._NNModule.convertToCaffe2Proto(old_proto)
- new_proto = caffe2_pb2.NetDef()
- new_proto.ParseFromString(output)
- return new_proto
- def match(self, pattern):
- for n in self.dataFlow.getMutableNodes():
- m = C.matchSubgraph(n, pattern)
- if m:
- yield m
- def render(s):
- s = str(s)
- cmd_exists = lambda x: any(
- os.access(os.path.join(path, x), os.X_OK)
- for path in os.getenv("PATH", "").split(os.pathsep)
- )
- if cmd_exists("graph-easy"):
- p = Popen("graph-easy", stdin=PIPE)
- try:
- p.stdin.write(s.encode("utf-8"))
- except IOError as e:
- if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
- pass
- else:
- # Raise any other error.
- raise
- p.stdin.close()
- p.wait()
- else:
- print(s)
- NeuralNetOperator = C.NeuralNetOperator
- Operator = C.NeuralNetOperator
- NeuralNetData = C.NeuralNetData
- Data = C.NeuralNetData
- NNSubgraph = C.NNSubgraph
- NNMatchGraph = C.NNMatchGraph
- Graph = C.Graph
- Annotation = C.Annotation
|