position_weighted.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. ## @package position_weighted
  2. # Module caffe2.python.layers.position_weighted
  3. import logging
  4. import numpy as np
  5. from caffe2.python import schema
  6. from caffe2.python.layers.layers import (
  7. get_categorical_limit,
  8. ModelLayer,
  9. )
  10. from caffe2.python.layers.tags import Tags
  11. logger = logging.getLogger(__name__)
  12. class PositionWeighted(ModelLayer):
  13. def __init__(self, model, input_record, weight_optim=None,
  14. name="position_weights"):
  15. super(PositionWeighted, self).__init__(model, name, input_record)
  16. assert isinstance(input_record, schema.List), "Incorrect input type"
  17. length_metadata = input_record.lengths.metadata
  18. max_length = (length_metadata.categorical_limit if length_metadata is
  19. not None else None)
  20. if max_length is not None:
  21. self.shape = max_length
  22. else:
  23. self.shape = get_categorical_limit(input_record)
  24. logger.warning(
  25. '{}: categorical_limit of lengths is not available, using '
  26. 'categorical_limit of the keys: {}'.format(
  27. str(input_record.lengths()), self.shape))
  28. self.pos_w = self.create_param(param_name='pos_w',
  29. shape=[self.shape, ],
  30. initializer=('ConstantFill', {'value': 1.0}),
  31. optimizer=weight_optim)
  32. self.output_schema = schema.Struct(
  33. ('position_weights',
  34. schema.Scalar((np.float32, self.shape),
  35. self.get_next_blob_reference("pos_w_gather")))
  36. )
  37. self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
  38. def get_memory_usage(self):
  39. return self.shape
  40. def add_ops(self, net):
  41. inc_seq = net.LengthsRangeFill(
  42. [self.input_record.lengths()],
  43. self.input_record.lengths() + '_pos_w_seq'
  44. )
  45. net.Gather(
  46. [self.pos_w, inc_seq],
  47. self.output_schema.position_weights.field_blobs())