bucket_weighted.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. ## @package bucket_weighted
  2. # Module caffe2.python.layers.bucket_weighted
  3. import logging
  4. import numpy as np
  5. from caffe2.python import core, schema
  6. from caffe2.python.layers.layers import (
  7. get_categorical_limit,
  8. ModelLayer,
  9. )
  10. from caffe2.python.layers.tags import Tags
  11. logger = logging.getLogger(__name__)
  12. class BucketWeighted(ModelLayer):
  13. def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
  14. hash_buckets=True, weight_optim=None, name="bucket_weighted"):
  15. super(BucketWeighted, self).__init__(model, name, input_record)
  16. assert isinstance(input_record, schema.List), "Incorrect input type"
  17. self.bucket_boundaries = bucket_boundaries
  18. self.hash_buckets = hash_buckets
  19. if bucket_boundaries is not None:
  20. self.shape = len(bucket_boundaries) + 1
  21. elif max_score > 0:
  22. self.shape = max_score
  23. else:
  24. self.shape = get_categorical_limit(input_record)
  25. self.bucket_w = self.create_param(param_name='bucket_w',
  26. shape=[self.shape, ],
  27. initializer=('ConstantFill', {'value': 1.0}),
  28. optimizer=weight_optim)
  29. self.output_schema = schema.Struct(
  30. ('bucket_weights',
  31. schema.Scalar((np.float32, self.shape),
  32. self.get_next_blob_reference("bucket_w_gather")))
  33. )
  34. self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
  35. def get_memory_usage(self):
  36. return self.shape
  37. def add_ops(self, net):
  38. if self.bucket_boundaries is not None:
  39. buckets_int = net.Bucketize(
  40. self.input_record.values(),
  41. "buckets_int",
  42. boundaries=self.bucket_boundaries
  43. )
  44. else:
  45. buckets = self.input_record.values()
  46. buckets_int = net.Cast(
  47. buckets,
  48. "buckets_int",
  49. to=core.DataType.INT32
  50. )
  51. if self.hash_buckets:
  52. buckets_int = net.IndexHash(
  53. buckets_int, "hashed_buckets_int", seed=0, modulo=self.shape
  54. )
  55. net.Gather(
  56. [self.bucket_w, buckets_int],
  57. self.output_schema.bucket_weights.field_blobs())