concat.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. ## @package concat
  2. # Module caffe2.python.layers.concat
  3. from caffe2.python import schema
  4. from caffe2.python.layers.layers import (
  5. ModelLayer,
  6. )
  7. from future.utils import viewitems
  8. import numpy as np
  9. from collections import defaultdict
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. def get_concatenated_feature_to_index(blobs_to_concat):
  13. concat_feature_to_index = defaultdict(list)
  14. start_pos = 0
  15. for scalar in blobs_to_concat:
  16. num_dims = scalar.dtype.shape[0]
  17. if hasattr(scalar, 'metadata') \
  18. and hasattr(scalar.metadata, 'feature_specs') \
  19. and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \
  20. and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): # noqa B950
  21. for k, v in scalar.metadata.feature_specs.feature_to_index.items():
  22. concat_feature_to_index[k].extend([start_pos + vi for vi in v])
  23. start_pos += num_dims
  24. return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None
  25. class Concat(ModelLayer):
  26. """
  27. Construct Concat layer
  28. Assume that first dimension is batch,
  29. Example:
  30. embedding_dim = 64
  31. input_record = self.new_record(schema.Struct(
  32. ('input1', schema.Scalar((np.float32, (embedding_dim, )))),
  33. ('input2', schema.Scalar((np.float32, (embedding_dim, )))),
  34. ('input3', schema.Scalar((np.float32, (embedding_dim, )))),
  35. ))
  36. output = self.model.Concat(input_record)
  37. self.assertEqual(
  38. schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))),
  39. output
  40. )
  41. # Note that in Concat layer we assume first dimension is batch.
  42. # so input is B * embedding_dim
  43. # add_axis=1 make it B * 1 * embedding_dim
  44. # Concat on axis=1 make it B * N * embedding_dim
  45. output = self.model.Concat(input_record, axis=1, add_axis=1)
  46. self.assertEqual(
  47. schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))),
  48. output
  49. )
  50. """
  51. def __init__(self, model, input_record, axis=1, add_axis=0,
  52. name='concat', **kwargs):
  53. super(Concat, self).__init__(model, name, input_record, **kwargs)
  54. self.axis = axis
  55. self.add_axis = add_axis
  56. assert not (axis == 0 and add_axis == 1), \
  57. "It's not allowed to add axis=0"
  58. assert isinstance(input_record, schema.Struct),\
  59. "Incorrect input type. Expected Struct, but received: {0}".\
  60. format(input_record)
  61. shapes = []
  62. for field_name, field_type in viewitems(input_record.fields):
  63. assert isinstance(field_type, schema.Scalar),\
  64. "Incorrect input type for {}. Expected Scalar, but got: {}".\
  65. format(field_name, field_type)
  66. # Assume that first dimension is batch, so actual axis in shape is
  67. # axis - 1
  68. shape = list(field_type.field_type().shape)
  69. if add_axis:
  70. shape.insert(axis - 1, 1)
  71. assert len(shape) >= axis,\
  72. "Concat expects that limited dimensions of the input tensor"
  73. shapes.append(shape)
  74. logger.info('Concat Layer input shapes: ' + str(shapes))
  75. if axis == 0:
  76. self.output_schema = schema.from_blob_list(
  77. input_record[0],
  78. [self.get_next_blob_reference('output')]
  79. )
  80. return
  81. concat_dim = 0
  82. for shape in shapes:
  83. concat_dim += shape[axis - 1]
  84. shape[axis - 1] = 0
  85. assert shape == shapes[0],\
  86. "Shapes {0} and {1} are not compatible for Concat".\
  87. format(shape, shapes[0])
  88. output_dims = shapes[0]
  89. output_dims[axis - 1] = concat_dim
  90. logger.info('Concat Layer output_dims: ' + str(output_dims))
  91. self.output_schema = schema.Scalar(
  92. (np.float32, output_dims),
  93. self.get_next_blob_reference('output'))
  94. record_to_concat = input_record.fields.values()
  95. concated_feature_to_index = get_concatenated_feature_to_index(
  96. record_to_concat
  97. )
  98. if concated_feature_to_index:
  99. metadata = schema.Metadata(
  100. feature_specs=schema.FeatureSpec(
  101. feature_to_index=concated_feature_to_index
  102. )
  103. )
  104. self.output_schema.set_metadata(metadata)
  105. def add_ops(self, net):
  106. net.Concat(
  107. self.input_record.field_blobs(),
  108. [
  109. self.output_schema.field_blobs()[0],
  110. self.output_schema.field_blobs()[0] + "_concat_dims"
  111. ],
  112. axis=self.axis,
  113. add_axis=self.add_axis,
  114. )