| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- ## @package concat
- # Module caffe2.python.layers.concat
- from caffe2.python import schema
- from caffe2.python.layers.layers import (
- ModelLayer,
- )
- from future.utils import viewitems
- import numpy as np
- from collections import defaultdict
- import logging
- logger = logging.getLogger(__name__)
- def get_concatenated_feature_to_index(blobs_to_concat):
- concat_feature_to_index = defaultdict(list)
- start_pos = 0
- for scalar in blobs_to_concat:
- num_dims = scalar.dtype.shape[0]
- if hasattr(scalar, 'metadata') \
- and hasattr(scalar.metadata, 'feature_specs') \
- and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \
- and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): # noqa B950
- for k, v in scalar.metadata.feature_specs.feature_to_index.items():
- concat_feature_to_index[k].extend([start_pos + vi for vi in v])
- start_pos += num_dims
- return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None
- class Concat(ModelLayer):
- """
- Construct Concat layer
- Assume that first dimension is batch,
- Example:
- embedding_dim = 64
- input_record = self.new_record(schema.Struct(
- ('input1', schema.Scalar((np.float32, (embedding_dim, )))),
- ('input2', schema.Scalar((np.float32, (embedding_dim, )))),
- ('input3', schema.Scalar((np.float32, (embedding_dim, )))),
- ))
- output = self.model.Concat(input_record)
- self.assertEqual(
- schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))),
- output
- )
- # Note that in Concat layer we assume first dimension is batch.
- # so input is B * embedding_dim
- # add_axis=1 make it B * 1 * embedding_dim
- # Concat on axis=1 make it B * N * embedding_dim
- output = self.model.Concat(input_record, axis=1, add_axis=1)
- self.assertEqual(
- schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))),
- output
- )
- """
- def __init__(self, model, input_record, axis=1, add_axis=0,
- name='concat', **kwargs):
- super(Concat, self).__init__(model, name, input_record, **kwargs)
- self.axis = axis
- self.add_axis = add_axis
- assert not (axis == 0 and add_axis == 1), \
- "It's not allowed to add axis=0"
- assert isinstance(input_record, schema.Struct),\
- "Incorrect input type. Expected Struct, but received: {0}".\
- format(input_record)
- shapes = []
- for field_name, field_type in viewitems(input_record.fields):
- assert isinstance(field_type, schema.Scalar),\
- "Incorrect input type for {}. Expected Scalar, but got: {}".\
- format(field_name, field_type)
- # Assume that first dimension is batch, so actual axis in shape is
- # axis - 1
- shape = list(field_type.field_type().shape)
- if add_axis:
- shape.insert(axis - 1, 1)
- assert len(shape) >= axis,\
- "Concat expects that limited dimensions of the input tensor"
- shapes.append(shape)
- logger.info('Concat Layer input shapes: ' + str(shapes))
- if axis == 0:
- self.output_schema = schema.from_blob_list(
- input_record[0],
- [self.get_next_blob_reference('output')]
- )
- return
- concat_dim = 0
- for shape in shapes:
- concat_dim += shape[axis - 1]
- shape[axis - 1] = 0
- assert shape == shapes[0],\
- "Shapes {0} and {1} are not compatible for Concat".\
- format(shape, shapes[0])
- output_dims = shapes[0]
- output_dims[axis - 1] = concat_dim
- logger.info('Concat Layer output_dims: ' + str(output_dims))
- self.output_schema = schema.Scalar(
- (np.float32, output_dims),
- self.get_next_blob_reference('output'))
- record_to_concat = input_record.fields.values()
- concated_feature_to_index = get_concatenated_feature_to_index(
- record_to_concat
- )
- if concated_feature_to_index:
- metadata = schema.Metadata(
- feature_specs=schema.FeatureSpec(
- feature_to_index=concated_feature_to_index
- )
- )
- self.output_schema.set_metadata(metadata)
- def add_ops(self, net):
- net.Concat(
- self.input_record.field_blobs(),
- [
- self.output_schema.field_blobs()[0],
- self.output_schema.field_blobs()[0] + "_concat_dims"
- ],
- axis=self.axis,
- add_axis=self.add_axis,
- )
|