| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- # @package adaptive_weight
- # Module caffe2.fb.python.layers.adaptive_weight
- import numpy as np
- from caffe2.python import core, schema
- from caffe2.python.layers.layers import ModelLayer
- from caffe2.python.regularizer import BoundedGradientProjection, LogBarrier
- """
- Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf
- """
- class AdaptiveWeight(ModelLayer):
- def __init__(
- self,
- model,
- input_record,
- name="adaptive_weight",
- optimizer=None,
- weights=None,
- enable_diagnose=False,
- estimation_method="log_std",
- pos_optim_method="log_barrier",
- reg_lambda=0.1,
- **kwargs
- ):
- super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs)
- self.output_schema = schema.Scalar(
- np.float32, self.get_next_blob_reference("adaptive_weight")
- )
- self.data = self.input_record.field_blobs()
- self.num = len(self.data)
- self.optimizer = optimizer
- if weights is not None:
- assert len(weights) == self.num
- else:
- weights = [1. / self.num for _ in range(self.num)]
- assert min(weights) > 0, "initial weights must be positive"
- self.weights = np.array(weights).astype(np.float32)
- self.estimation_method = str(estimation_method).lower()
- # used in positivity-constrained parameterization as when the estimation method
- # is inv_var, with optimization method being either log barrier, or grad proj
- self.pos_optim_method = str(pos_optim_method).lower()
- self.reg_lambda = float(reg_lambda)
- self.enable_diagnose = enable_diagnose
- self.init_func = getattr(self, self.estimation_method + "_init")
- self.weight_func = getattr(self, self.estimation_method + "_weight")
- self.reg_func = getattr(self, self.estimation_method + "_reg")
- self.init_func()
- if self.enable_diagnose:
- self.weight_i = [
- self.get_next_blob_reference("adaptive_weight_%d" % i)
- for i in range(self.num)
- ]
- for i in range(self.num):
- self.model.add_ad_hoc_plot_blob(self.weight_i[i])
- def concat_data(self, net):
- reshaped = [net.NextScopedBlob("reshaped_data_%d" % i) for i in range(self.num)]
- # coerce shape for single real values
- for i in range(self.num):
- net.Reshape(
- [self.data[i]],
- [reshaped[i], net.NextScopedBlob("new_shape_%d" % i)],
- shape=[1],
- )
- concated = net.NextScopedBlob("concated_data")
- net.Concat(
- reshaped, [concated, net.NextScopedBlob("concated_new_shape")], axis=0
- )
- return concated
- def log_std_init(self):
- """
- mu = 2 log sigma, sigma = standard variance
- per task objective:
- min 1 / 2 / e^mu X + mu / 2
- """
- values = np.log(1. / 2. / self.weights)
- initializer = (
- "GivenTensorFill",
- {"values": values, "dtype": core.DataType.FLOAT},
- )
- self.mu = self.create_param(
- param_name="mu",
- shape=[self.num],
- initializer=initializer,
- optimizer=self.optimizer,
- )
- def log_std_weight(self, x, net, weight):
- """
- min 1 / 2 / e^mu X + mu / 2
- """
- mu_neg = net.NextScopedBlob("mu_neg")
- net.Negative(self.mu, mu_neg)
- mu_neg_exp = net.NextScopedBlob("mu_neg_exp")
- net.Exp(mu_neg, mu_neg_exp)
- net.Scale(mu_neg_exp, weight, scale=0.5)
- def log_std_reg(self, net, reg):
- net.Scale(self.mu, reg, scale=0.5)
- def inv_var_init(self):
- """
- k = 1 / variance
- per task objective:
- min 1 / 2 * k X - 1 / 2 * log k
- """
- values = 2. * self.weights
- initializer = (
- "GivenTensorFill",
- {"values": values, "dtype": core.DataType.FLOAT},
- )
- if self.pos_optim_method == "log_barrier":
- regularizer = LogBarrier(reg_lambda=self.reg_lambda)
- elif self.pos_optim_method == "pos_grad_proj":
- regularizer = BoundedGradientProjection(lb=0, left_open=True)
- else:
- raise TypeError(
- "unknown positivity optimization method: {}".format(
- self.pos_optim_method
- )
- )
- self.k = self.create_param(
- param_name="k",
- shape=[self.num],
- initializer=initializer,
- optimizer=self.optimizer,
- regularizer=regularizer,
- )
- def inv_var_weight(self, x, net, weight):
- net.Scale(self.k, weight, scale=0.5)
- def inv_var_reg(self, net, reg):
- log_k = net.NextScopedBlob("log_k")
- net.Log(self.k, log_k)
- net.Scale(log_k, reg, scale=-0.5)
- def _add_ops_impl(self, net, enable_diagnose):
- x = self.concat_data(net)
- weight = net.NextScopedBlob("weight")
- reg = net.NextScopedBlob("reg")
- weighted_x = net.NextScopedBlob("weighted_x")
- weighted_x_add_reg = net.NextScopedBlob("weighted_x_add_reg")
- self.weight_func(x, net, weight)
- self.reg_func(net, reg)
- net.Mul([weight, x], weighted_x)
- net.Add([weighted_x, reg], weighted_x_add_reg)
- net.SumElements(weighted_x_add_reg, self.output_schema())
- if enable_diagnose:
- for i in range(self.num):
- net.Slice(weight, self.weight_i[i], starts=[i], ends=[i + 1])
- def add_ops(self, net):
- self._add_ops_impl(net, self.enable_diagnose)
|