quantized.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. from torch import Tensor, _VF # noqa: F401
  2. from torch.nn.utils.rnn import PackedSequence
  3. import torch
  4. import warnings
  5. from typing import List, Optional, Tuple
  6. class QuantizedLinear(torch.jit.ScriptModule):
  7. __constants__ = ['scale', 'zero_point']
  8. def __init__(self, other):
  9. super(QuantizedLinear, self).__init__()
  10. warnings.warn(
  11. "torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming "
  12. "PyTorch release. Please use the torch.nn.quantized.dynamic.Linear instead.")
  13. self.in_features = other.in_features
  14. self.out_features = other.out_features
  15. # Quantize weight and discard the original
  16. self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
  17. other.weight.clone(memory_format=torch.contiguous_format).float())
  18. self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
  19. self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
  20. assert other.bias is not None, 'QuantizedLinear requires a bias'
  21. self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False)
  22. self.register_buffer(
  23. 'packed_tensor_ptr',
  24. torch.fbgemm_pack_quantized_matrix(self.weight.clone(memory_format=torch.contiguous_format)))
  25. @torch.jit.script_method
  26. def _unpack(self):
  27. self.packed_tensor_ptr.set_(
  28. torch.fbgemm_pack_quantized_matrix(self.weight))
  29. @torch.jit.script_method
  30. def _pack(self):
  31. self.packed_tensor_ptr.set_(
  32. torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
  33. @torch.jit.script_method
  34. def forward(self, input):
  35. out = torch.fbgemm_linear_int8_weight_fp32_activation(
  36. input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
  37. self.scale, self.zero_point, self.bias)
  38. return out.to(input.dtype)
  39. def extra_repr(self):
  40. repr = 'in_features={in_features}, out_features={out_features}, ' \
  41. 'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
  42. return repr
  43. # FP16 weights
  44. class QuantizedLinearFP16(torch.jit.ScriptModule):
  45. def __init__(self, other):
  46. super(QuantizedLinearFP16, self).__init__()
  47. warnings.warn(
  48. "torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming "
  49. "PyTorch release. Please use the torch.nn.quantized.dynamic.Linear instead.")
  50. self.in_features = other.in_features
  51. self.out_features = other.out_features
  52. self.original_weight = other.weight
  53. self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
  54. other.weight.clone(memory_format=torch.contiguous_format).float())
  55. assert other.bias is not None, 'QuantizedLinearFP16 requires a bias'
  56. self.bias = torch.nn.Parameter(other.bias.clone(memory_format=torch.contiguous_format).float(), requires_grad=False)
  57. self.register_buffer('packed_weight', self.weight)
  58. @torch.jit.script_method
  59. def _unpack(self):
  60. self.packed_weight.set_(
  61. torch.fbgemm_pack_gemm_matrix_fp16(
  62. self.original_weight))
  63. @torch.jit.script_method
  64. def _pack(self):
  65. self.packed_weight.set_(
  66. torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
  67. @torch.jit.script_method
  68. def forward(self, input):
  69. out = torch.fbgemm_linear_fp16_weight_fp32_activation(
  70. input.float(), self.packed_weight, self.bias)
  71. return out
  72. def extra_repr(self):
  73. repr = 'in_features={in_features}, out_features={out_features}, '.format(**self.__dict__)
  74. return repr
  75. # Quantized RNN cell implementations
  76. class QuantizedRNNCellBase(torch.jit.ScriptModule):
  77. __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
  78. 'zero_point_ih', 'zero_point_hh']
  79. def __init__(self, other):
  80. super(QuantizedRNNCellBase, self).__init__()
  81. warnings.warn(
  82. "torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming "
  83. "PyTorch release. Please use the torch.nn.quantized.dynamic.RNNCell instead.")
  84. self.input_size = other.input_size
  85. self.hidden_size = other.hidden_size
  86. self.bias = other.bias
  87. if not self.bias:
  88. raise ValueError("Quantized RNN cells require bias terms")
  89. weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
  90. torch.fbgemm_linear_quantize_weight(other.weight_ih.clone(memory_format=torch.contiguous_format).float())
  91. self.register_buffer('weight_ih', weight_ih)
  92. self.register_buffer('col_offsets_ih', col_offsets_ih)
  93. weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
  94. torch.fbgemm_linear_quantize_weight(other.weight_hh.clone(memory_format=torch.contiguous_format).float())
  95. self.register_buffer('weight_hh', weight_hh)
  96. self.register_buffer('col_offsets_hh', col_offsets_hh)
  97. packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
  98. self.register_buffer('packed_ih', packed_ih)
  99. packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
  100. self.register_buffer('packed_hh', packed_hh)
  101. self.bias_ih = torch.nn.Parameter(other.bias_ih.clone(memory_format=torch.contiguous_format).float(), requires_grad=False)
  102. self.bias_hh = torch.nn.Parameter(other.bias_hh.clone(memory_format=torch.contiguous_format).float(), requires_grad=False)
  103. def extra_repr(self):
  104. s = '{input_size}, {hidden_size}'
  105. if 'bias' in self.__dict__ and self.bias is not True:
  106. s += ', bias={bias}'
  107. if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
  108. s += ', nonlinearity={nonlinearity}'
  109. return s.format(**self.__dict__)
  110. @torch.jit.script_method
  111. def check_forward_input(self, input):
  112. if input.size(1) != self.input_size:
  113. raise RuntimeError(
  114. "input has inconsistent input_size: got {}, expected {}".format(
  115. input.size(1), self.input_size))
  116. @torch.jit.script_method
  117. def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
  118. if input.size(0) != hx.size(0):
  119. raise RuntimeError(
  120. "Input batch size {} doesn't match hidden{} batch size {}".format(
  121. input.size(0), hidden_label, hx.size(0)))
  122. if hx.size(1) != self.hidden_size:
  123. raise RuntimeError(
  124. "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
  125. hidden_label, hx.size(1), self.hidden_size))
  126. # TODO: for some reason weak_script_method causes a destruction of the
  127. # module to occur, which in turn frees the packed_ih object via its DataPtr
  128. # deleter. This is bizarre and should probably get fixed.
  129. # @torch._jit_internal.weak_script_method
  130. @torch.jit.script_method
  131. def _unpack(self):
  132. self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
  133. self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))
  134. # @torch._jit_internal.weak_script_method
  135. @torch.jit.script_method
  136. def _pack(self):
  137. self.packed_ih.set_(
  138. torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
  139. self.packed_hh.set_(
  140. torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
  141. class QuantizedRNNCell(QuantizedRNNCellBase):
  142. __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
  143. 'zero_point_ih', 'zero_point_hh', 'nonlinearity']
  144. def __init__(self, other):
  145. super(QuantizedRNNCell, self).__init__(other)
  146. warnings.warn(
  147. "torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming "
  148. "PyTorch release. Please use the torch.nn.quantized.dynamic.RNNCell instead.")
  149. self.nonlinearity = other.nonlinearity
  150. @torch.jit.script_method
  151. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  152. self.check_forward_input(input)
  153. if hx is None:
  154. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  155. self.check_forward_hidden(input, hx, '')
  156. if self.nonlinearity == "tanh":
  157. ret = _VF.quantized_rnn_tanh_cell(
  158. input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
  159. self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
  160. self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
  161. self.zero_point_hh
  162. )
  163. elif self.nonlinearity == "relu":
  164. ret = _VF.quantized_rnn_relu_cell(
  165. input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
  166. self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
  167. self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
  168. self.zero_point_hh
  169. )
  170. else:
  171. ret = input # TODO: remove when jit supports exception flow
  172. raise RuntimeError(
  173. "Unknown nonlinearity: {}".format(self.nonlinearity))
  174. return ret
  175. class QuantizedLSTMCell(QuantizedRNNCellBase):
  176. def __init__(self, other):
  177. super(QuantizedLSTMCell, self).__init__(other)
  178. warnings.warn(
  179. "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming "
  180. "PyTorch release. Please use the torch.nn.quantized.dynamic.LSTMCell instead.")
  181. @torch.jit.script_method
  182. def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
  183. self.check_forward_input(input)
  184. if hx is None:
  185. zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  186. hx = (zeros, zeros)
  187. self.check_forward_hidden(input, hx[0], '[0]')
  188. self.check_forward_hidden(input, hx[1], '[1]')
  189. return _VF.quantized_lstm_cell(
  190. input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
  191. self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
  192. self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
  193. self.zero_point_hh
  194. )
  195. class QuantizedGRUCell(QuantizedRNNCellBase):
  196. def __init__(self, other):
  197. super(QuantizedGRUCell, self).__init__(other)
  198. warnings.warn(
  199. "torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming "
  200. "PyTorch release. Please use the torch.nn.quantized.dynamic.GRUCell instead.")
  201. @torch.jit.script_method
  202. def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
  203. self.check_forward_input(input)
  204. if hx is None:
  205. hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
  206. self.check_forward_hidden(input, hx, '')
  207. return _VF.quantized_gru_cell(
  208. input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
  209. self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
  210. self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
  211. self.zero_point_hh
  212. )
  213. def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
  214. return tensor.index_select(dim, permutation)
  215. class QuantizedRNNBase(torch.jit.ScriptModule):
  216. __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
  217. 'batch_first', 'dropout', 'bidirectional', 'dtype']
  218. def __init__(self, other, dtype=torch.int8):
  219. super(QuantizedRNNBase, self).__init__()
  220. warnings.warn(
  221. "torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming "
  222. "PyTorch release. Please use the torch.nn.quantized.dynamic instead.")
  223. self.mode = other.mode
  224. self.input_size = other.input_size
  225. self.hidden_size = other.hidden_size
  226. self.num_layers = other.num_layers
  227. self.bias = other.bias
  228. self.batch_first = other.batch_first
  229. if self.mode != 'GRU':
  230. assert not self.batch_first
  231. self.dropout = other.dropout
  232. self.bidirectional = other.bidirectional
  233. num_directions = 2 if self.bidirectional else 1
  234. self.dtype = dtype
  235. assert self.bias
  236. # TODO: support more than just LSTM
  237. if self.mode != 'LSTM' and self.mode != 'GRU':
  238. raise RuntimeError('Only LSTM or GRU is supported for QuantizedRNN')
  239. if dtype != torch.int8 and dtype != torch.float16:
  240. raise RuntimeError('Unsupported dtype: {}'.format(dtype))
  241. self.all_weights = []
  242. for layer in range(self.num_layers):
  243. for direction in range(num_directions):
  244. layer_input_size = self.input_size if layer == 0 else self.hidden_size * num_directions
  245. suffix = '_reverse' if direction == 1 else ''
  246. def get_weight_bias(ihhh):
  247. weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
  248. bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
  249. weight = getattr(other, weight_name)
  250. bias = getattr(other, bias_name)
  251. return weight, bias
  252. weight_ih, bias_ih = get_weight_bias('ih')
  253. weight_hh, bias_hh = get_weight_bias('hh')
  254. if dtype == torch.int8:
  255. cell_params = torch.ops.quantized.make_quantized_cell_params(
  256. weight_ih, weight_hh, bias_ih, bias_hh)
  257. else:
  258. packed_ih = torch.ops.quantized.linear_prepack_fp16(
  259. weight_ih.float(), bias_ih)
  260. packed_hh = torch.ops.quantized.linear_prepack_fp16(
  261. weight_hh.float(), bias_hh)
  262. cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
  263. packed_ih, packed_hh)
  264. setattr(self, 'cell_params_{}_{}'.format(layer, suffix), cell_params)
  265. self.all_weights.append(cell_params)
  266. @torch.jit.script_method
  267. def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
  268. expected_input_dim = 2 if batch_sizes is not None else 3
  269. if input.dim() != expected_input_dim:
  270. raise RuntimeError(
  271. 'input must have {} dimensions, got {}'.format(
  272. expected_input_dim, input.dim()))
  273. if self.input_size != input.size(-1):
  274. raise RuntimeError(
  275. 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
  276. self.input_size, input.size(-1)))
  277. @torch.jit.script_method
  278. def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
  279. if batch_sizes is not None:
  280. mini_batch = int(batch_sizes[0])
  281. else:
  282. mini_batch = input.size(0) if self.batch_first else input.size(1)
  283. num_directions = 2 if self.bidirectional else 1
  284. expected_hidden_size = (self.num_layers * num_directions,
  285. mini_batch, self.hidden_size)
  286. return expected_hidden_size
  287. @torch.jit.script_method
  288. def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
  289. msg: str = 'Expected hidden size {}, got {}') -> None:
  290. if hx.size() != expected_hidden_size:
  291. raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
  292. @torch.jit.script_method
  293. def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
  294. self.check_input(input, batch_sizes)
  295. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  296. self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}')
  297. @torch.jit.script_method
  298. def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
  299. if permutation is None:
  300. return hx
  301. return apply_permutation(hx, permutation)
  302. class QuantizedLSTM(QuantizedRNNBase):
  303. __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
  304. def __init__(self, other, dtype):
  305. super(QuantizedLSTM, self).__init__(other, dtype)
  306. warnings.warn(
  307. "torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming "
  308. "PyTorch release. Please use the torch.nn.quantized.dynamic.LSTM instead.")
  309. @torch.jit.script_method
  310. def forward_impl(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]], batch_sizes: Optional[Tensor],
  311. max_batch_size: int, sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  312. if hx is None:
  313. num_directions = 2 if self.bidirectional else 1
  314. zeros = torch.zeros(self.num_layers * num_directions,
  315. max_batch_size, self.hidden_size,
  316. dtype=input.dtype, device=input.device)
  317. hx = (zeros, zeros)
  318. else:
  319. # Each batch of the hidden state should match the input sequence that
  320. # the user believes he/she is passing in.
  321. hx = self.permute_hidden(hx, sorted_indices)
  322. self.check_forward_args(input, hx, batch_sizes)
  323. assert batch_sizes is None
  324. result = torch.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers,
  325. float(self.dropout), self.training, self.bidirectional,
  326. self.batch_first, dtype=self.dtype, use_dynamic=False)
  327. output = result[0]
  328. hidden = result[1:]
  329. return output, hidden
  330. @torch.jit.script_method
  331. def forward_tensor(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
  332. batch_sizes = None
  333. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  334. sorted_indices = None
  335. unsorted_indices = None
  336. output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
  337. return output, self.permute_hidden(hidden, unsorted_indices)
  338. @torch.jit.script_method
  339. def forward_packed(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
  340. ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
  341. input, batch_sizes, sorted_indices, unsorted_indices = input
  342. max_batch_size = batch_sizes[0]
  343. max_batch_size = int(max_batch_size)
  344. output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
  345. output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  346. return output, self.permute_hidden(hidden, unsorted_indices)
  347. @torch.jit.script_method
  348. def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
  349. if permutation is None:
  350. return hx
  351. return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
  352. @torch.jit.script_method
  353. def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]) -> None:
  354. self.check_input(input, batch_sizes)
  355. expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
  356. self.check_hidden_size(hidden[0], expected_hidden_size,
  357. 'Expected hidden[0] size {}, got {}')
  358. self.check_hidden_size(hidden[1], expected_hidden_size,
  359. 'Expected hidden[1] size {}, got {}')
  360. def forward(self, input, hx=None):
  361. if isinstance(input, PackedSequence):
  362. return self.forward_packed(input, hx)
  363. else:
  364. return self.forward_tensor(input, hx)
  365. class QuantizedGRU(QuantizedRNNBase):
  366. __overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
  367. def __init__(self, *args, **kwargs):
  368. super().__init__(*args, **kwargs)
  369. warnings.warn(
  370. "torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming "
  371. "PyTorch release. Please use the torch.nn.quantized.dynamic.GRU instead.")
  372. @torch.jit.script_method
  373. def forward_impl(self, input: Tensor, hx: Optional[Tensor], batch_sizes: Optional[Tensor], max_batch_size: int,
  374. sorted_indices: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
  375. if hx is None:
  376. num_directions = 2 if self.bidirectional else 1
  377. hx = torch.zeros(self.num_layers * num_directions,
  378. max_batch_size, self.hidden_size,
  379. dtype=input.dtype, device=input.device)
  380. else:
  381. # Each batch of the hidden state should match the input sequence that
  382. # the user believes he/she is passing in.
  383. hx = self.permute_hidden(hx, sorted_indices)
  384. self.check_forward_args(input, hx, batch_sizes)
  385. if batch_sizes is None:
  386. result = torch.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers,
  387. float(self.dropout), self.training, self.bidirectional,
  388. self.batch_first)
  389. else:
  390. result = torch.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers,
  391. float(self.dropout), self.training, self.bidirectional)
  392. output = result[0]
  393. hidden = result[1]
  394. return output, hidden
  395. @torch.jit.script_method
  396. def forward_tensor(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  397. batch_sizes = None
  398. max_batch_size = input.size(0) if self.batch_first else input.size(1)
  399. sorted_indices = None
  400. unsorted_indices = None
  401. output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
  402. return output, self.permute_hidden(hidden, unsorted_indices)
  403. @torch.jit.script_method
  404. def forward_packed(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]:
  405. input, batch_sizes, sorted_indices, unsorted_indices = input
  406. max_batch_size = batch_sizes[0]
  407. max_batch_size = int(max_batch_size)
  408. output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
  409. output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
  410. return output, self.permute_hidden(hidden, unsorted_indices)
  411. def forward(self, input, hx=None):
  412. if isinstance(input, PackedSequence):
  413. return self.forward_packed(input, hx)
  414. else:
  415. return self.forward_tensor(input, hx)
  416. def quantize_rnn_cell_modules(module):
  417. warnings.warn("quantize_rnn_cell_modules function has been deprecated. "
  418. "Please use torch.ao.quantization.quantize_dynamic API instead.")
  419. reassign = {}
  420. for name, mod in module.named_modules():
  421. if mod is module:
  422. continue
  423. new_mod = quantize_rnn_cell_modules(mod)
  424. if new_mod is not mod:
  425. reassign[name] = new_mod
  426. for name, mod in reassign.items():
  427. setattr(module, name, mod)
  428. if isinstance(module, torch.nn.LSTMCell):
  429. return QuantizedLSTMCell(module)
  430. if isinstance(module, torch.nn.GRUCell):
  431. return QuantizedGRUCell(module)
  432. if isinstance(module, torch.nn.RNNCell):
  433. return QuantizedRNNCell(module)
  434. return module
  435. def quantize_linear_modules(module, dtype=torch.int8):
  436. warnings.warn("quantize_linear_modules function has been deprecated. "
  437. "Please use torch.ao.quantization.quantize_dynamic API instead.")
  438. reassign = {}
  439. for name, mod in module.named_modules():
  440. if mod is module:
  441. continue
  442. new_mod = quantize_linear_modules(mod, dtype)
  443. if new_mod is not mod:
  444. reassign[name] = new_mod
  445. for name, mod in reassign.items():
  446. setattr(module, name, mod)
  447. if isinstance(module, torch.nn.Linear):
  448. if dtype == torch.int8:
  449. return QuantizedLinear(module)
  450. elif dtype == torch.float16:
  451. return QuantizedLinearFP16(module)
  452. else:
  453. raise RuntimeError(
  454. "Unsupported dtype: {}".format(dtype))
  455. return module
  456. def quantize_rnn_modules(module, dtype=torch.int8):
  457. warnings.warn("quantize_rnn_modules function has been deprecated. "
  458. "Please use torch.ao.quantization.quantize_dynamic API instead.")
  459. reassign = {}
  460. for name, mod in module.named_modules():
  461. if mod is module:
  462. continue
  463. new_mod = quantize_rnn_modules(mod, dtype)
  464. if new_mod is not mod:
  465. reassign[name] = new_mod
  466. for name, mod in reassign.items():
  467. setattr(module, name, mod)
  468. if isinstance(module, torch.nn.LSTM):
  469. if dtype != torch.int8 and dtype != torch.float16:
  470. raise RuntimeError("Unsupported dtype: {}".format(dtype))
  471. return QuantizedLSTM(module, dtype)
  472. if isinstance(module, torch.nn.GRU):
  473. return QuantizedGRU(module)
  474. return module