SparseTransformer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright (c) 2016-present, Facebook, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. ##############################################################################
  15. ## @package SparseTransformer
  16. # Module caffe2.experiments.python.SparseTransformer
  17. from caffe2.python import workspace
  18. import scipy.sparse
  19. class NetDefNode():
  20. def __init__(self, name, optype, p=None, op=None):
  21. self.name = name
  22. self.optype = optype
  23. self.ops = {}
  24. self.prev = {}
  25. self.insertInput(p)
  26. self.visited = False
  27. self.op = op
  28. def insertInput(self, p):
  29. """
  30. Insert input of this op
  31. also maintain the output of previous op
  32. p: a node or a list of node
  33. """
  34. if isinstance(p, list):
  35. for i in p:
  36. self.prev[i.name] = i
  37. i.ops[self.name] = self
  38. elif isinstance(p, NetDefNode):
  39. self.prev[p.name] = p
  40. p.ops[self.name] = self
  41. def deleteInput(self, p):
  42. if isinstance(p, NetDefNode):
  43. del self.prev[p.name]
  44. del p.ops[self.name]
  45. def maskNallocate(weight_name):
  46. """
  47. Combine mask and weights
  48. create wcsr, iw, jw, return their names
  49. """
  50. w = workspace.FetchBlob(weight_name)
  51. w_csr = scipy.sparse.csr_matrix(w)
  52. wcsr = w_csr.data
  53. iw = w_csr.indptr
  54. jw = w_csr.indices
  55. workspace.FeedBlob(weight_name + "wcsr", wcsr)
  56. workspace.FeedBlob(weight_name + "iw", iw)
  57. workspace.FeedBlob(weight_name + "jw", jw)
  58. return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"
  59. def transFCRelu(cur, id2node, name2id, ops, model):
  60. """
  61. Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
  62. """
  63. # 1. add trans before the start of this chain
  64. # assuming that cur is a FC_Prune, and it has only one input
  65. pre = cur.prev.itervalues().next()
  66. # Create a node /op and insert it.
  67. # TODO(wyiming): check whether it is correct here
  68. current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
  69. # print model.net.Proto()
  70. trans_op = model.net.Proto().op[-1]
  71. trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
  72. trans_node.visited = True
  73. pre_new = trans_node
  74. # 2. use while loop to visit the chain
  75. while True:
  76. # breakup with the parent
  77. cur.deleteInput(pre)
  78. if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
  79. print("Reaching the end of the chain")
  80. break
  81. if len(cur.ops) > 1:
  82. print("A FC/Relu giving more than 1 useful outputs")
  83. if cur.optype == "FC_Prune":
  84. op = cur.op
  85. wcsr, iw, jw = maskNallocate(op.input[1])
  86. bias_name = op.input[3]
  87. # TODO(wyiming): create a new Op here
  88. current_blob = model.FC_Sparse(current_blob,
  89. cur.op.output[0] + "_Sparse",
  90. wcsr, iw, jw, bias_name)
  91. sps_op = model.net.Proto().op[-1]
  92. sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
  93. "FC_Sparse",
  94. pre_new, sps_op)
  95. sps_node.visited = True
  96. pre_new = sps_node
  97. if cur.optype == "Relu":
  98. op = cur.op
  99. current_blob = model.Relu(current_blob, current_blob)
  100. rel_op = model.net.Proto().op[-1]
  101. rel_node = NetDefNode(str(current_blob), "Relu",
  102. pre_new, rel_op)
  103. rel_node.visited = True
  104. pre_new = rel_node
  105. cur.visited = True
  106. pre = cur
  107. flag = False
  108. for _, temp in cur.ops.iteritems():
  109. if temp.optype == "Relu" or temp.optype == "FC_Prune":
  110. flag = True
  111. cur = temp
  112. if not flag:
  113. # assume that there is only 1 output that is not PrintOP
  114. cur = cur.ops.itervalues().next()
  115. cur.deleteInput(pre)
  116. print("No FC/RElu children")
  117. print(cur.op.type)
  118. break
  119. # 3. add trans after this chain like 1.
  120. current_blob = model.Transpose(current_blob, pre.op.output[0])
  121. trans_op = model.net.Proto().op[-1]
  122. trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
  123. trans_node.visited = True
  124. cur.insertInput(trans_node)
  125. print(cur.prev)
  126. print(trans_node.ops)
  127. def Prune2Sparse(cur, id2node, name2id, ops, model):
  128. # Assume that FC and Relu takes in only 1 input;
  129. # If not raise warning
  130. if not cur.visited and cur.optype == "FC_Prune":
  131. transFCRelu(cur, id2node, name2id, ops, model)
  132. cur.visited = True
  133. for name, n in cur.ops.iteritems():
  134. Prune2Sparse(n, id2node, name2id, ops, model)
  135. def net2list(net_root):
  136. """
  137. Use topological order(BFS) to print the op of a net in a list
  138. """
  139. bfs_queue = []
  140. op_list = []
  141. cur = net_root
  142. for _, n in cur.ops.iteritems():
  143. bfs_queue.append(n)
  144. while bfs_queue:
  145. node = bfs_queue[0]
  146. bfs_queue = bfs_queue[1:]
  147. op_list.append(node.op)
  148. for _, n in node.ops.iteritems():
  149. bfs_queue.append(n)
  150. return op_list
  151. def netbuilder(model):
  152. print("Welcome to model checker")
  153. proto = model.net.Proto()
  154. net_name2id = {}
  155. net_id2node = {}
  156. net_root = NetDefNode("net_root", "root", None)
  157. for op_id, op in enumerate(proto.op):
  158. if op.type == "Print":
  159. continue
  160. op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
  161. if op.name else '%s (op#%d)' % (op.type, op_id)
  162. # print(op_name)
  163. op_node = NetDefNode(op_name, op.type, op=op)
  164. net_id2node[op_id] = op_node
  165. if_has_layer_input = False
  166. for input_name in op.input:
  167. if input_name not in net_name2id:
  168. # assume that un_occured name are non_layers
  169. # TODO: write a non-layer checker and log it
  170. continue
  171. op_node.insertInput(net_id2node[net_name2id[input_name]])
  172. if_has_layer_input = True
  173. if not if_has_layer_input:
  174. op_node.insertInput(net_root)
  175. for output_name in op.output:
  176. net_name2id[output_name] = op_id
  177. return net_root, net_name2id, net_id2node