| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- ## @package bucket_weighted
- # Module caffe2.python.layers.bucket_weighted
- import logging
- import numpy as np
- from caffe2.python import core, schema
- from caffe2.python.layers.layers import (
- get_categorical_limit,
- ModelLayer,
- )
- from caffe2.python.layers.tags import Tags
- logger = logging.getLogger(__name__)
- class BucketWeighted(ModelLayer):
- def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
- hash_buckets=True, weight_optim=None, name="bucket_weighted"):
- super(BucketWeighted, self).__init__(model, name, input_record)
- assert isinstance(input_record, schema.List), "Incorrect input type"
- self.bucket_boundaries = bucket_boundaries
- self.hash_buckets = hash_buckets
- if bucket_boundaries is not None:
- self.shape = len(bucket_boundaries) + 1
- elif max_score > 0:
- self.shape = max_score
- else:
- self.shape = get_categorical_limit(input_record)
- self.bucket_w = self.create_param(param_name='bucket_w',
- shape=[self.shape, ],
- initializer=('ConstantFill', {'value': 1.0}),
- optimizer=weight_optim)
- self.output_schema = schema.Struct(
- ('bucket_weights',
- schema.Scalar((np.float32, self.shape),
- self.get_next_blob_reference("bucket_w_gather")))
- )
- self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
- def get_memory_usage(self):
- return self.shape
- def add_ops(self, net):
- if self.bucket_boundaries is not None:
- buckets_int = net.Bucketize(
- self.input_record.values(),
- "buckets_int",
- boundaries=self.bucket_boundaries
- )
- else:
- buckets = self.input_record.values()
- buckets_int = net.Cast(
- buckets,
- "buckets_int",
- to=core.DataType.INT32
- )
- if self.hash_buckets:
- buckets_int = net.IndexHash(
- buckets_int, "hashed_buckets_int", seed=0, modulo=self.shape
- )
- net.Gather(
- [self.bucket_w, buckets_int],
- self.output_schema.bucket_weights.field_blobs())
|