split.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. ## @package split
  2. # Module caffe2.python.layers.split
  3. from caffe2.python import schema
  4. from caffe2.python.layers.layers import (
  5. ModelLayer,
  6. )
  7. class Split(ModelLayer):
  8. def __init__(self, model, input_record, num_splits=1, axis=1,
  9. name='split', split=None, **kwargs):
  10. super(Split, self).__init__(model, name, input_record, **kwargs)
  11. self.axis = axis
  12. # Assume that first dimension is batch, so actual axis in shape is
  13. # axis - 1
  14. axis -= 1
  15. assert axis >= 0
  16. assert isinstance(input_record, schema.Scalar),\
  17. "Incorrect input type. Expected Scalar, but received: {0}".\
  18. format(input_record)
  19. input_shape = input_record.field_type().shape
  20. assert len(input_shape) >= axis
  21. if split is None:
  22. assert input_shape[axis] % num_splits == 0
  23. else:
  24. num_splits = len(split)
  25. assert input_shape[axis] == sum(split)
  26. if split is None:
  27. output_shape = list(input_shape)
  28. output_shape[axis] = int(output_shape[axis] / num_splits)
  29. else:
  30. output_shape = []
  31. for i in range(num_splits):
  32. output_shape_i = list(input_shape)
  33. output_shape_i[axis] = split[i]
  34. output_shape.append(output_shape_i)
  35. data_type = input_record.field_type().base
  36. if split is None:
  37. output_scalars = [
  38. schema.Scalar(
  39. (data_type, output_shape),
  40. self.get_next_blob_reference('output_{}'.format(i)),
  41. )
  42. for i in range(num_splits)
  43. ]
  44. else:
  45. output_scalars = [
  46. schema.Scalar(
  47. (data_type, output_shape[i]),
  48. self.get_next_blob_reference('output_{}'.format(i)),
  49. )
  50. for i in range(num_splits)
  51. ]
  52. self.output_schema = schema.Tuple(*output_scalars)
  53. self.split = split
  54. def add_ops(self, net):
  55. net.Split(
  56. self.input_record.field_blobs(),
  57. self.output_schema.field_blobs(),
  58. split=self.split,
  59. axis=self.axis,
  60. )