| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- ## @package split
- # Module caffe2.python.layers.split
- from caffe2.python import schema
- from caffe2.python.layers.layers import (
- ModelLayer,
- )
- class Split(ModelLayer):
- def __init__(self, model, input_record, num_splits=1, axis=1,
- name='split', split=None, **kwargs):
- super(Split, self).__init__(model, name, input_record, **kwargs)
- self.axis = axis
- # Assume that first dimension is batch, so actual axis in shape is
- # axis - 1
- axis -= 1
- assert axis >= 0
- assert isinstance(input_record, schema.Scalar),\
- "Incorrect input type. Expected Scalar, but received: {0}".\
- format(input_record)
- input_shape = input_record.field_type().shape
- assert len(input_shape) >= axis
- if split is None:
- assert input_shape[axis] % num_splits == 0
- else:
- num_splits = len(split)
- assert input_shape[axis] == sum(split)
- if split is None:
- output_shape = list(input_shape)
- output_shape[axis] = int(output_shape[axis] / num_splits)
- else:
- output_shape = []
- for i in range(num_splits):
- output_shape_i = list(input_shape)
- output_shape_i[axis] = split[i]
- output_shape.append(output_shape_i)
- data_type = input_record.field_type().base
- if split is None:
- output_scalars = [
- schema.Scalar(
- (data_type, output_shape),
- self.get_next_blob_reference('output_{}'.format(i)),
- )
- for i in range(num_splits)
- ]
- else:
- output_scalars = [
- schema.Scalar(
- (data_type, output_shape[i]),
- self.get_next_blob_reference('output_{}'.format(i)),
- )
- for i in range(num_splits)
- ]
- self.output_schema = schema.Tuple(*output_scalars)
- self.split = split
- def add_ops(self, net):
- net.Split(
- self.input_record.field_blobs(),
- self.output_schema.field_blobs(),
- split=self.split,
- axis=self.axis,
- )
|