| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779 |
- ## @package workspace
- # Module caffe2.python.workspace
- import collections
- import contextlib
- from google.protobuf.message import Message
- from multiprocessing import Process
- import os
- from collections import defaultdict
- import logging
- import numpy as np
- from past.builtins import basestring
- import shutil
- import socket
- import tempfile
- from caffe2.proto import caffe2_pb2
- from caffe2.python import scope, utils
- from caffe2.python.lazy import TriggerLazyImport
- import caffe2.python._import_c_extension as C
- logger = logging.getLogger(__name__)
- Blobs = C.blobs
- ResetBlob = C.reset_blob
- CreateBlob = C.create_blob
- CurrentWorkspace = C.current_workspace
- DeserializeBlob = C.deserialize_blob
- GlobalInit = C.global_init
- HasBlob = C.has_blob
- RegisteredOperators = C.registered_operators
- SerializeBlob = C.serialize_blob
- SwitchWorkspace = C.switch_workspace
- RootFolder = C.root_folder
- Workspaces = C.workspaces
- BenchmarkNet = C.benchmark_net
- BenchmarkNetOnce = C.benchmark_net_once
- GetStats = C.get_stats
- CreateOfflineTensor = C.create_offline_tensor
- operator_tracebacks = defaultdict(dict)
- is_asan = C.is_asan
- has_fbgemm = C.has_fbgemm
- has_cuda_support = C.has_cuda_support
- has_hip_support = C.has_hip_support
- has_gpu_support = C.has_gpu_support
- if has_cuda_support:
- GpuDeviceType = caffe2_pb2.CUDA
- NumCudaDevices = C.num_cuda_devices
- # This is a duplicate of NumCudaDevices. Remove
- # NumCudaDevices once replaced everywhere in the code
- NumGpuDevices = C.num_cuda_devices
- GetCUDAVersion = C.get_cuda_version
- GetCuDNNVersion = C.get_cudnn_version
- def GetGpuPeerAccessPattern():
- return np.asarray(C.get_cuda_peer_access_pattern())
- GetDeviceProperties = C.get_device_properties
- GetGPUMemoryInfo = C.get_gpu_memory_info
- else:
- # pyre-fixme[9]: incompatible type assignment
- NumCudaDevices = lambda: 0 # noqa
- # pyre-fixme[9]: incompatible type assignment
- GetCUDAVersion = lambda: 0 # noqa
- # pyre-fixme[9]: incompatible type assignment
- GetCuDNNVersion = lambda: 0 # noqa
- if has_hip_support:
- GpuDeviceType = caffe2_pb2.HIP
- # pyre-fixme[9]: incompatible type assignment
- NumGpuDevices = C.num_hip_devices
- GetHIPVersion = C.get_hip_version
- def GetGpuPeerAccessPattern():
- return np.asarray(C.get_hip_peer_access_pattern())
- GetDeviceProperties = C.get_device_properties
- GetGPUMemoryInfo = C.get_gpu_memory_info
- if not has_gpu_support:
- # setting cuda as the default GpuDeviceType as some tests
- # like core, scope tests use GpuDeviceType even without gpu support
- GpuDeviceType = caffe2_pb2.CUDA
- # pyre-fixme[9]: incompatible type assignment
- NumGpuDevices = lambda: 0 # noqa
- GetDeviceProperties = lambda x: None # noqa
- GetGpuPeerAccessPattern = lambda: np.array([]) # noqa
- # pyre-fixme[9]: incompatible type assignment
- GetGPUMemoryInfo = lambda: None # noqa
- IsNUMAEnabled = C.is_numa_enabled
- GetNumNUMANodes = C.get_num_numa_nodes
- GetBlobNUMANode = C.get_blob_numa_node
- GetBlobSizeBytes = C.get_blob_size_bytes
- def FillRandomNetworkInputs(net, input_dims, input_types):
- C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types)
- def _GetFreeFlaskPort():
- """Get a free flask port."""
- # We will prefer to use 5000. If not, we will then pick a random port.
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- result = sock.connect_ex(('127.0.0.1', 5000))
- if result == 0:
- return 5000
- else:
- s = socket.socket()
- s.bind(('', 0))
- port = s.getsockname()[1]
- s.close()
- # Race condition: between the interval we close the socket and actually
- # start a mint process, another process might have occupied the port. We
- # don't do much here as this is mostly for convenience in research
- # rather than 24x7 service.
- return port
- def StartMint(root_folder=None, port=None):
- """Start a mint instance.
- TODO(Yangqing): this does not work well under ipython yet. According to
- https://github.com/ipython/ipython/issues/5862
- writing up some fix is a todo item.
- """
- from caffe2.python.mint import app
- if root_folder is None:
- # Get the root folder from the current workspace
- root_folder = C.root_folder()
- if port is None:
- port = _GetFreeFlaskPort()
- process = Process(
- target=app.main,
- args=(
- ['-p', str(port), '-r', root_folder],
- )
- )
- process.start()
- print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
- return process
- def StringifyProto(obj):
- """Stringify a protocol buffer object.
- Inputs:
- obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
- function.
- Outputs:
- string: the output protobuf string.
- Raises:
- AttributeError: if the passed in object does not have the right attribute.
- """
- if isinstance(obj, basestring):
- return obj
- else:
- if isinstance(obj, Message):
- # First, see if this object is a protocol buffer, which we can
- # simply serialize with the SerializeToString() call.
- return obj.SerializeToString()
- elif hasattr(obj, 'Proto'):
- return obj.Proto().SerializeToString()
- else:
- raise ValueError("Unexpected argument to StringifyProto of type " +
- type(obj).__name__)
- def ResetWorkspace(root_folder=None):
- if root_folder is None:
- # Reset the workspace, but keep the current root folder setting.
- return C.reset_workspace(C.root_folder())
- else:
- if not os.path.exists(root_folder):
- os.makedirs(root_folder)
- return C.reset_workspace(root_folder)
- def CreateNet(net, overwrite=False, input_blobs=None):
- TriggerLazyImport()
- if input_blobs is None:
- input_blobs = []
- for input_blob in input_blobs:
- C.create_blob(input_blob)
- return CallWithExceptionIntercept(
- C.create_net,
- C.Workspace.current._last_failed_op_net_position,
- GetNetName(net),
- StringifyProto(net), overwrite,
- )
- def Predictor(init_net, predict_net):
- return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
- def GetOperatorCost(operator, blobs):
- return C.get_operator_cost(StringifyProto(operator), blobs)
- def RunOperatorOnce(operator):
- return C.run_operator_once(StringifyProto(operator))
- def RunOperatorMultiple(operator, num_runs):
- return C.run_operator_multiple(StringifyProto(operator), num_runs)
- def RunOperatorsOnce(operators):
- for op in operators:
- success = RunOperatorOnce(op)
- if not success:
- return False
- return True
- def ClearGlobalNetObserver():
- return C.clear_global_net_observer()
- def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
- try:
- return func(*args, **kwargs)
- except Exception:
- op_id = op_id_fetcher()
- net_tracebacks = operator_tracebacks.get(net_name, None)
- logger.warning(
- 'Original python traceback for operator `{}` in network '
- '`{}` in exception above (most recent call last):'.format(
- op_id, net_name))
- if net_tracebacks and op_id in net_tracebacks:
- tb = net_tracebacks[op_id]
- for line in reversed(tb):
- logger.warning(' File "{}", line {}, in {}'.format(
- line[0], line[1], line[2]))
- raise
- def RunNetOnce(net):
- return CallWithExceptionIntercept(
- C.run_net_once,
- C.Workspace.current._last_failed_op_net_position,
- GetNetName(net),
- StringifyProto(net),
- )
- def RunNet(name, num_iter=1, allow_fail=False):
- """Runs a given net.
- Inputs:
- name: the name of the net, or a reference to the net.
- num_iter: number of iterations to run
- allow_fail: if True, does not assert on net exec failure but returns False
- Returns:
- True or an exception.
- """
- return CallWithExceptionIntercept(
- C.run_net,
- C.Workspace.current._last_failed_op_net_position,
- GetNetName(name),
- StringifyNetName(name), num_iter, allow_fail,
- )
- def RunPlan(plan_or_step):
- # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
- import caffe2.python.core as core
- if isinstance(plan_or_step, core.ExecutionStep):
- plan_or_step = core.Plan(plan_or_step)
- return C.run_plan(StringifyProto(plan_or_step))
- def RunPlanInBackground(plan_or_step):
- # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
- import caffe2.python.core as core
- if isinstance(plan_or_step, core.ExecutionStep):
- plan_or_step = core.Plan(plan_or_step)
- return C.run_plan_in_background(StringifyProto(plan_or_step))
- def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
- blob_types=None):
- """Infers the shapes and types for the specified nets.
- Inputs:
- nets: the list of nets
- blob_dimensions (optional): a dictionary of blobs and their dimensions.
- If not specified, the workspace blobs are used.
- nets_proto (optional): a boolean flag indicating whether the protobuffer
- representation is passed to the routine.
- Returns:
- A tuple of (shapes, types) dictionaries keyed by blob name.
- """
- if nets_proto:
- net_protos = [StringifyProto(n) for n in nets]
- else:
- net_protos = [StringifyProto(n.Proto()) for n in nets]
- if blob_dimensions is None:
- assert blob_types is None
- blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
- elif blob_types is None:
- blobdesc_prototxt = C.infer_shapes_and_types_from_map(
- net_protos, blob_dimensions
- )
- else:
- blobdesc_prototxt = C.infer_shapes_and_types_from_map(
- net_protos, blob_dimensions, blob_types
- )
- blobdesc_proto = caffe2_pb2.TensorShapes()
- blobdesc_proto.ParseFromString(blobdesc_prototxt)
- shapes = {}
- types = {}
- for ts in blobdesc_proto.shapes:
- if not ts.unknown_shape:
- shapes[ts.name] = list(ts.dims)
- types[ts.name] = ts.data_type
- return (shapes, types)
- def _StringifyName(name, expected_type):
- if isinstance(name, basestring):
- return name
- assert type(name).__name__ == expected_type, \
- "Expected a string or %s" % expected_type
- return str(name)
- def StringifyBlobName(name):
- return _StringifyName(name, "BlobReference")
- def StringifyNetName(name):
- return _StringifyName(name, "Net")
- def GetNetName(net):
- if isinstance(net, basestring):
- return net
- if type(net).__name__ == "Net" or type(net).__name__ == "NetWithShapeInference":
- return net.Name()
- if isinstance(net, caffe2_pb2.NetDef):
- return net.name
- raise Exception("Not a Net object: {}".format(str(net)))
- def FeedBlob(name, arr, device_option=None):
- """Feeds a blob into the workspace.
- Inputs:
- name: the name of the blob.
- arr: either a TensorProto object or a numpy array object to be fed into
- the workspace.
- device_option (optional): the device option to feed the data with.
- Returns:
- True or False, stating whether the feed is successful.
- """
- ws = C.Workspace.current
- return _Workspace_feed_blob(ws, name, arr, device_option)
- def FetchBlobs(names):
- """Fetches a list of blobs from the workspace.
- Inputs:
- names: list of names of blobs - strings or BlobReferences
- Returns:
- list of fetched blobs
- """
- return [FetchBlob(name) for name in names]
- def FetchBlob(name):
- """Fetches a blob from the workspace.
- Inputs:
- name: the name of the blob - a string or a BlobReference
- Returns:
- Fetched blob (numpy array or string) if successful
- """
- result = C.fetch_blob(StringifyBlobName(name))
- if isinstance(result, tuple):
- raise TypeError(
- "Use FetchInt8Blob to fetch Int8 Blob {}".format(
- StringifyBlobName(name)
- )
- )
- return result
- def FetchTorch(name):
- ws = C.Workspace.current
- return ws.blobs[name].to_torch()
- Int8Tensor = collections.namedtuple(
- 'Int8Tensor', ['data', 'scale', 'zero_point']
- )
- def FetchInt8Blob(name):
- """Fetches an Int8 blob from the workspace. It shared backend implementation
- with FetchBlob but it is recommended when fetching Int8 Blobs
- Inputs:
- name: the name of the Int8 blob - a string or a BlobReference
- Returns:
- data: int8 numpy array, data
- scale: float, fake quantization scale
- zero_point: int, fake quantization offset
- """
- result = C.fetch_blob(StringifyBlobName(name))
- assert isinstance(result, tuple), \
- 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
- StringifyBlobName(name))
- return Int8Tensor(*result)
- def FetchInt8BlobRealVal(name):
- """Fetches an Int8 blob from the workspace and return its real value representation.
- Inputs:
- name: the name of the Int8 blob - a string or a BlobReference
- Returns:
- real value representation of int8 numpy array
- """
- result = C.fetch_blob(StringifyBlobName(name))
- assert isinstance(result, tuple), \
- 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
- StringifyBlobName(name))
- int8_blob = Int8Tensor(*result)
- return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
- np.float32) * int8_blob.scale
- def _Workspace_fetch_int8_blob(ws, name):
- """Fetches an Int8 blob from the workspace. It shared backend implementation
- with FetchBlob but it is recommended when fetching Int8 Blobs
- Inputs:
- name: the name of the Int8 blob - a string or a BlobReference
- Returns:
- data: int8 numpy array, data
- scale: float, fake quantization scale
- zero_point: int, fake quantization offset
- """
- result = ws.fetch_blob(name)
- assert isinstance(result, tuple), \
- 'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
- StringifyBlobName(name))
- return Int8Tensor(*result)
- C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
- def ApplyTransform(transform_key, net):
- """Apply a Transform to a NetDef protobuf object, and returns the new
- transformed NetDef.
- Inputs:
- transform_key: the name of the transform, as it is stored in the registry
- net: a NetDef protobuf object
- Returns:
- Transformed NetDef protobuf object.
- """
- transformed_net = caffe2_pb2.NetDef()
- transformed_str = C.apply_transform(
- str(transform_key).encode('utf-8'),
- net.SerializeToString(),
- )
- transformed_net.ParseFromString(transformed_str)
- return transformed_net
- def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
- """Apply a Transform to a NetDef protobuf object, and returns the new
- transformed NetDef, only if it runs faster than the original.
- The runs are performed on the current active workspace (gWorkspace).
- You should initialize that workspace before making a call to this function.
- Inputs:
- transform_key: the name of the transform, as it is stored in the registry
- net: a NetDef protobuf object
- init_net: The net to initialize the workspace.
- warmup_runs (optional):
- Determines how many times the net is run before testing.
- Will be 5 by default.
- main_runs (optional):
- Determines how many times the net is run during testing.
- Will be 10 by default.
- improvement_threshold (optional):
- Determines the factor which the new net needs to be faster
- in order to replace the old. Will be 1.01 by default.
- Returns:
- Either a Transformed NetDef protobuf object, or the original netdef.
- """
- warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
- main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
- improvement_threshold = kwargs['improvement_threshold'] \
- if 'improvement_threshold' in kwargs else 1.01
- transformed_net = caffe2_pb2.NetDef()
- transformed_str = C.apply_transform_if_faster(
- str(transform_key).encode('utf-8'),
- net.SerializeToString(),
- init_net.SerializeToString(),
- warmup_runs,
- main_runs,
- float(improvement_threshold),
- )
- transformed_net.ParseFromString(transformed_str)
- return transformed_net
- def GetNameScope():
- """Return the current namescope string. To be used to fetch blobs"""
- return scope.CurrentNameScope()
- class _BlobDict(object):
- """Provides python dict compatible way to do fetching and feeding"""
- def __getitem__(self, key):
- return FetchBlob(key)
- def __setitem__(self, key, value):
- return FeedBlob(key, value)
- def __len__(self):
- return len(C.blobs())
- def __iter__(self):
- return C.blobs().__iter__()
- def __contains__(self, item):
- return C.has_blob(item)
- blobs = _BlobDict()
- ################################################################################
- # Utilities for immediate mode
- #
- # Caffe2's immediate mode implements the following behavior: between the two
- # function calls StartImmediate() and StopImmediate(), for any operator that is
- # called through CreateOperator(), we will also run that operator in a workspace
- # that is specific to the immediate mode. The user is explicitly expected to
- # make sure that these ops have proper inputs and outputs, i.e. one should not
- # run an op where an external input is not created or fed.
- #
- # Users can use FeedImmediate() and FetchImmediate() to interact with blobs
- # in the immediate workspace.
- #
- # Once StopImmediate() is called, all contents in the immediate workspace is
- # freed up so one can continue using normal runs.
- #
- # The immediate mode is solely for debugging purposes and support will be very
- # sparse.
- ################################################################################
- _immediate_mode = False
- _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
- _immediate_root_folder = ''
- def IsImmediate():
- return _immediate_mode
- @contextlib.contextmanager
- def WorkspaceGuard(workspace_name):
- current = CurrentWorkspace()
- SwitchWorkspace(workspace_name, True)
- yield
- SwitchWorkspace(current)
- def StartImmediate(i_know=False):
- global _immediate_mode
- global _immediate_root_folder
- if IsImmediate():
- # already in immediate mode. We will kill the previous one
- # and start from fresh.
- StopImmediate()
- _immediate_mode = True
- with WorkspaceGuard(_immediate_workspace_name):
- _immediate_root_folder = tempfile.mkdtemp()
- ResetWorkspace(_immediate_root_folder)
- if i_know:
- # if the user doesn't want to see the warning message, sure...
- return
- print("""
- Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
- feature and may very easily go wrong. This is because Caffe2 uses a
- declarative way of defining operators and models, which is essentially
- not meant to run things in an interactive way. Read the following carefully
- to make sure that you understand the caveats.
- (1) You need to make sure that the sequences of operators you create are
- actually runnable sequentially. For example, if you create an op that takes
- an input X, somewhere earlier you should have already created X.
- (2) Caffe2 immediate uses one single workspace, so if the set of operators
- you run are intended to be under different workspaces, they will not run.
- To create boundaries between such use cases, you can call FinishImmediate()
- and StartImmediate() manually to flush out everything no longer needed.
- (3) Underlying objects held by the immediate mode may interfere with your
- normal run. For example, if there is a leveldb that you opened in immediate
- mode and did not close, your main run will fail because leveldb does not
- support double opening. Immediate mode may also occupy a lot of memory esp.
- on GPUs. Call FinishImmediate() as soon as possible when you no longer
- need it.
- (4) Immediate is designed to be slow. Every immediate call implicitly
- creates a temp operator object, runs it, and destroys the operator. This
- slow-speed run is by design to discourage abuse. For most use cases other
- than debugging, do NOT turn on immediate mode.
- (5) If there is anything FATAL happening in the underlying C++ code, the
- immediate mode will immediately (pun intended) cause the runtime to crash.
- Thus you should use immediate mode with extra care. If you still would
- like to, have fun [https://xkcd.com/149/].
- """)
- def StopImmediate():
- """Stops an immediate mode run."""
- # Phew, that was a dangerous ride.
- global _immediate_mode
- global _immediate_root_folder
- if not IsImmediate():
- return
- with WorkspaceGuard(_immediate_workspace_name):
- ResetWorkspace()
- shutil.rmtree(_immediate_root_folder)
- _immediate_root_folder = ''
- _immediate_mode = False
- def ImmediateBlobs():
- with WorkspaceGuard(_immediate_workspace_name):
- return Blobs()
- def RunOperatorImmediate(op):
- with WorkspaceGuard(_immediate_workspace_name):
- RunOperatorOnce(op)
- def FetchImmediate(*args, **kwargs):
- with WorkspaceGuard(_immediate_workspace_name):
- return FetchBlob(*args, **kwargs)
- def FeedImmediate(*args, **kwargs):
- with WorkspaceGuard(_immediate_workspace_name):
- return FeedBlob(*args, **kwargs)
- # C.Workspace methods.
- def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
- return CallWithExceptionIntercept(
- ws._create_net,
- ws._last_failed_op_net_position,
- GetNetName(net),
- StringifyProto(net), overwrite,
- )
- def _Workspace_run(ws, obj):
- if hasattr(obj, 'Proto'):
- obj = obj.Proto()
- if isinstance(obj, caffe2_pb2.PlanDef):
- return ws._run_plan(obj.SerializeToString())
- if isinstance(obj, caffe2_pb2.NetDef):
- return CallWithExceptionIntercept(
- ws._run_net,
- ws._last_failed_op_net_position,
- GetNetName(obj),
- obj.SerializeToString(),
- )
- # return ws._run_net(obj.SerializeToString())
- if isinstance(obj, caffe2_pb2.OperatorDef):
- return ws._run_operator(obj.SerializeToString())
- raise ValueError(
- "Don't know how to do Workspace.run() on {}".format(type(obj)))
- def _Workspace_feed_blob(ws, name, arr, device_option=None):
- if type(arr) is caffe2_pb2.TensorProto:
- arr = utils.Caffe2TensorToNumpyArray(arr)
- if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
- # Plain NumPy strings are weird, let's use objects instead
- arr = arr.astype(np.object)
- if device_option is None:
- device_option = scope.CurrentDeviceScope()
- if device_option and device_option.device_type == caffe2_pb2.CUDA:
- if arr.dtype == np.dtype('float64'):
- logger.warning(
- "CUDA operators do not support 64-bit doubles, " +
- "please use arr.astype(np.float32) or np.int32 for ints." +
- " Blob: {}".format(name) +
- " type: {}".format(str(arr.dtype))
- )
- name = StringifyBlobName(name)
- if device_option is not None:
- return ws.create_blob(name).feed(arr, device_option)
- else:
- return ws.create_blob(name).feed(arr)
- def _Workspace_remove_blob(ws, blob):
- ws._remove_blob(str(blob))
- Workspace = C.Workspace
- Workspace.create_net = _Workspace_create_net_with_exception_intercept
- Workspace.run = _Workspace_run
- Workspace.feed_blob = _Workspace_feed_blob
- Workspace.remove_blob = _Workspace_remove_blob
- # C.Blob methods.
- def _Blob_feed(blob, arg, device_option=None):
- # conservative type check to avoid unnecessary import
- if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch':
- import torch
- if isinstance(arg, torch.Tensor):
- assert device_option is None, \
- "device_option doesn't make sense with PyTorch tensors"
- handle = torch._C._tensor_impl_raw_handle(arg)
- blob._wrap_tensor_impl(handle)
- return True # _feed() returns True for some reason
- if device_option is not None:
- device_option = StringifyProto(device_option)
- return blob._feed(arg, device_option)
- C.Blob.feed = _Blob_feed
- def _Tensor_to_torch(tensor):
- """
- PyTorch tensor interop (TensorCPU methods)
- Can be accessed as:
- workspace.Workspace.current.blobs['foo'].tensor().to_torch()
- """
- # avoiding circular dependency
- import torch
- handle = tensor._tensor_impl_raw_handle()
- return torch._C._wrap_tensor_impl(handle)
- C.TensorCPU.to_torch = _Tensor_to_torch
- def _Blob_to_torch(blob):
- if not blob.is_tensor():
- raise RuntimeError("Blob has to be a tensor")
- return blob.as_tensor().to_torch()
- C.Blob.to_torch = _Blob_to_torch
|