sparse.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. from typing import Optional
  2. import torch
  3. from torch import Tensor
  4. from torch.nn.parameter import Parameter
  5. from .module import Module
  6. from .. import functional as F
  7. from .. import init
  8. class Embedding(Module):
  9. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  10. This module is often used to store word embeddings and retrieve them using indices.
  11. The input to the module is a list of indices, and the output is the corresponding
  12. word embeddings.
  13. Args:
  14. num_embeddings (int): size of the dictionary of embeddings
  15. embedding_dim (int): the size of each embedding vector
  16. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  17. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  18. i.e. it remains as a fixed "pad". For a newly constructed Embedding,
  19. the embedding vector at :attr:`padding_idx` will default to all zeros,
  20. but can be updated to another value to be used as the padding vector.
  21. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  22. is renormalized to have norm :attr:`max_norm`.
  23. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  24. scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
  25. the words in the mini-batch. Default ``False``.
  26. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
  27. See Notes for more details regarding sparse gradients.
  28. Attributes:
  29. weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
  30. initialized from :math:`\mathcal{N}(0, 1)`
  31. Shape:
  32. - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
  33. - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
  34. .. note::
  35. Keep in mind that only a limited number of optimizers support
  36. sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
  37. :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
  38. .. note::
  39. When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
  40. :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
  41. modified in-place, performing a differentiable operation on ``Embedding.weight`` before
  42. calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
  43. :attr:`max_norm` is not ``None``. For example::
  44. n, d, m = 3, 5, 7
  45. embedding = nn.Embedding(n, d, max_norm=True)
  46. W = torch.randn((m, d), requires_grad=True)
  47. idx = torch.tensor([1, 2])
  48. a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
  49. b = embedding(idx) @ W.t() # modifies weight in-place
  50. out = (a.unsqueeze(0) + b.unsqueeze(1))
  51. loss = out.sigmoid().prod()
  52. loss.backward()
  53. Examples::
  54. >>> # an Embedding module containing 10 tensors of size 3
  55. >>> embedding = nn.Embedding(10, 3)
  56. >>> # a batch of 2 samples of 4 indices each
  57. >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
  58. >>> embedding(input)
  59. tensor([[[-0.0251, -1.6902, 0.7172],
  60. [-0.6431, 0.0748, 0.6969],
  61. [ 1.4970, 1.3448, -0.9685],
  62. [-0.3677, -2.7265, -0.1685]],
  63. [[ 1.4970, 1.3448, -0.9685],
  64. [ 0.4362, -0.4004, 0.9400],
  65. [-0.6431, 0.0748, 0.6969],
  66. [ 0.9124, -2.3616, 1.1151]]])
  67. >>> # example with padding_idx
  68. >>> embedding = nn.Embedding(10, 3, padding_idx=0)
  69. >>> input = torch.LongTensor([[0,2,0,5]])
  70. >>> embedding(input)
  71. tensor([[[ 0.0000, 0.0000, 0.0000],
  72. [ 0.1535, -2.0309, 0.9315],
  73. [ 0.0000, 0.0000, 0.0000],
  74. [-0.1655, 0.9897, 0.0635]]])
  75. >>> # example of changing `pad` vector
  76. >>> padding_idx = 0
  77. >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
  78. >>> embedding.weight
  79. Parameter containing:
  80. tensor([[ 0.0000, 0.0000, 0.0000],
  81. [-0.7895, -0.7089, -0.0364],
  82. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  83. >>> with torch.no_grad():
  84. ... embedding.weight[padding_idx] = torch.ones(3)
  85. >>> embedding.weight
  86. Parameter containing:
  87. tensor([[ 1.0000, 1.0000, 1.0000],
  88. [-0.7895, -0.7089, -0.0364],
  89. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  90. """
  91. __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
  92. 'norm_type', 'scale_grad_by_freq', 'sparse']
  93. num_embeddings: int
  94. embedding_dim: int
  95. padding_idx: Optional[int]
  96. max_norm: Optional[float]
  97. norm_type: float
  98. scale_grad_by_freq: bool
  99. weight: Tensor
  100. sparse: bool
  101. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
  102. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  103. sparse: bool = False, _weight: Optional[Tensor] = None,
  104. device=None, dtype=None) -> None:
  105. factory_kwargs = {'device': device, 'dtype': dtype}
  106. super(Embedding, self).__init__()
  107. self.num_embeddings = num_embeddings
  108. self.embedding_dim = embedding_dim
  109. if padding_idx is not None:
  110. if padding_idx > 0:
  111. assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
  112. elif padding_idx < 0:
  113. assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
  114. padding_idx = self.num_embeddings + padding_idx
  115. self.padding_idx = padding_idx
  116. self.max_norm = max_norm
  117. self.norm_type = norm_type
  118. self.scale_grad_by_freq = scale_grad_by_freq
  119. if _weight is None:
  120. self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
  121. self.reset_parameters()
  122. else:
  123. assert list(_weight.shape) == [num_embeddings, embedding_dim], \
  124. 'Shape of weight does not match num_embeddings and embedding_dim'
  125. self.weight = Parameter(_weight)
  126. self.sparse = sparse
  127. def reset_parameters(self) -> None:
  128. init.normal_(self.weight)
  129. self._fill_padding_idx_with_zero()
  130. def _fill_padding_idx_with_zero(self) -> None:
  131. if self.padding_idx is not None:
  132. with torch.no_grad():
  133. self.weight[self.padding_idx].fill_(0)
  134. def forward(self, input: Tensor) -> Tensor:
  135. return F.embedding(
  136. input, self.weight, self.padding_idx, self.max_norm,
  137. self.norm_type, self.scale_grad_by_freq, self.sparse)
  138. def extra_repr(self) -> str:
  139. s = '{num_embeddings}, {embedding_dim}'
  140. if self.padding_idx is not None:
  141. s += ', padding_idx={padding_idx}'
  142. if self.max_norm is not None:
  143. s += ', max_norm={max_norm}'
  144. if self.norm_type != 2:
  145. s += ', norm_type={norm_type}'
  146. if self.scale_grad_by_freq is not False:
  147. s += ', scale_grad_by_freq={scale_grad_by_freq}'
  148. if self.sparse is not False:
  149. s += ', sparse=True'
  150. return s.format(**self.__dict__)
  151. @classmethod
  152. def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
  153. max_norm=None, norm_type=2., scale_grad_by_freq=False,
  154. sparse=False):
  155. r"""Creates Embedding instance from given 2-dimensional FloatTensor.
  156. Args:
  157. embeddings (Tensor): FloatTensor containing weights for the Embedding.
  158. First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
  159. freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
  160. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
  161. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  162. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  163. i.e. it remains as a fixed "pad".
  164. max_norm (float, optional): See module initialization documentation.
  165. norm_type (float, optional): See module initialization documentation. Default ``2``.
  166. scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
  167. sparse (bool, optional): See module initialization documentation.
  168. Examples::
  169. >>> # FloatTensor containing pretrained weights
  170. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  171. >>> embedding = nn.Embedding.from_pretrained(weight)
  172. >>> # Get embeddings for index 1
  173. >>> input = torch.LongTensor([1])
  174. >>> embedding(input)
  175. tensor([[ 4.0000, 5.1000, 6.3000]])
  176. """
  177. assert embeddings.dim() == 2, \
  178. 'Embeddings parameter is expected to be 2-dimensional'
  179. rows, cols = embeddings.shape
  180. embedding = cls(
  181. num_embeddings=rows,
  182. embedding_dim=cols,
  183. _weight=embeddings,
  184. padding_idx=padding_idx,
  185. max_norm=max_norm,
  186. norm_type=norm_type,
  187. scale_grad_by_freq=scale_grad_by_freq,
  188. sparse=sparse)
  189. embedding.weight.requires_grad = not freeze
  190. return embedding
  191. class EmbeddingBag(Module):
  192. r"""Computes sums or means of 'bags' of embeddings, without instantiating the
  193. intermediate embeddings.
  194. For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`,
  195. and with 2D inputs, this class
  196. * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``,
  197. * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``,
  198. * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``.
  199. However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
  200. operations.
  201. EmbeddingBag also supports per-sample weights as an argument to the forward
  202. pass. This scales the output of the Embedding before performing a weighted
  203. reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
  204. only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
  205. :attr:`per_sample_weights`.
  206. Args:
  207. num_embeddings (int): size of the dictionary of embeddings
  208. embedding_dim (int): the size of each embedding vector
  209. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  210. is renormalized to have norm :attr:`max_norm`.
  211. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  212. scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the inverse of frequency of
  213. the words in the mini-batch. Default ``False``.
  214. Note: this option is not supported when ``mode="max"``.
  215. mode (string, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
  216. ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
  217. into consideration. ``"mean"`` computes the average of the values
  218. in the bag, ``"max"`` computes the max value over each bag.
  219. Default: ``"mean"``
  220. sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
  221. Notes for more details regarding sparse gradients. Note: this option is not
  222. supported when ``mode="max"``.
  223. include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element
  224. is equivalent to the size of `indices`. This matches the CSR format.
  225. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
  226. gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
  227. during training, i.e. it remains as a fixed "pad". For a newly constructed
  228. EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all
  229. zeros, but can be updated to another value to be used as the padding vector.
  230. Note that the embedding vector at :attr:`padding_idx` is excluded from the
  231. reduction.
  232. Attributes:
  233. weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
  234. initialized from :math:`\mathcal{N}(0, 1)`.
  235. Examples::
  236. >>> # an EmbeddingBag module containing 10 tensors of size 3
  237. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
  238. >>> # a batch of 2 samples of 4 indices each
  239. >>> input = torch.tensor([1,2,4,5,4,3,2,9], dtype=torch.long)
  240. >>> offsets = torch.tensor([0,4], dtype=torch.long)
  241. >>> embedding_sum(input, offsets)
  242. tensor([[-0.8861, -5.4350, -0.0523],
  243. [ 1.1306, -2.5798, -1.0044]])
  244. >>> # Example with padding_idx
  245. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
  246. >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
  247. >>> offsets = torch.tensor([0,4], dtype=torch.long)
  248. >>> embedding_sum(input, offsets)
  249. tensor([[ 0.0000, 0.0000, 0.0000],
  250. [-0.7082, 3.2145, -2.6251]])
  251. >>> # An EmbeddingBag can be loaded from an Embedding like so
  252. >>> embedding = nn.Embedding(10, 3, padding_idx=2)
  253. >>> embedding_sum = nn.EmbeddingBag.from_pretrained(
  254. embedding.weight,
  255. padding_idx=embedding.padding_idx,
  256. mode='sum')
  257. """
  258. __constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type',
  259. 'scale_grad_by_freq', 'mode', 'sparse', 'include_last_offset',
  260. 'padding_idx']
  261. num_embeddings: int
  262. embedding_dim: int
  263. max_norm: Optional[float]
  264. norm_type: float
  265. scale_grad_by_freq: bool
  266. weight: Tensor
  267. mode: str
  268. sparse: bool
  269. include_last_offset: bool
  270. padding_idx: Optional[int]
  271. def __init__(self, num_embeddings: int, embedding_dim: int,
  272. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  273. mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
  274. include_last_offset: bool = False, padding_idx: Optional[int] = None,
  275. device=None, dtype=None) -> None:
  276. factory_kwargs = {'device': device, 'dtype': dtype}
  277. super(EmbeddingBag, self).__init__()
  278. self.num_embeddings = num_embeddings
  279. self.embedding_dim = embedding_dim
  280. self.max_norm = max_norm
  281. self.norm_type = norm_type
  282. self.scale_grad_by_freq = scale_grad_by_freq
  283. if padding_idx is not None:
  284. if padding_idx > 0:
  285. assert padding_idx < self.num_embeddings, 'padding_idx must be within num_embeddings'
  286. elif padding_idx < 0:
  287. assert padding_idx >= -self.num_embeddings, 'padding_idx must be within num_embeddings'
  288. padding_idx = self.num_embeddings + padding_idx
  289. self.padding_idx = padding_idx
  290. if _weight is None:
  291. self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
  292. self.reset_parameters()
  293. else:
  294. assert list(_weight.shape) == [num_embeddings, embedding_dim], \
  295. 'Shape of weight does not match num_embeddings and embedding_dim'
  296. self.weight = Parameter(_weight)
  297. self.mode = mode
  298. self.sparse = sparse
  299. self.include_last_offset = include_last_offset
  300. def reset_parameters(self) -> None:
  301. init.normal_(self.weight)
  302. self._fill_padding_idx_with_zero()
  303. def _fill_padding_idx_with_zero(self) -> None:
  304. if self.padding_idx is not None:
  305. with torch.no_grad():
  306. self.weight[self.padding_idx].fill_(0)
  307. def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
  308. """Forward pass of EmbeddingBag.
  309. Args:
  310. input (Tensor): Tensor containing bags of indices into the embedding matrix.
  311. offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
  312. the starting index position of each bag (sequence) in :attr:`input`.
  313. per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
  314. to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
  315. must have exactly the same shape as input and is treated as having the same
  316. :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
  317. Returns:
  318. Tensor output shape of `(B, embedding_dim)`.
  319. .. note::
  320. A few notes about ``input`` and ``offsets``:
  321. - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
  322. - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
  323. each of fixed length ``N``, and this will return ``B`` values aggregated in a way
  324. depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
  325. - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
  326. multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the
  327. starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
  328. :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
  329. returned vectors filled by zeros.
  330. """
  331. return F.embedding_bag(input, self.weight, offsets,
  332. self.max_norm, self.norm_type,
  333. self.scale_grad_by_freq, self.mode, self.sparse,
  334. per_sample_weights, self.include_last_offset,
  335. self.padding_idx)
  336. def extra_repr(self) -> str:
  337. s = '{num_embeddings}, {embedding_dim}'
  338. if self.max_norm is not None:
  339. s += ', max_norm={max_norm}'
  340. if self.norm_type != 2:
  341. s += ', norm_type={norm_type}'
  342. if self.scale_grad_by_freq is not False:
  343. s += ', scale_grad_by_freq={scale_grad_by_freq}'
  344. s += ', mode={mode}'
  345. if self.padding_idx is not None:
  346. s += ', padding_idx={padding_idx}'
  347. return s.format(**self.__dict__)
  348. @classmethod
  349. def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Optional[float] = None,
  350. norm_type: float = 2., scale_grad_by_freq: bool = False,
  351. mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False,
  352. padding_idx: Optional[int] = None) -> 'EmbeddingBag':
  353. r"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor.
  354. Args:
  355. embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
  356. First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
  357. freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
  358. Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
  359. max_norm (float, optional): See module initialization documentation. Default: ``None``
  360. norm_type (float, optional): See module initialization documentation. Default ``2``.
  361. scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
  362. mode (string, optional): See module initialization documentation. Default: ``"mean"``
  363. sparse (bool, optional): See module initialization documentation. Default: ``False``.
  364. include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
  365. padding_idx (int, optional): See module initialization documentation. Default: ``None``.
  366. Examples::
  367. >>> # FloatTensor containing pretrained weights
  368. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  369. >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
  370. >>> # Get embeddings for index 1
  371. >>> input = torch.LongTensor([[1, 0]])
  372. >>> embeddingbag(input)
  373. tensor([[ 2.5000, 3.7000, 4.6500]])
  374. """
  375. assert embeddings.dim() == 2, \
  376. 'Embeddings parameter is expected to be 2-dimensional'
  377. rows, cols = embeddings.shape
  378. embeddingbag = cls(
  379. num_embeddings=rows,
  380. embedding_dim=cols,
  381. _weight=embeddings,
  382. max_norm=max_norm,
  383. norm_type=norm_type,
  384. scale_grad_by_freq=scale_grad_by_freq,
  385. mode=mode,
  386. sparse=sparse,
  387. include_last_offset=include_last_offset,
  388. padding_idx=padding_idx)
  389. embeddingbag.weight.requires_grad = not freeze
  390. return embeddingbag