| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- ## @package cnn
- # Module caffe2.python.cnn
- from caffe2.python import brew, workspace
- from caffe2.python.model_helper import ModelHelper
- from caffe2.proto import caffe2_pb2
- import logging
- class CNNModelHelper(ModelHelper):
- """A helper model so we can write CNN models more easily, without having to
- manually define parameter initializations and operators separately.
- """
- def __init__(self, order="NCHW", name=None,
- use_cudnn=True, cudnn_exhaustive_search=False,
- ws_nbytes_limit=None, init_params=True,
- skip_sparse_optim=False,
- param_model=None):
- logging.warning(
- "[====DEPRECATE WARNING====]: you are creating an "
- "object from CNNModelHelper class which will be deprecated soon. "
- "Please use ModelHelper object with brew module. For more "
- "information, please refer to caffe2.ai and python/brew.py, "
- "python/brew_test.py for more information."
- )
- cnn_arg_scope = {
- 'order': order,
- 'use_cudnn': use_cudnn,
- 'cudnn_exhaustive_search': cudnn_exhaustive_search,
- }
- if ws_nbytes_limit:
- cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit
- super(CNNModelHelper, self).__init__(
- skip_sparse_optim=skip_sparse_optim,
- name="CNN" if name is None else name,
- init_params=init_params,
- param_model=param_model,
- arg_scope=cnn_arg_scope,
- )
- self.order = order
- self.use_cudnn = use_cudnn
- self.cudnn_exhaustive_search = cudnn_exhaustive_search
- self.ws_nbytes_limit = ws_nbytes_limit
- if self.order != "NHWC" and self.order != "NCHW":
- raise ValueError(
- "Cannot understand the CNN storage order %s." % self.order
- )
- def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs):
- return brew.image_input(
- self,
- blob_in,
- blob_out,
- order=self.order,
- use_gpu_transform=use_gpu_transform,
- **kwargs
- )
- def VideoInput(self, blob_in, blob_out, **kwargs):
- return brew.video_input(
- self,
- blob_in,
- blob_out,
- **kwargs
- )
- def PadImage(self, blob_in, blob_out, **kwargs):
- # TODO(wyiming): remove this dummy helper later
- self.net.PadImage(blob_in, blob_out, **kwargs)
- def ConvNd(self, *args, **kwargs):
- return brew.conv_nd(
- self,
- *args,
- use_cudnn=self.use_cudnn,
- order=self.order,
- cudnn_exhaustive_search=self.cudnn_exhaustive_search,
- ws_nbytes_limit=self.ws_nbytes_limit,
- **kwargs
- )
- def Conv(self, *args, **kwargs):
- return brew.conv(
- self,
- *args,
- use_cudnn=self.use_cudnn,
- order=self.order,
- cudnn_exhaustive_search=self.cudnn_exhaustive_search,
- ws_nbytes_limit=self.ws_nbytes_limit,
- **kwargs
- )
- def ConvTranspose(self, *args, **kwargs):
- return brew.conv_transpose(
- self,
- *args,
- use_cudnn=self.use_cudnn,
- order=self.order,
- cudnn_exhaustive_search=self.cudnn_exhaustive_search,
- ws_nbytes_limit=self.ws_nbytes_limit,
- **kwargs
- )
- def GroupConv(self, *args, **kwargs):
- return brew.group_conv(
- self,
- *args,
- use_cudnn=self.use_cudnn,
- order=self.order,
- cudnn_exhaustive_search=self.cudnn_exhaustive_search,
- ws_nbytes_limit=self.ws_nbytes_limit,
- **kwargs
- )
- def GroupConv_Deprecated(self, *args, **kwargs):
- return brew.group_conv_deprecated(
- self,
- *args,
- use_cudnn=self.use_cudnn,
- order=self.order,
- cudnn_exhaustive_search=self.cudnn_exhaustive_search,
- ws_nbytes_limit=self.ws_nbytes_limit,
- **kwargs
- )
- def FC(self, *args, **kwargs):
- return brew.fc(self, *args, **kwargs)
- def PackedFC(self, *args, **kwargs):
- return brew.packed_fc(self, *args, **kwargs)
- def FC_Prune(self, *args, **kwargs):
- return brew.fc_prune(self, *args, **kwargs)
- def FC_Decomp(self, *args, **kwargs):
- return brew.fc_decomp(self, *args, **kwargs)
- def FC_Sparse(self, *args, **kwargs):
- return brew.fc_sparse(self, *args, **kwargs)
- def Dropout(self, *args, **kwargs):
- return brew.dropout(
- self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
- )
- def LRN(self, *args, **kwargs):
- return brew.lrn(
- self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
- )
- def Softmax(self, *args, **kwargs):
- return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs)
- def SpatialBN(self, *args, **kwargs):
- return brew.spatial_bn(self, *args, order=self.order, **kwargs)
- def SpatialGN(self, *args, **kwargs):
- return brew.spatial_gn(self, *args, order=self.order, **kwargs)
- def InstanceNorm(self, *args, **kwargs):
- return brew.instance_norm(self, *args, order=self.order, **kwargs)
- def Relu(self, *args, **kwargs):
- return brew.relu(
- self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs
- )
- def PRelu(self, *args, **kwargs):
- return brew.prelu(self, *args, **kwargs)
- def Concat(self, *args, **kwargs):
- return brew.concat(self, *args, order=self.order, **kwargs)
- def DepthConcat(self, *args, **kwargs):
- """The old depth concat function - we should move to use concat."""
- print("DepthConcat is deprecated. use Concat instead.")
- return self.Concat(*args, **kwargs)
- def Sum(self, *args, **kwargs):
- return brew.sum(self, *args, **kwargs)
- def Transpose(self, *args, **kwargs):
- return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs)
- def Iter(self, *args, **kwargs):
- return brew.iter(self, *args, **kwargs)
- def Accuracy(self, *args, **kwargs):
- return brew.accuracy(self, *args, **kwargs)
- def MaxPool(self, *args, **kwargs):
- return brew.max_pool(
- self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
- )
- def MaxPoolWithIndex(self, *args, **kwargs):
- return brew.max_pool_with_index(self, *args, order=self.order, **kwargs)
- def AveragePool(self, *args, **kwargs):
- return brew.average_pool(
- self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs
- )
- @property
- def XavierInit(self):
- return ('XavierFill', {})
- def ConstantInit(self, value):
- return ('ConstantFill', dict(value=value))
- @property
- def MSRAInit(self):
- return ('MSRAFill', {})
- @property
- def ZeroInit(self):
- return ('ConstantFill', {})
- def AddWeightDecay(self, weight_decay):
- return brew.add_weight_decay(self, weight_decay)
- @property
- def CPU(self):
- device_option = caffe2_pb2.DeviceOption()
- device_option.device_type = caffe2_pb2.CPU
- return device_option
- @property
- def GPU(self, gpu_id=0):
- device_option = caffe2_pb2.DeviceOption()
- device_option.device_type = workspace.GpuDeviceType
- device_option.device_id = gpu_id
- return device_option
|