utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # @package utils
  2. # Module caffe2.python.utils
  3. from caffe2.proto import caffe2_pb2
  4. from future.utils import viewitems
  5. from google.protobuf.message import DecodeError, Message
  6. from google.protobuf import text_format
  7. import sys
  8. import collections
  9. import copy
  10. import functools
  11. import numpy as np
  12. from six import integer_types, binary_type, text_type, string_types
  13. OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
  14. ITERATION_MUTEX_NAME = "iteration_mutex"
  15. def OpAlmostEqual(op_a, op_b, ignore_fields=None):
  16. '''
  17. Two ops are identical except for each field in the `ignore_fields`.
  18. '''
  19. ignore_fields = ignore_fields or []
  20. if not isinstance(ignore_fields, list):
  21. ignore_fields = [ignore_fields]
  22. assert all(isinstance(f, text_type) for f in ignore_fields), (
  23. 'Expect each field is text type, but got {}'.format(ignore_fields))
  24. def clean_op(op):
  25. op = copy.deepcopy(op)
  26. for field in ignore_fields:
  27. if op.HasField(field):
  28. op.ClearField(field)
  29. return op
  30. op_a = clean_op(op_a)
  31. op_b = clean_op(op_b)
  32. return op_a == op_b or str(op_a) == str(op_b)
  33. def CaffeBlobToNumpyArray(blob):
  34. if (blob.num != 0):
  35. # old style caffe blob.
  36. return (np.asarray(blob.data, dtype=np.float32)
  37. .reshape(blob.num, blob.channels, blob.height, blob.width))
  38. else:
  39. # new style caffe blob.
  40. return (np.asarray(blob.data, dtype=np.float32)
  41. .reshape(blob.shape.dim))
  42. def Caffe2TensorToNumpyArray(tensor):
  43. if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
  44. return np.asarray(
  45. tensor.float_data, dtype=np.float32).reshape(tensor.dims)
  46. elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
  47. return np.asarray(
  48. tensor.double_data, dtype=np.float64).reshape(tensor.dims)
  49. elif tensor.data_type == caffe2_pb2.TensorProto.INT64:
  50. return np.asarray(
  51. tensor.int64_data, dtype=np.int64).reshape(tensor.dims)
  52. elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
  53. return np.asarray(
  54. tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
  55. elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
  56. return np.asarray(
  57. tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
  58. elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
  59. return np.asarray(
  60. tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
  61. elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
  62. return np.asarray(
  63. tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
  64. elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
  65. return np.asarray(
  66. tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
  67. else:
  68. # TODO: complete the data type: bool, float16, byte, int64, string
  69. raise RuntimeError(
  70. "Tensor data type not supported yet: " + str(tensor.data_type))
  71. def NumpyArrayToCaffe2Tensor(arr, name=None):
  72. tensor = caffe2_pb2.TensorProto()
  73. tensor.dims.extend(arr.shape)
  74. if name:
  75. tensor.name = name
  76. if arr.dtype == np.float32:
  77. tensor.data_type = caffe2_pb2.TensorProto.FLOAT
  78. tensor.float_data.extend(list(arr.flatten().astype(float)))
  79. elif arr.dtype == np.float64:
  80. tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
  81. tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
  82. elif arr.dtype == np.int64:
  83. tensor.data_type = caffe2_pb2.TensorProto.INT64
  84. tensor.int64_data.extend(list(arr.flatten().astype(np.int64)))
  85. elif arr.dtype == np.int or arr.dtype == np.int32:
  86. tensor.data_type = caffe2_pb2.TensorProto.INT32
  87. tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
  88. elif arr.dtype == np.int16:
  89. tensor.data_type = caffe2_pb2.TensorProto.INT16
  90. tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
  91. elif arr.dtype == np.uint16:
  92. tensor.data_type = caffe2_pb2.TensorProto.UINT16
  93. tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
  94. elif arr.dtype == np.int8:
  95. tensor.data_type = caffe2_pb2.TensorProto.INT8
  96. tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
  97. elif arr.dtype == np.uint8:
  98. tensor.data_type = caffe2_pb2.TensorProto.UINT8
  99. tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
  100. else:
  101. # TODO: complete the data type: bool, float16, byte, string
  102. raise RuntimeError(
  103. "Numpy data type not supported yet: " + str(arr.dtype))
  104. return tensor
  105. def MakeArgument(key, value):
  106. """Makes an argument based on the value type."""
  107. argument = caffe2_pb2.Argument()
  108. argument.name = key
  109. iterable = isinstance(value, collections.abc.Iterable)
  110. # Fast tracking common use case where a float32 array of tensor parameters
  111. # needs to be serialized. The entire array is guaranteed to have the same
  112. # dtype, so no per-element checking necessary and no need to convert each
  113. # element separately.
  114. if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
  115. argument.floats.extend(value.flatten().tolist())
  116. return argument
  117. if isinstance(value, np.ndarray):
  118. value = value.flatten().tolist()
  119. elif isinstance(value, np.generic):
  120. # convert numpy scalar to native python type
  121. value = np.asscalar(value)
  122. if type(value) is float:
  123. argument.f = value
  124. elif type(value) in integer_types or type(value) is bool:
  125. # We make a relaxation that a boolean variable will also be stored as
  126. # int.
  127. argument.i = value
  128. elif isinstance(value, binary_type):
  129. argument.s = value
  130. elif isinstance(value, text_type):
  131. argument.s = value.encode('utf-8')
  132. elif isinstance(value, caffe2_pb2.NetDef):
  133. argument.n.CopyFrom(value)
  134. elif isinstance(value, Message):
  135. argument.s = value.SerializeToString()
  136. elif iterable and all(type(v) in [float, np.float_] for v in value):
  137. argument.floats.extend(
  138. v.item() if type(v) is np.float_ else v for v in value
  139. )
  140. elif iterable and all(
  141. type(v) in integer_types or type(v) in [bool, np.int_] for v in value
  142. ):
  143. argument.ints.extend(
  144. v.item() if type(v) is np.int_ else v for v in value
  145. )
  146. elif iterable and all(
  147. isinstance(v, binary_type) or isinstance(v, text_type) for v in value
  148. ):
  149. argument.strings.extend(
  150. v.encode('utf-8') if isinstance(v, text_type) else v
  151. for v in value
  152. )
  153. elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
  154. argument.nets.extend(value)
  155. elif iterable and all(isinstance(v, Message) for v in value):
  156. argument.strings.extend(v.SerializeToString() for v in value)
  157. else:
  158. if iterable:
  159. raise ValueError(
  160. "Unknown iterable argument type: key={} value={}, value "
  161. "type={}[{}]".format(
  162. key, value, type(value), set(type(v) for v in value)
  163. )
  164. )
  165. else:
  166. raise ValueError(
  167. "Unknown argument type: key={} value={}, value type={}".format(
  168. key, value, type(value)
  169. )
  170. )
  171. return argument
  172. def TryReadProtoWithClass(cls, s):
  173. """Reads a protobuffer with the given proto class.
  174. Inputs:
  175. cls: a protobuffer class.
  176. s: a string of either binary or text protobuffer content.
  177. Outputs:
  178. proto: the protobuffer of cls
  179. Throws:
  180. google.protobuf.message.DecodeError: if we cannot decode the message.
  181. """
  182. obj = cls()
  183. try:
  184. text_format.Parse(s, obj)
  185. return obj
  186. except (text_format.ParseError, UnicodeDecodeError):
  187. obj.ParseFromString(s)
  188. return obj
  189. def GetContentFromProto(obj, function_map):
  190. """Gets a specific field from a protocol buffer that matches the given class
  191. """
  192. for cls, func in viewitems(function_map):
  193. if type(obj) is cls:
  194. return func(obj)
  195. def GetContentFromProtoString(s, function_map):
  196. for cls, func in viewitems(function_map):
  197. try:
  198. obj = TryReadProtoWithClass(cls, s)
  199. return func(obj)
  200. except DecodeError:
  201. continue
  202. else:
  203. raise DecodeError("Cannot find a fit protobuffer class.")
  204. def ConvertProtoToBinary(proto_class, filename, out_filename):
  205. """Convert a text file of the given protobuf class to binary."""
  206. with open(filename) as f:
  207. proto = TryReadProtoWithClass(proto_class, f.read())
  208. with open(out_filename, 'w') as fid:
  209. fid.write(proto.SerializeToString())
  210. def GetGPUMemoryUsageStats():
  211. """Get GPU memory usage stats from CUDAContext/HIPContext. This requires flag
  212. --caffe2_gpu_memory_tracking to be enabled"""
  213. from caffe2.python import workspace, core
  214. workspace.RunOperatorOnce(
  215. core.CreateOperator(
  216. "GetGPUMemoryUsage",
  217. [],
  218. ["____mem____"],
  219. device_option=core.DeviceOption(workspace.GpuDeviceType, 0),
  220. ),
  221. )
  222. b = workspace.FetchBlob("____mem____")
  223. return {
  224. 'total_by_gpu': b[0, :],
  225. 'max_by_gpu': b[1, :],
  226. 'total': np.sum(b[0, :]),
  227. 'max_total': np.sum(b[1, :])
  228. }
  229. def ResetBlobs(blobs):
  230. from caffe2.python import workspace, core
  231. workspace.RunOperatorOnce(
  232. core.CreateOperator(
  233. "Free",
  234. list(blobs),
  235. list(blobs),
  236. device_option=core.DeviceOption(caffe2_pb2.CPU),
  237. ),
  238. )
  239. class DebugMode(object):
  240. '''
  241. This class allows to drop you into an interactive debugger
  242. if there is an unhandled exception in your python script
  243. Example of usage:
  244. def main():
  245. # your code here
  246. pass
  247. if __name__ == '__main__':
  248. from caffe2.python.utils import DebugMode
  249. DebugMode.run(main)
  250. '''
  251. @classmethod
  252. def run(cls, func):
  253. try:
  254. return func()
  255. except KeyboardInterrupt:
  256. raise
  257. except Exception:
  258. import pdb
  259. print(
  260. 'Entering interactive debugger. Type "bt" to print '
  261. 'the full stacktrace. Type "help" to see command listing.')
  262. print(sys.exc_info()[1])
  263. print
  264. pdb.post_mortem()
  265. sys.exit(1)
  266. raise
  267. def raiseIfNotEqual(a, b, msg):
  268. if a != b:
  269. raise Exception("{}. {} != {}".format(msg, a, b))
  270. def debug(f):
  271. '''
  272. Use this method to decorate your function with DebugMode's functionality
  273. Example:
  274. @debug
  275. def test_foo(self):
  276. raise Exception("Bar")
  277. '''
  278. @functools.wraps(f)
  279. def wrapper(*args, **kwargs):
  280. def func():
  281. return f(*args, **kwargs)
  282. return DebugMode.run(func)
  283. return wrapper
  284. def BuildUniqueMutexIter(
  285. init_net,
  286. net,
  287. iter=None,
  288. iter_mutex=None,
  289. iter_val=0
  290. ):
  291. '''
  292. Often, a mutex guarded iteration counter is needed. This function creates a
  293. mutex iter in the net uniquely (if the iter already existing, it does
  294. nothing)
  295. This function returns the iter blob
  296. '''
  297. iter = iter if iter is not None else OPTIMIZER_ITERATION_NAME
  298. iter_mutex = iter_mutex if iter_mutex is not None else ITERATION_MUTEX_NAME
  299. from caffe2.python import core
  300. if not init_net.BlobIsDefined(iter):
  301. # Add training operators.
  302. with core.DeviceScope(
  303. core.DeviceOption(caffe2_pb2.CPU,
  304. extra_info=["device_type_override:cpu"])
  305. ):
  306. iteration = init_net.ConstantFill(
  307. [],
  308. iter,
  309. shape=[1],
  310. value=iter_val,
  311. dtype=core.DataType.INT64,
  312. )
  313. iter_mutex = init_net.CreateMutex([], [iter_mutex])
  314. net.AtomicIter([iter_mutex, iteration], [iteration])
  315. else:
  316. iteration = init_net.GetBlobRef(iter)
  317. return iteration
  318. def EnumClassKeyVals(cls):
  319. # cls can only be derived from object
  320. assert type(cls) == type
  321. # Enum attribute keys are all capitalized and values are strings
  322. enum = {}
  323. for k in dir(cls):
  324. if k == k.upper():
  325. v = getattr(cls, k)
  326. if isinstance(v, string_types):
  327. assert v not in enum.values(), (
  328. "Failed to resolve {} as Enum: "
  329. "duplicate entries {}={}, {}={}".format(
  330. cls, k, v, [key for key in enum if enum[key] == v][0], v
  331. )
  332. )
  333. enum[k] = v
  334. return enum
  335. def ArgsToDict(args):
  336. """
  337. Convert a list of arguments to a name, value dictionary. Assumes that
  338. each argument has a name. Otherwise, the argument is skipped.
  339. """
  340. ans = {}
  341. for arg in args:
  342. if not arg.HasField("name"):
  343. continue
  344. for d in arg.DESCRIPTOR.fields:
  345. if d.name == "name":
  346. continue
  347. if d.label == d.LABEL_OPTIONAL and arg.HasField(d.name):
  348. ans[arg.name] = getattr(arg, d.name)
  349. break
  350. elif d.label == d.LABEL_REPEATED:
  351. list_ = getattr(arg, d.name)
  352. if len(list_) > 0:
  353. ans[arg.name] = list_
  354. break
  355. else:
  356. ans[arg.name] = None
  357. return ans
  358. def NHWC2NCHW(tensor):
  359. assert tensor.ndim >= 1
  360. return tensor.transpose((0, tensor.ndim - 1) + tuple(range(1, tensor.ndim - 1)))
  361. def NCHW2NHWC(tensor):
  362. assert tensor.ndim >= 2
  363. return tensor.transpose((0,) + tuple(range(2, tensor.ndim)) + (1,))