workspace.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. ## @package workspace
  2. # Module caffe2.python.workspace
  3. import collections
  4. import contextlib
  5. from google.protobuf.message import Message
  6. from multiprocessing import Process
  7. import os
  8. from collections import defaultdict
  9. import logging
  10. import numpy as np
  11. from past.builtins import basestring
  12. import shutil
  13. import socket
  14. import tempfile
  15. from caffe2.proto import caffe2_pb2
  16. from caffe2.python import scope, utils
  17. from caffe2.python.lazy import TriggerLazyImport
  18. import caffe2.python._import_c_extension as C
  19. logger = logging.getLogger(__name__)
  20. Blobs = C.blobs
  21. ResetBlob = C.reset_blob
  22. CreateBlob = C.create_blob
  23. CurrentWorkspace = C.current_workspace
  24. DeserializeBlob = C.deserialize_blob
  25. GlobalInit = C.global_init
  26. HasBlob = C.has_blob
  27. RegisteredOperators = C.registered_operators
  28. SerializeBlob = C.serialize_blob
  29. SwitchWorkspace = C.switch_workspace
  30. RootFolder = C.root_folder
  31. Workspaces = C.workspaces
  32. BenchmarkNet = C.benchmark_net
  33. BenchmarkNetOnce = C.benchmark_net_once
  34. GetStats = C.get_stats
  35. CreateOfflineTensor = C.create_offline_tensor
  36. operator_tracebacks = defaultdict(dict)
  37. is_asan = C.is_asan
  38. has_fbgemm = C.has_fbgemm
  39. has_cuda_support = C.has_cuda_support
  40. has_hip_support = C.has_hip_support
  41. has_gpu_support = C.has_gpu_support
  42. if has_cuda_support:
  43. GpuDeviceType = caffe2_pb2.CUDA
  44. NumCudaDevices = C.num_cuda_devices
  45. # This is a duplicate of NumCudaDevices. Remove
  46. # NumCudaDevices once replaced everywhere in the code
  47. NumGpuDevices = C.num_cuda_devices
  48. GetCUDAVersion = C.get_cuda_version
  49. GetCuDNNVersion = C.get_cudnn_version
  50. def GetGpuPeerAccessPattern():
  51. return np.asarray(C.get_cuda_peer_access_pattern())
  52. GetDeviceProperties = C.get_device_properties
  53. GetGPUMemoryInfo = C.get_gpu_memory_info
  54. else:
  55. # pyre-fixme[9]: incompatible type assignment
  56. NumCudaDevices = lambda: 0 # noqa
  57. # pyre-fixme[9]: incompatible type assignment
  58. GetCUDAVersion = lambda: 0 # noqa
  59. # pyre-fixme[9]: incompatible type assignment
  60. GetCuDNNVersion = lambda: 0 # noqa
  61. if has_hip_support:
  62. GpuDeviceType = caffe2_pb2.HIP
  63. # pyre-fixme[9]: incompatible type assignment
  64. NumGpuDevices = C.num_hip_devices
  65. GetHIPVersion = C.get_hip_version
  66. def GetGpuPeerAccessPattern():
  67. return np.asarray(C.get_hip_peer_access_pattern())
  68. GetDeviceProperties = C.get_device_properties
  69. GetGPUMemoryInfo = C.get_gpu_memory_info
  70. if not has_gpu_support:
  71. # setting cuda as the default GpuDeviceType as some tests
  72. # like core, scope tests use GpuDeviceType even without gpu support
  73. GpuDeviceType = caffe2_pb2.CUDA
  74. # pyre-fixme[9]: incompatible type assignment
  75. NumGpuDevices = lambda: 0 # noqa
  76. GetDeviceProperties = lambda x: None # noqa
  77. GetGpuPeerAccessPattern = lambda: np.array([]) # noqa
  78. # pyre-fixme[9]: incompatible type assignment
  79. GetGPUMemoryInfo = lambda: None # noqa
  80. IsNUMAEnabled = C.is_numa_enabled
  81. GetNumNUMANodes = C.get_num_numa_nodes
  82. GetBlobNUMANode = C.get_blob_numa_node
  83. GetBlobSizeBytes = C.get_blob_size_bytes
  84. def FillRandomNetworkInputs(net, input_dims, input_types):
  85. C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types)
  86. def _GetFreeFlaskPort():
  87. """Get a free flask port."""
  88. # We will prefer to use 5000. If not, we will then pick a random port.
  89. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  90. result = sock.connect_ex(('127.0.0.1', 5000))
  91. if result == 0:
  92. return 5000
  93. else:
  94. s = socket.socket()
  95. s.bind(('', 0))
  96. port = s.getsockname()[1]
  97. s.close()
  98. # Race condition: between the interval we close the socket and actually
  99. # start a mint process, another process might have occupied the port. We
  100. # don't do much here as this is mostly for convenience in research
  101. # rather than 24x7 service.
  102. return port
  103. def StartMint(root_folder=None, port=None):
  104. """Start a mint instance.
  105. TODO(Yangqing): this does not work well under ipython yet. According to
  106. https://github.com/ipython/ipython/issues/5862
  107. writing up some fix is a todo item.
  108. """
  109. from caffe2.python.mint import app
  110. if root_folder is None:
  111. # Get the root folder from the current workspace
  112. root_folder = C.root_folder()
  113. if port is None:
  114. port = _GetFreeFlaskPort()
  115. process = Process(
  116. target=app.main,
  117. args=(
  118. ['-p', str(port), '-r', root_folder],
  119. )
  120. )
  121. process.start()
  122. print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
  123. return process
  124. def StringifyProto(obj):
  125. """Stringify a protocol buffer object.
  126. Inputs:
  127. obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
  128. function.
  129. Outputs:
  130. string: the output protobuf string.
  131. Raises:
  132. AttributeError: if the passed in object does not have the right attribute.
  133. """
  134. if isinstance(obj, basestring):
  135. return obj
  136. else:
  137. if isinstance(obj, Message):
  138. # First, see if this object is a protocol buffer, which we can
  139. # simply serialize with the SerializeToString() call.
  140. return obj.SerializeToString()
  141. elif hasattr(obj, 'Proto'):
  142. return obj.Proto().SerializeToString()
  143. else:
  144. raise ValueError("Unexpected argument to StringifyProto of type " +
  145. type(obj).__name__)
  146. def ResetWorkspace(root_folder=None):
  147. if root_folder is None:
  148. # Reset the workspace, but keep the current root folder setting.
  149. return C.reset_workspace(C.root_folder())
  150. else:
  151. if not os.path.exists(root_folder):
  152. os.makedirs(root_folder)
  153. return C.reset_workspace(root_folder)
  154. def CreateNet(net, overwrite=False, input_blobs=None):
  155. TriggerLazyImport()
  156. if input_blobs is None:
  157. input_blobs = []
  158. for input_blob in input_blobs:
  159. C.create_blob(input_blob)
  160. return CallWithExceptionIntercept(
  161. C.create_net,
  162. C.Workspace.current._last_failed_op_net_position,
  163. GetNetName(net),
  164. StringifyProto(net), overwrite,
  165. )
  166. def Predictor(init_net, predict_net):
  167. return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
  168. def GetOperatorCost(operator, blobs):
  169. return C.get_operator_cost(StringifyProto(operator), blobs)
  170. def RunOperatorOnce(operator):
  171. return C.run_operator_once(StringifyProto(operator))
  172. def RunOperatorMultiple(operator, num_runs):
  173. return C.run_operator_multiple(StringifyProto(operator), num_runs)
  174. def RunOperatorsOnce(operators):
  175. for op in operators:
  176. success = RunOperatorOnce(op)
  177. if not success:
  178. return False
  179. return True
  180. def ClearGlobalNetObserver():
  181. return C.clear_global_net_observer()
  182. def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
  183. try:
  184. return func(*args, **kwargs)
  185. except Exception:
  186. op_id = op_id_fetcher()
  187. net_tracebacks = operator_tracebacks.get(net_name, None)
  188. logger.warning(
  189. 'Original python traceback for operator `{}` in network '
  190. '`{}` in exception above (most recent call last):'.format(
  191. op_id, net_name))
  192. if net_tracebacks and op_id in net_tracebacks:
  193. tb = net_tracebacks[op_id]
  194. for line in reversed(tb):
  195. logger.warning(' File "{}", line {}, in {}'.format(
  196. line[0], line[1], line[2]))
  197. raise
  198. def RunNetOnce(net):
  199. return CallWithExceptionIntercept(
  200. C.run_net_once,
  201. C.Workspace.current._last_failed_op_net_position,
  202. GetNetName(net),
  203. StringifyProto(net),
  204. )
  205. def RunNet(name, num_iter=1, allow_fail=False):
  206. """Runs a given net.
  207. Inputs:
  208. name: the name of the net, or a reference to the net.
  209. num_iter: number of iterations to run
  210. allow_fail: if True, does not assert on net exec failure but returns False
  211. Returns:
  212. True or an exception.
  213. """
  214. return CallWithExceptionIntercept(
  215. C.run_net,
  216. C.Workspace.current._last_failed_op_net_position,
  217. GetNetName(name),
  218. StringifyNetName(name), num_iter, allow_fail,
  219. )
  220. def RunPlan(plan_or_step):
  221. # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
  222. import caffe2.python.core as core
  223. if isinstance(plan_or_step, core.ExecutionStep):
  224. plan_or_step = core.Plan(plan_or_step)
  225. return C.run_plan(StringifyProto(plan_or_step))
  226. def RunPlanInBackground(plan_or_step):
  227. # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
  228. import caffe2.python.core as core
  229. if isinstance(plan_or_step, core.ExecutionStep):
  230. plan_or_step = core.Plan(plan_or_step)
  231. return C.run_plan_in_background(StringifyProto(plan_or_step))
  232. def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
  233. blob_types=None):
  234. """Infers the shapes and types for the specified nets.
  235. Inputs:
  236. nets: the list of nets
  237. blob_dimensions (optional): a dictionary of blobs and their dimensions.
  238. If not specified, the workspace blobs are used.
  239. nets_proto (optional): a boolean flag indicating whether the protobuffer
  240. representation is passed to the routine.
  241. Returns:
  242. A tuple of (shapes, types) dictionaries keyed by blob name.
  243. """
  244. if nets_proto:
  245. net_protos = [StringifyProto(n) for n in nets]
  246. else:
  247. net_protos = [StringifyProto(n.Proto()) for n in nets]
  248. if blob_dimensions is None:
  249. assert blob_types is None
  250. blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
  251. elif blob_types is None:
  252. blobdesc_prototxt = C.infer_shapes_and_types_from_map(
  253. net_protos, blob_dimensions
  254. )
  255. else:
  256. blobdesc_prototxt = C.infer_shapes_and_types_from_map(
  257. net_protos, blob_dimensions, blob_types
  258. )
  259. blobdesc_proto = caffe2_pb2.TensorShapes()
  260. blobdesc_proto.ParseFromString(blobdesc_prototxt)
  261. shapes = {}
  262. types = {}
  263. for ts in blobdesc_proto.shapes:
  264. if not ts.unknown_shape:
  265. shapes[ts.name] = list(ts.dims)
  266. types[ts.name] = ts.data_type
  267. return (shapes, types)
  268. def _StringifyName(name, expected_type):
  269. if isinstance(name, basestring):
  270. return name
  271. assert type(name).__name__ == expected_type, \
  272. "Expected a string or %s" % expected_type
  273. return str(name)
  274. def StringifyBlobName(name):
  275. return _StringifyName(name, "BlobReference")
  276. def StringifyNetName(name):
  277. return _StringifyName(name, "Net")
  278. def GetNetName(net):
  279. if isinstance(net, basestring):
  280. return net
  281. if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference":
  282. return net.Name()
  283. if isinstance(net, caffe2_pb2.NetDef):
  284. return net.name
  285. raise Exception("Not a Net object: {}".format(str(net)))
  286. def FeedBlob(name, arr, device_option=None):
  287. """Feeds a blob into the workspace.
  288. Inputs:
  289. name: the name of the blob.
  290. arr: either a TensorProto object or a numpy array object to be fed into
  291. the workspace.
  292. device_option (optional): the device option to feed the data with.
  293. Returns:
  294. True or False, stating whether the feed is successful.
  295. """
  296. ws = C.Workspace.current
  297. return _Workspace_feed_blob(ws, name, arr, device_option)
  298. def FetchBlobs(names):
  299. """Fetches a list of blobs from the workspace.
  300. Inputs:
  301. names: list of names of blobs - strings or BlobReferences
  302. Returns:
  303. list of fetched blobs
  304. """
  305. return [FetchBlob(name) for name in names]
  306. def FetchBlob(name):
  307. """Fetches a blob from the workspace.
  308. Inputs:
  309. name: the name of the blob - a string or a BlobReference
  310. Returns:
  311. Fetched blob (numpy array or string) if successful
  312. """
  313. result = C.fetch_blob(StringifyBlobName(name))
  314. if isinstance(result, tuple):
  315. raise TypeError(
  316. "Use FetchInt8Blob to fetch Int8 Blob {}".format(
  317. StringifyBlobName(name)
  318. )
  319. )
  320. return result
  321. def FetchTorch(name):
  322. ws = C.Workspace.current
  323. return ws.blobs[name].to_torch()
  324. Int8Tensor = collections.namedtuple(
  325. 'Int8Tensor', ['data', 'scale', 'zero_point']
  326. )
  327. def FetchInt8Blob(name):
  328. """Fetches an Int8 blob from the workspace. It shared backend implementation
  329. with FetchBlob but it is recommended when fetching Int8 Blobs
  330. Inputs:
  331. name: the name of the Int8 blob - a string or a BlobReference
  332. Returns:
  333. data: int8 numpy array, data
  334. scale: float, fake quantization scale
  335. zero_point: int, fake quantization offset
  336. """
  337. result = C.fetch_blob(StringifyBlobName(name))
  338. assert isinstance(result, tuple), \
  339. 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
  340. StringifyBlobName(name))
  341. return Int8Tensor(*result)
  342. def FetchInt8BlobRealVal(name):
  343. """Fetches an Int8 blob from the workspace and return its real value representation.
  344. Inputs:
  345. name: the name of the Int8 blob - a string or a BlobReference
  346. Returns:
  347. real value representation of int8 numpy array
  348. """
  349. result = C.fetch_blob(StringifyBlobName(name))
  350. assert isinstance(result, tuple), \
  351. 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
  352. StringifyBlobName(name))
  353. int8_blob = Int8Tensor(*result)
  354. return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
  355. np.float32) * int8_blob.scale
  356. def _Workspace_fetch_int8_blob(ws, name):
  357. """Fetches an Int8 blob from the workspace. It shared backend implementation
  358. with FetchBlob but it is recommended when fetching Int8 Blobs
  359. Inputs:
  360. name: the name of the Int8 blob - a string or a BlobReference
  361. Returns:
  362. data: int8 numpy array, data
  363. scale: float, fake quantization scale
  364. zero_point: int, fake quantization offset
  365. """
  366. result = ws.fetch_blob(name)
  367. assert isinstance(result, tuple), \
  368. 'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
  369. StringifyBlobName(name))
  370. return Int8Tensor(*result)
  371. C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
  372. def ApplyTransform(transform_key, net):
  373. """Apply a Transform to a NetDef protobuf object, and returns the new
  374. transformed NetDef.
  375. Inputs:
  376. transform_key: the name of the transform, as it is stored in the registry
  377. net: a NetDef protobuf object
  378. Returns:
  379. Transformed NetDef protobuf object.
  380. """
  381. transformed_net = caffe2_pb2.NetDef()
  382. transformed_str = C.apply_transform(
  383. str(transform_key).encode('utf-8'),
  384. net.SerializeToString(),
  385. )
  386. transformed_net.ParseFromString(transformed_str)
  387. return transformed_net
  388. def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
  389. """Apply a Transform to a NetDef protobuf object, and returns the new
  390. transformed NetDef, only if it runs faster than the original.
  391. The runs are performed on the current active workspace (gWorkspace).
  392. You should initialize that workspace before making a call to this function.
  393. Inputs:
  394. transform_key: the name of the transform, as it is stored in the registry
  395. net: a NetDef protobuf object
  396. init_net: The net to initialize the workspace.
  397. warmup_runs (optional):
  398. Determines how many times the net is run before testing.
  399. Will be 5 by default.
  400. main_runs (optional):
  401. Determines how many times the net is run during testing.
  402. Will be 10 by default.
  403. improvement_threshold (optional):
  404. Determines the factor which the new net needs to be faster
  405. in order to replace the old. Will be 1.01 by default.
  406. Returns:
  407. Either a Transformed NetDef protobuf object, or the original netdef.
  408. """
  409. warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
  410. main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
  411. improvement_threshold = kwargs['improvement_threshold'] \
  412. if 'improvement_threshold' in kwargs else 1.01
  413. transformed_net = caffe2_pb2.NetDef()
  414. transformed_str = C.apply_transform_if_faster(
  415. str(transform_key).encode('utf-8'),
  416. net.SerializeToString(),
  417. init_net.SerializeToString(),
  418. warmup_runs,
  419. main_runs,
  420. float(improvement_threshold),
  421. )
  422. transformed_net.ParseFromString(transformed_str)
  423. return transformed_net
  424. def GetNameScope():
  425. """Return the current namescope string. To be used to fetch blobs"""
  426. return scope.CurrentNameScope()
  427. class _BlobDict(object):
  428. """Provides python dict compatible way to do fetching and feeding"""
  429. def __getitem__(self, key):
  430. return FetchBlob(key)
  431. def __setitem__(self, key, value):
  432. return FeedBlob(key, value)
  433. def __len__(self):
  434. return len(C.blobs())
  435. def __iter__(self):
  436. return C.blobs().__iter__()
  437. def __contains__(self, item):
  438. return C.has_blob(item)
  439. blobs = _BlobDict()
  440. ################################################################################
  441. # Utilities for immediate mode
  442. #
  443. # Caffe2's immediate mode implements the following behavior: between the two
  444. # function calls StartImmediate() and StopImmediate(), for any operator that is
  445. # called through CreateOperator(), we will also run that operator in a workspace
  446. # that is specific to the immediate mode. The user is explicitly expected to
  447. # make sure that these ops have proper inputs and outputs, i.e. one should not
  448. # run an op where an external input is not created or fed.
  449. #
  450. # Users can use FeedImmediate() and FetchImmediate() to interact with blobs
  451. # in the immediate workspace.
  452. #
  453. # Once StopImmediate() is called, all contents in the immediate workspace is
  454. # freed up so one can continue using normal runs.
  455. #
  456. # The immediate mode is solely for debugging purposes and support will be very
  457. # sparse.
  458. ################################################################################
  459. _immediate_mode = False
  460. _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
  461. _immediate_root_folder = ''
  462. def IsImmediate():
  463. return _immediate_mode
  464. @contextlib.contextmanager
  465. def WorkspaceGuard(workspace_name):
  466. current = CurrentWorkspace()
  467. SwitchWorkspace(workspace_name, True)
  468. yield
  469. SwitchWorkspace(current)
  470. def StartImmediate(i_know=False):
  471. global _immediate_mode
  472. global _immediate_root_folder
  473. if IsImmediate():
  474. # already in immediate mode. We will kill the previous one
  475. # and start from fresh.
  476. StopImmediate()
  477. _immediate_mode = True
  478. with WorkspaceGuard(_immediate_workspace_name):
  479. _immediate_root_folder = tempfile.mkdtemp()
  480. ResetWorkspace(_immediate_root_folder)
  481. if i_know:
  482. # if the user doesn't want to see the warning message, sure...
  483. return
  484. print("""
  485. Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
  486. feature and may very easily go wrong. This is because Caffe2 uses a
  487. declarative way of defining operators and models, which is essentially
  488. not meant to run things in an interactive way. Read the following carefully
  489. to make sure that you understand the caveats.
  490. (1) You need to make sure that the sequences of operators you create are
  491. actually runnable sequentially. For example, if you create an op that takes
  492. an input X, somewhere earlier you should have already created X.
  493. (2) Caffe2 immediate uses one single workspace, so if the set of operators
  494. you run are intended to be under different workspaces, they will not run.
  495. To create boundaries between such use cases, you can call FinishImmediate()
  496. and StartImmediate() manually to flush out everything no longer needed.
  497. (3) Underlying objects held by the immediate mode may interfere with your
  498. normal run. For example, if there is a leveldb that you opened in immediate
  499. mode and did not close, your main run will fail because leveldb does not
  500. support double opening. Immediate mode may also occupy a lot of memory esp.
  501. on GPUs. Call FinishImmediate() as soon as possible when you no longer
  502. need it.
  503. (4) Immediate is designed to be slow. Every immediate call implicitly
  504. creates a temp operator object, runs it, and destroys the operator. This
  505. slow-speed run is by design to discourage abuse. For most use cases other
  506. than debugging, do NOT turn on immediate mode.
  507. (5) If there is anything FATAL happening in the underlying C++ code, the
  508. immediate mode will immediately (pun intended) cause the runtime to crash.
  509. Thus you should use immediate mode with extra care. If you still would
  510. like to, have fun [https://xkcd.com/149/].
  511. """)
  512. def StopImmediate():
  513. """Stops an immediate mode run."""
  514. # Phew, that was a dangerous ride.
  515. global _immediate_mode
  516. global _immediate_root_folder
  517. if not IsImmediate():
  518. return
  519. with WorkspaceGuard(_immediate_workspace_name):
  520. ResetWorkspace()
  521. shutil.rmtree(_immediate_root_folder)
  522. _immediate_root_folder = ''
  523. _immediate_mode = False
  524. def ImmediateBlobs():
  525. with WorkspaceGuard(_immediate_workspace_name):
  526. return Blobs()
  527. def RunOperatorImmediate(op):
  528. with WorkspaceGuard(_immediate_workspace_name):
  529. RunOperatorOnce(op)
  530. def FetchImmediate(*args, **kwargs):
  531. with WorkspaceGuard(_immediate_workspace_name):
  532. return FetchBlob(*args, **kwargs)
  533. def FeedImmediate(*args, **kwargs):
  534. with WorkspaceGuard(_immediate_workspace_name):
  535. return FeedBlob(*args, **kwargs)
  536. # C.Workspace methods.
  537. def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
  538. return CallWithExceptionIntercept(
  539. ws._create_net,
  540. ws._last_failed_op_net_position,
  541. GetNetName(net),
  542. StringifyProto(net), overwrite,
  543. )
  544. def _Workspace_run(ws, obj):
  545. if hasattr(obj, 'Proto'):
  546. obj = obj.Proto()
  547. if isinstance(obj, caffe2_pb2.PlanDef):
  548. return ws._run_plan(obj.SerializeToString())
  549. if isinstance(obj, caffe2_pb2.NetDef):
  550. return CallWithExceptionIntercept(
  551. ws._run_net,
  552. ws._last_failed_op_net_position,
  553. GetNetName(obj),
  554. obj.SerializeToString(),
  555. )
  556. # return ws._run_net(obj.SerializeToString())
  557. if isinstance(obj, caffe2_pb2.OperatorDef):
  558. return ws._run_operator(obj.SerializeToString())
  559. raise ValueError(
  560. "Don't know how to do Workspace.run() on {}".format(type(obj)))
  561. def _Workspace_feed_blob(ws, name, arr, device_option=None):
  562. if type(arr) is caffe2_pb2.TensorProto:
  563. arr = utils.Caffe2TensorToNumpyArray(arr)
  564. if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
  565. # Plain NumPy strings are weird, let's use objects instead
  566. arr = arr.astype(np.object)
  567. if device_option is None:
  568. device_option = scope.CurrentDeviceScope()
  569. if device_option and device_option.device_type == caffe2_pb2.CUDA:
  570. if arr.dtype == np.dtype('float64'):
  571. logger.warning(
  572. "CUDA operators do not support 64-bit doubles, " +
  573. "please use arr.astype(np.float32) or np.int32 for ints." +
  574. " Blob: {}".format(name) +
  575. " type: {}".format(str(arr.dtype))
  576. )
  577. name = StringifyBlobName(name)
  578. if device_option is not None:
  579. return ws.create_blob(name).feed(arr, device_option)
  580. else:
  581. return ws.create_blob(name).feed(arr)
  582. def _Workspace_remove_blob(ws, blob):
  583. ws._remove_blob(str(blob))
  584. Workspace = C.Workspace
  585. Workspace.create_net = _Workspace_create_net_with_exception_intercept
  586. Workspace.run = _Workspace_run
  587. Workspace.feed_blob = _Workspace_feed_blob
  588. Workspace.remove_blob = _Workspace_remove_blob
  589. # C.Blob methods.
  590. def _Blob_feed(blob, arg, device_option=None):
  591. # conservative type check to avoid unnecessary import
  592. if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch':
  593. import torch
  594. if isinstance(arg, torch.Tensor):
  595. assert device_option is None, \
  596. "device_option doesn't make sense with PyTorch tensors"
  597. handle = torch._C._tensor_impl_raw_handle(arg)
  598. blob._wrap_tensor_impl(handle)
  599. return True # _feed() returns True for some reason
  600. if device_option is not None:
  601. device_option = StringifyProto(device_option)
  602. return blob._feed(arg, device_option)
  603. C.Blob.feed = _Blob_feed
  604. def _Tensor_to_torch(tensor):
  605. """
  606. PyTorch tensor interop (TensorCPU methods)
  607. Can be accessed as:
  608. workspace.Workspace.current.blobs['foo'].tensor().to_torch()
  609. """
  610. # avoiding circular dependency
  611. import torch
  612. handle = tensor._tensor_impl_raw_handle()
  613. return torch._C._wrap_tensor_impl(handle)
  614. C.TensorCPU.to_torch = _Tensor_to_torch
  615. def _Blob_to_torch(blob):
  616. if not blob.is_tensor():
  617. raise RuntimeError("Blob has to be a tensor")
  618. return blob.as_tensor().to_torch()
  619. C.Blob.to_torch = _Blob_to_torch