regnet.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532
  1. import math
  2. from collections import OrderedDict
  3. from functools import partial
  4. from typing import Any, Callable, Dict, List, Optional, Tuple
  5. import torch
  6. from torch import nn, Tensor
  7. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
  8. from ..transforms._presets import ImageClassification, InterpolationMode
  9. from ..utils import _log_api_usage_once
  10. from ._api import WeightsEnum, Weights
  11. from ._meta import _IMAGENET_CATEGORIES
  12. from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
  13. __all__ = [
  14. "RegNet",
  15. "RegNet_Y_400MF_Weights",
  16. "RegNet_Y_800MF_Weights",
  17. "RegNet_Y_1_6GF_Weights",
  18. "RegNet_Y_3_2GF_Weights",
  19. "RegNet_Y_8GF_Weights",
  20. "RegNet_Y_16GF_Weights",
  21. "RegNet_Y_32GF_Weights",
  22. "RegNet_Y_128GF_Weights",
  23. "RegNet_X_400MF_Weights",
  24. "RegNet_X_800MF_Weights",
  25. "RegNet_X_1_6GF_Weights",
  26. "RegNet_X_3_2GF_Weights",
  27. "RegNet_X_8GF_Weights",
  28. "RegNet_X_16GF_Weights",
  29. "RegNet_X_32GF_Weights",
  30. "regnet_y_400mf",
  31. "regnet_y_800mf",
  32. "regnet_y_1_6gf",
  33. "regnet_y_3_2gf",
  34. "regnet_y_8gf",
  35. "regnet_y_16gf",
  36. "regnet_y_32gf",
  37. "regnet_y_128gf",
  38. "regnet_x_400mf",
  39. "regnet_x_800mf",
  40. "regnet_x_1_6gf",
  41. "regnet_x_3_2gf",
  42. "regnet_x_8gf",
  43. "regnet_x_16gf",
  44. "regnet_x_32gf",
  45. ]
  46. class SimpleStemIN(Conv2dNormActivation):
  47. """Simple stem for ImageNet: 3x3, BN, ReLU."""
  48. def __init__(
  49. self,
  50. width_in: int,
  51. width_out: int,
  52. norm_layer: Callable[..., nn.Module],
  53. activation_layer: Callable[..., nn.Module],
  54. ) -> None:
  55. super().__init__(
  56. width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
  57. )
  58. class BottleneckTransform(nn.Sequential):
  59. """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
  60. def __init__(
  61. self,
  62. width_in: int,
  63. width_out: int,
  64. stride: int,
  65. norm_layer: Callable[..., nn.Module],
  66. activation_layer: Callable[..., nn.Module],
  67. group_width: int,
  68. bottleneck_multiplier: float,
  69. se_ratio: Optional[float],
  70. ) -> None:
  71. layers: OrderedDict[str, nn.Module] = OrderedDict()
  72. w_b = int(round(width_out * bottleneck_multiplier))
  73. g = w_b // group_width
  74. layers["a"] = Conv2dNormActivation(
  75. width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
  76. )
  77. layers["b"] = Conv2dNormActivation(
  78. w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
  79. )
  80. if se_ratio:
  81. # The SE reduction ratio is defined with respect to the
  82. # beginning of the block
  83. width_se_out = int(round(se_ratio * width_in))
  84. layers["se"] = SqueezeExcitation(
  85. input_channels=w_b,
  86. squeeze_channels=width_se_out,
  87. activation=activation_layer,
  88. )
  89. layers["c"] = Conv2dNormActivation(
  90. w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
  91. )
  92. super().__init__(layers)
  93. class ResBottleneckBlock(nn.Module):
  94. """Residual bottleneck block: x + F(x), F = bottleneck transform."""
  95. def __init__(
  96. self,
  97. width_in: int,
  98. width_out: int,
  99. stride: int,
  100. norm_layer: Callable[..., nn.Module],
  101. activation_layer: Callable[..., nn.Module],
  102. group_width: int = 1,
  103. bottleneck_multiplier: float = 1.0,
  104. se_ratio: Optional[float] = None,
  105. ) -> None:
  106. super().__init__()
  107. # Use skip connection with projection if shape changes
  108. self.proj = None
  109. should_proj = (width_in != width_out) or (stride != 1)
  110. if should_proj:
  111. self.proj = Conv2dNormActivation(
  112. width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
  113. )
  114. self.f = BottleneckTransform(
  115. width_in,
  116. width_out,
  117. stride,
  118. norm_layer,
  119. activation_layer,
  120. group_width,
  121. bottleneck_multiplier,
  122. se_ratio,
  123. )
  124. self.activation = activation_layer(inplace=True)
  125. def forward(self, x: Tensor) -> Tensor:
  126. if self.proj is not None:
  127. x = self.proj(x) + self.f(x)
  128. else:
  129. x = x + self.f(x)
  130. return self.activation(x)
  131. class AnyStage(nn.Sequential):
  132. """AnyNet stage (sequence of blocks w/ the same output shape)."""
  133. def __init__(
  134. self,
  135. width_in: int,
  136. width_out: int,
  137. stride: int,
  138. depth: int,
  139. block_constructor: Callable[..., nn.Module],
  140. norm_layer: Callable[..., nn.Module],
  141. activation_layer: Callable[..., nn.Module],
  142. group_width: int,
  143. bottleneck_multiplier: float,
  144. se_ratio: Optional[float] = None,
  145. stage_index: int = 0,
  146. ) -> None:
  147. super().__init__()
  148. for i in range(depth):
  149. block = block_constructor(
  150. width_in if i == 0 else width_out,
  151. width_out,
  152. stride if i == 0 else 1,
  153. norm_layer,
  154. activation_layer,
  155. group_width,
  156. bottleneck_multiplier,
  157. se_ratio,
  158. )
  159. self.add_module(f"block{stage_index}-{i}", block)
  160. class BlockParams:
  161. def __init__(
  162. self,
  163. depths: List[int],
  164. widths: List[int],
  165. group_widths: List[int],
  166. bottleneck_multipliers: List[float],
  167. strides: List[int],
  168. se_ratio: Optional[float] = None,
  169. ) -> None:
  170. self.depths = depths
  171. self.widths = widths
  172. self.group_widths = group_widths
  173. self.bottleneck_multipliers = bottleneck_multipliers
  174. self.strides = strides
  175. self.se_ratio = se_ratio
  176. @classmethod
  177. def from_init_params(
  178. cls,
  179. depth: int,
  180. w_0: int,
  181. w_a: float,
  182. w_m: float,
  183. group_width: int,
  184. bottleneck_multiplier: float = 1.0,
  185. se_ratio: Optional[float] = None,
  186. **kwargs: Any,
  187. ) -> "BlockParams":
  188. """
  189. Programatically compute all the per-block settings,
  190. given the RegNet parameters.
  191. The first step is to compute the quantized linear block parameters,
  192. in log space. Key parameters are:
  193. - `w_a` is the width progression slope
  194. - `w_0` is the initial width
  195. - `w_m` is the width stepping in the log space
  196. In other terms
  197. `log(block_width) = log(w_0) + w_m * block_capacity`,
  198. with `bock_capacity` ramping up following the w_0 and w_a params.
  199. This block width is finally quantized to multiples of 8.
  200. The second step is to compute the parameters per stage,
  201. taking into account the skip connection and the final 1x1 convolutions.
  202. We use the fact that the output width is constant within a stage.
  203. """
  204. QUANT = 8
  205. STRIDE = 2
  206. if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
  207. raise ValueError("Invalid RegNet settings")
  208. # Compute the block widths. Each stage has one unique block width
  209. widths_cont = torch.arange(depth) * w_a + w_0
  210. block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
  211. block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
  212. num_stages = len(set(block_widths))
  213. # Convert to per stage parameters
  214. split_helper = zip(
  215. block_widths + [0],
  216. [0] + block_widths,
  217. block_widths + [0],
  218. [0] + block_widths,
  219. )
  220. splits = [w != wp or r != rp for w, wp, r, rp in split_helper]
  221. stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
  222. stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()
  223. strides = [STRIDE] * num_stages
  224. bottleneck_multipliers = [bottleneck_multiplier] * num_stages
  225. group_widths = [group_width] * num_stages
  226. # Adjust the compatibility of stage widths and group widths
  227. stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
  228. stage_widths, bottleneck_multipliers, group_widths
  229. )
  230. return cls(
  231. depths=stage_depths,
  232. widths=stage_widths,
  233. group_widths=group_widths,
  234. bottleneck_multipliers=bottleneck_multipliers,
  235. strides=strides,
  236. se_ratio=se_ratio,
  237. )
  238. def _get_expanded_params(self):
  239. return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
  240. @staticmethod
  241. def _adjust_widths_groups_compatibilty(
  242. stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
  243. ) -> Tuple[List[int], List[int]]:
  244. """
  245. Adjusts the compatibility of widths and groups,
  246. depending on the bottleneck ratio.
  247. """
  248. # Compute all widths for the current settings
  249. widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
  250. group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]
  251. # Compute the adjusted widths so that stage and group widths fit
  252. ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
  253. stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
  254. return stage_widths, group_widths_min
  255. class RegNet(nn.Module):
  256. def __init__(
  257. self,
  258. block_params: BlockParams,
  259. num_classes: int = 1000,
  260. stem_width: int = 32,
  261. stem_type: Optional[Callable[..., nn.Module]] = None,
  262. block_type: Optional[Callable[..., nn.Module]] = None,
  263. norm_layer: Optional[Callable[..., nn.Module]] = None,
  264. activation: Optional[Callable[..., nn.Module]] = None,
  265. ) -> None:
  266. super().__init__()
  267. _log_api_usage_once(self)
  268. if stem_type is None:
  269. stem_type = SimpleStemIN
  270. if norm_layer is None:
  271. norm_layer = nn.BatchNorm2d
  272. if block_type is None:
  273. block_type = ResBottleneckBlock
  274. if activation is None:
  275. activation = nn.ReLU
  276. # Ad hoc stem
  277. self.stem = stem_type(
  278. 3, # width_in
  279. stem_width,
  280. norm_layer,
  281. activation,
  282. )
  283. current_width = stem_width
  284. blocks = []
  285. for i, (
  286. width_out,
  287. stride,
  288. depth,
  289. group_width,
  290. bottleneck_multiplier,
  291. ) in enumerate(block_params._get_expanded_params()):
  292. blocks.append(
  293. (
  294. f"block{i+1}",
  295. AnyStage(
  296. current_width,
  297. width_out,
  298. stride,
  299. depth,
  300. block_type,
  301. norm_layer,
  302. activation,
  303. group_width,
  304. bottleneck_multiplier,
  305. block_params.se_ratio,
  306. stage_index=i + 1,
  307. ),
  308. )
  309. )
  310. current_width = width_out
  311. self.trunk_output = nn.Sequential(OrderedDict(blocks))
  312. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  313. self.fc = nn.Linear(in_features=current_width, out_features=num_classes)
  314. # Performs ResNet-style weight initialization
  315. for m in self.modules():
  316. if isinstance(m, nn.Conv2d):
  317. # Note that there is no bias due to BN
  318. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  319. nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
  320. elif isinstance(m, nn.BatchNorm2d):
  321. nn.init.ones_(m.weight)
  322. nn.init.zeros_(m.bias)
  323. elif isinstance(m, nn.Linear):
  324. nn.init.normal_(m.weight, mean=0.0, std=0.01)
  325. nn.init.zeros_(m.bias)
  326. def forward(self, x: Tensor) -> Tensor:
  327. x = self.stem(x)
  328. x = self.trunk_output(x)
  329. x = self.avgpool(x)
  330. x = x.flatten(start_dim=1)
  331. x = self.fc(x)
  332. return x
  333. def _regnet(
  334. block_params: BlockParams,
  335. weights: Optional[WeightsEnum],
  336. progress: bool,
  337. **kwargs: Any,
  338. ) -> RegNet:
  339. if weights is not None:
  340. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  341. norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
  342. model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
  343. if weights is not None:
  344. model.load_state_dict(weights.get_state_dict(progress=progress))
  345. return model
  346. _COMMON_META: Dict[str, Any] = {
  347. "min_size": (1, 1),
  348. "categories": _IMAGENET_CATEGORIES,
  349. }
  350. _COMMON_SWAG_META = {
  351. **_COMMON_META,
  352. "recipe": "https://github.com/facebookresearch/SWAG",
  353. "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
  354. }
  355. class RegNet_Y_400MF_Weights(WeightsEnum):
  356. IMAGENET1K_V1 = Weights(
  357. url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
  358. transforms=partial(ImageClassification, crop_size=224),
  359. meta={
  360. **_COMMON_META,
  361. "num_params": 4344144,
  362. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  363. "_metrics": {
  364. "ImageNet-1K": {
  365. "acc@1": 74.046,
  366. "acc@5": 91.716,
  367. }
  368. },
  369. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  370. },
  371. )
  372. IMAGENET1K_V2 = Weights(
  373. url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
  374. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  375. meta={
  376. **_COMMON_META,
  377. "num_params": 4344144,
  378. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  379. "_metrics": {
  380. "ImageNet-1K": {
  381. "acc@1": 75.804,
  382. "acc@5": 92.742,
  383. }
  384. },
  385. "_docs": """
  386. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  387. `new training recipe
  388. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  389. """,
  390. },
  391. )
  392. DEFAULT = IMAGENET1K_V2
  393. class RegNet_Y_800MF_Weights(WeightsEnum):
  394. IMAGENET1K_V1 = Weights(
  395. url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
  396. transforms=partial(ImageClassification, crop_size=224),
  397. meta={
  398. **_COMMON_META,
  399. "num_params": 6432512,
  400. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  401. "_metrics": {
  402. "ImageNet-1K": {
  403. "acc@1": 76.420,
  404. "acc@5": 93.136,
  405. }
  406. },
  407. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  408. },
  409. )
  410. IMAGENET1K_V2 = Weights(
  411. url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
  412. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  413. meta={
  414. **_COMMON_META,
  415. "num_params": 6432512,
  416. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  417. "_metrics": {
  418. "ImageNet-1K": {
  419. "acc@1": 78.828,
  420. "acc@5": 94.502,
  421. }
  422. },
  423. "_docs": """
  424. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  425. `new training recipe
  426. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  427. """,
  428. },
  429. )
  430. DEFAULT = IMAGENET1K_V2
  431. class RegNet_Y_1_6GF_Weights(WeightsEnum):
  432. IMAGENET1K_V1 = Weights(
  433. url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
  434. transforms=partial(ImageClassification, crop_size=224),
  435. meta={
  436. **_COMMON_META,
  437. "num_params": 11202430,
  438. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  439. "_metrics": {
  440. "ImageNet-1K": {
  441. "acc@1": 77.950,
  442. "acc@5": 93.966,
  443. }
  444. },
  445. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  446. },
  447. )
  448. IMAGENET1K_V2 = Weights(
  449. url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
  450. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  451. meta={
  452. **_COMMON_META,
  453. "num_params": 11202430,
  454. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  455. "_metrics": {
  456. "ImageNet-1K": {
  457. "acc@1": 80.876,
  458. "acc@5": 95.444,
  459. }
  460. },
  461. "_docs": """
  462. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  463. `new training recipe
  464. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  465. """,
  466. },
  467. )
  468. DEFAULT = IMAGENET1K_V2
  469. class RegNet_Y_3_2GF_Weights(WeightsEnum):
  470. IMAGENET1K_V1 = Weights(
  471. url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
  472. transforms=partial(ImageClassification, crop_size=224),
  473. meta={
  474. **_COMMON_META,
  475. "num_params": 19436338,
  476. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  477. "_metrics": {
  478. "ImageNet-1K": {
  479. "acc@1": 78.948,
  480. "acc@5": 94.576,
  481. }
  482. },
  483. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  484. },
  485. )
  486. IMAGENET1K_V2 = Weights(
  487. url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
  488. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  489. meta={
  490. **_COMMON_META,
  491. "num_params": 19436338,
  492. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  493. "_metrics": {
  494. "ImageNet-1K": {
  495. "acc@1": 81.982,
  496. "acc@5": 95.972,
  497. }
  498. },
  499. "_docs": """
  500. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  501. `new training recipe
  502. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  503. """,
  504. },
  505. )
  506. DEFAULT = IMAGENET1K_V2
  507. class RegNet_Y_8GF_Weights(WeightsEnum):
  508. IMAGENET1K_V1 = Weights(
  509. url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
  510. transforms=partial(ImageClassification, crop_size=224),
  511. meta={
  512. **_COMMON_META,
  513. "num_params": 39381472,
  514. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  515. "_metrics": {
  516. "ImageNet-1K": {
  517. "acc@1": 80.032,
  518. "acc@5": 95.048,
  519. }
  520. },
  521. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  522. },
  523. )
  524. IMAGENET1K_V2 = Weights(
  525. url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
  526. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  527. meta={
  528. **_COMMON_META,
  529. "num_params": 39381472,
  530. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  531. "_metrics": {
  532. "ImageNet-1K": {
  533. "acc@1": 82.828,
  534. "acc@5": 96.330,
  535. }
  536. },
  537. "_docs": """
  538. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  539. `new training recipe
  540. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  541. """,
  542. },
  543. )
  544. DEFAULT = IMAGENET1K_V2
  545. class RegNet_Y_16GF_Weights(WeightsEnum):
  546. IMAGENET1K_V1 = Weights(
  547. url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
  548. transforms=partial(ImageClassification, crop_size=224),
  549. meta={
  550. **_COMMON_META,
  551. "num_params": 83590140,
  552. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  553. "_metrics": {
  554. "ImageNet-1K": {
  555. "acc@1": 80.424,
  556. "acc@5": 95.240,
  557. }
  558. },
  559. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  560. },
  561. )
  562. IMAGENET1K_V2 = Weights(
  563. url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
  564. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  565. meta={
  566. **_COMMON_META,
  567. "num_params": 83590140,
  568. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  569. "_metrics": {
  570. "ImageNet-1K": {
  571. "acc@1": 82.886,
  572. "acc@5": 96.328,
  573. }
  574. },
  575. "_docs": """
  576. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  577. `new training recipe
  578. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  579. """,
  580. },
  581. )
  582. IMAGENET1K_SWAG_E2E_V1 = Weights(
  583. url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
  584. transforms=partial(
  585. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  586. ),
  587. meta={
  588. **_COMMON_SWAG_META,
  589. "num_params": 83590140,
  590. "_metrics": {
  591. "ImageNet-1K": {
  592. "acc@1": 86.012,
  593. "acc@5": 98.054,
  594. }
  595. },
  596. "_docs": """
  597. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  598. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  599. """,
  600. },
  601. )
  602. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  603. url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
  604. transforms=partial(
  605. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  606. ),
  607. meta={
  608. **_COMMON_SWAG_META,
  609. "recipe": "https://github.com/pytorch/vision/pull/5793",
  610. "num_params": 83590140,
  611. "_metrics": {
  612. "ImageNet-1K": {
  613. "acc@1": 83.976,
  614. "acc@5": 97.244,
  615. }
  616. },
  617. "_docs": """
  618. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  619. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  620. """,
  621. },
  622. )
  623. DEFAULT = IMAGENET1K_V2
  624. class RegNet_Y_32GF_Weights(WeightsEnum):
  625. IMAGENET1K_V1 = Weights(
  626. url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
  627. transforms=partial(ImageClassification, crop_size=224),
  628. meta={
  629. **_COMMON_META,
  630. "num_params": 145046770,
  631. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  632. "_metrics": {
  633. "ImageNet-1K": {
  634. "acc@1": 80.878,
  635. "acc@5": 95.340,
  636. }
  637. },
  638. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  639. },
  640. )
  641. IMAGENET1K_V2 = Weights(
  642. url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
  643. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  644. meta={
  645. **_COMMON_META,
  646. "num_params": 145046770,
  647. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  648. "_metrics": {
  649. "ImageNet-1K": {
  650. "acc@1": 83.368,
  651. "acc@5": 96.498,
  652. }
  653. },
  654. "_docs": """
  655. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  656. `new training recipe
  657. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  658. """,
  659. },
  660. )
  661. IMAGENET1K_SWAG_E2E_V1 = Weights(
  662. url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
  663. transforms=partial(
  664. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  665. ),
  666. meta={
  667. **_COMMON_SWAG_META,
  668. "num_params": 145046770,
  669. "_metrics": {
  670. "ImageNet-1K": {
  671. "acc@1": 86.838,
  672. "acc@5": 98.362,
  673. }
  674. },
  675. "_docs": """
  676. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  677. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  678. """,
  679. },
  680. )
  681. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  682. url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
  683. transforms=partial(
  684. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  685. ),
  686. meta={
  687. **_COMMON_SWAG_META,
  688. "recipe": "https://github.com/pytorch/vision/pull/5793",
  689. "num_params": 145046770,
  690. "_metrics": {
  691. "ImageNet-1K": {
  692. "acc@1": 84.622,
  693. "acc@5": 97.480,
  694. }
  695. },
  696. "_docs": """
  697. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  698. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  699. """,
  700. },
  701. )
  702. DEFAULT = IMAGENET1K_V2
  703. class RegNet_Y_128GF_Weights(WeightsEnum):
  704. IMAGENET1K_SWAG_E2E_V1 = Weights(
  705. url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
  706. transforms=partial(
  707. ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
  708. ),
  709. meta={
  710. **_COMMON_SWAG_META,
  711. "num_params": 644812894,
  712. "_metrics": {
  713. "ImageNet-1K": {
  714. "acc@1": 88.228,
  715. "acc@5": 98.682,
  716. }
  717. },
  718. "_docs": """
  719. These weights are learnt via transfer learning by end-to-end fine-tuning the original
  720. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
  721. """,
  722. },
  723. )
  724. IMAGENET1K_SWAG_LINEAR_V1 = Weights(
  725. url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
  726. transforms=partial(
  727. ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
  728. ),
  729. meta={
  730. **_COMMON_SWAG_META,
  731. "recipe": "https://github.com/pytorch/vision/pull/5793",
  732. "num_params": 644812894,
  733. "_metrics": {
  734. "ImageNet-1K": {
  735. "acc@1": 86.068,
  736. "acc@5": 97.844,
  737. }
  738. },
  739. "_docs": """
  740. These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
  741. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
  742. """,
  743. },
  744. )
  745. DEFAULT = IMAGENET1K_SWAG_E2E_V1
  746. class RegNet_X_400MF_Weights(WeightsEnum):
  747. IMAGENET1K_V1 = Weights(
  748. url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
  749. transforms=partial(ImageClassification, crop_size=224),
  750. meta={
  751. **_COMMON_META,
  752. "num_params": 5495976,
  753. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  754. "_metrics": {
  755. "ImageNet-1K": {
  756. "acc@1": 72.834,
  757. "acc@5": 90.950,
  758. }
  759. },
  760. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  761. },
  762. )
  763. IMAGENET1K_V2 = Weights(
  764. url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
  765. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  766. meta={
  767. **_COMMON_META,
  768. "num_params": 5495976,
  769. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  770. "_metrics": {
  771. "ImageNet-1K": {
  772. "acc@1": 74.864,
  773. "acc@5": 92.322,
  774. }
  775. },
  776. "_docs": """
  777. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  778. `new training recipe
  779. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  780. """,
  781. },
  782. )
  783. DEFAULT = IMAGENET1K_V2
  784. class RegNet_X_800MF_Weights(WeightsEnum):
  785. IMAGENET1K_V1 = Weights(
  786. url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
  787. transforms=partial(ImageClassification, crop_size=224),
  788. meta={
  789. **_COMMON_META,
  790. "num_params": 7259656,
  791. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  792. "_metrics": {
  793. "ImageNet-1K": {
  794. "acc@1": 75.212,
  795. "acc@5": 92.348,
  796. }
  797. },
  798. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  799. },
  800. )
  801. IMAGENET1K_V2 = Weights(
  802. url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
  803. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  804. meta={
  805. **_COMMON_META,
  806. "num_params": 7259656,
  807. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  808. "_metrics": {
  809. "ImageNet-1K": {
  810. "acc@1": 77.522,
  811. "acc@5": 93.826,
  812. }
  813. },
  814. "_docs": """
  815. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  816. `new training recipe
  817. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  818. """,
  819. },
  820. )
  821. DEFAULT = IMAGENET1K_V2
  822. class RegNet_X_1_6GF_Weights(WeightsEnum):
  823. IMAGENET1K_V1 = Weights(
  824. url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
  825. transforms=partial(ImageClassification, crop_size=224),
  826. meta={
  827. **_COMMON_META,
  828. "num_params": 9190136,
  829. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
  830. "_metrics": {
  831. "ImageNet-1K": {
  832. "acc@1": 77.040,
  833. "acc@5": 93.440,
  834. }
  835. },
  836. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  837. },
  838. )
  839. IMAGENET1K_V2 = Weights(
  840. url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
  841. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  842. meta={
  843. **_COMMON_META,
  844. "num_params": 9190136,
  845. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
  846. "_metrics": {
  847. "ImageNet-1K": {
  848. "acc@1": 79.668,
  849. "acc@5": 94.922,
  850. }
  851. },
  852. "_docs": """
  853. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  854. `new training recipe
  855. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  856. """,
  857. },
  858. )
  859. DEFAULT = IMAGENET1K_V2
  860. class RegNet_X_3_2GF_Weights(WeightsEnum):
  861. IMAGENET1K_V1 = Weights(
  862. url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
  863. transforms=partial(ImageClassification, crop_size=224),
  864. meta={
  865. **_COMMON_META,
  866. "num_params": 15296552,
  867. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  868. "_metrics": {
  869. "ImageNet-1K": {
  870. "acc@1": 78.364,
  871. "acc@5": 93.992,
  872. }
  873. },
  874. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  875. },
  876. )
  877. IMAGENET1K_V2 = Weights(
  878. url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
  879. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  880. meta={
  881. **_COMMON_META,
  882. "num_params": 15296552,
  883. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  884. "_metrics": {
  885. "ImageNet-1K": {
  886. "acc@1": 81.196,
  887. "acc@5": 95.430,
  888. }
  889. },
  890. "_docs": """
  891. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  892. `new training recipe
  893. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  894. """,
  895. },
  896. )
  897. DEFAULT = IMAGENET1K_V2
  898. class RegNet_X_8GF_Weights(WeightsEnum):
  899. IMAGENET1K_V1 = Weights(
  900. url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
  901. transforms=partial(ImageClassification, crop_size=224),
  902. meta={
  903. **_COMMON_META,
  904. "num_params": 39572648,
  905. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  906. "_metrics": {
  907. "ImageNet-1K": {
  908. "acc@1": 79.344,
  909. "acc@5": 94.686,
  910. }
  911. },
  912. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  913. },
  914. )
  915. IMAGENET1K_V2 = Weights(
  916. url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
  917. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  918. meta={
  919. **_COMMON_META,
  920. "num_params": 39572648,
  921. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  922. "_metrics": {
  923. "ImageNet-1K": {
  924. "acc@1": 81.682,
  925. "acc@5": 95.678,
  926. }
  927. },
  928. "_docs": """
  929. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  930. `new training recipe
  931. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  932. """,
  933. },
  934. )
  935. DEFAULT = IMAGENET1K_V2
  936. class RegNet_X_16GF_Weights(WeightsEnum):
  937. IMAGENET1K_V1 = Weights(
  938. url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
  939. transforms=partial(ImageClassification, crop_size=224),
  940. meta={
  941. **_COMMON_META,
  942. "num_params": 54278536,
  943. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
  944. "_metrics": {
  945. "ImageNet-1K": {
  946. "acc@1": 80.058,
  947. "acc@5": 94.944,
  948. }
  949. },
  950. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  951. },
  952. )
  953. IMAGENET1K_V2 = Weights(
  954. url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
  955. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  956. meta={
  957. **_COMMON_META,
  958. "num_params": 54278536,
  959. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  960. "_metrics": {
  961. "ImageNet-1K": {
  962. "acc@1": 82.716,
  963. "acc@5": 96.196,
  964. }
  965. },
  966. "_docs": """
  967. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  968. `new training recipe
  969. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  970. """,
  971. },
  972. )
  973. DEFAULT = IMAGENET1K_V2
  974. class RegNet_X_32GF_Weights(WeightsEnum):
  975. IMAGENET1K_V1 = Weights(
  976. url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
  977. transforms=partial(ImageClassification, crop_size=224),
  978. meta={
  979. **_COMMON_META,
  980. "num_params": 107811560,
  981. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
  982. "_metrics": {
  983. "ImageNet-1K": {
  984. "acc@1": 80.622,
  985. "acc@5": 95.248,
  986. }
  987. },
  988. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  989. },
  990. )
  991. IMAGENET1K_V2 = Weights(
  992. url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
  993. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  994. meta={
  995. **_COMMON_META,
  996. "num_params": 107811560,
  997. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
  998. "_metrics": {
  999. "ImageNet-1K": {
  1000. "acc@1": 83.014,
  1001. "acc@5": 96.288,
  1002. }
  1003. },
  1004. "_docs": """
  1005. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  1006. `new training recipe
  1007. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  1008. """,
  1009. },
  1010. )
  1011. DEFAULT = IMAGENET1K_V2
  1012. @handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
  1013. def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1014. """
  1015. Constructs a RegNetY_400MF architecture from
  1016. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1017. Args:
  1018. weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
  1019. See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values.
  1020. By default, no pretrained weights are used.
  1021. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1022. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1023. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1024. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1025. for more detail about the classes.
  1026. .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights
  1027. :members:
  1028. """
  1029. weights = RegNet_Y_400MF_Weights.verify(weights)
  1030. params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
  1031. return _regnet(params, weights, progress, **kwargs)
  1032. @handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
  1033. def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1034. """
  1035. Constructs a RegNetY_800MF architecture from
  1036. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1037. Args:
  1038. weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
  1039. See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values.
  1040. By default, no pretrained weights are used.
  1041. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1042. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1043. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1044. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1045. for more detail about the classes.
  1046. .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights
  1047. :members:
  1048. """
  1049. weights = RegNet_Y_800MF_Weights.verify(weights)
  1050. params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
  1051. return _regnet(params, weights, progress, **kwargs)
  1052. @handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
  1053. def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1054. """
  1055. Constructs a RegNetY_1.6GF architecture from
  1056. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1057. Args:
  1058. weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
  1059. See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values.
  1060. By default, no pretrained weights are used.
  1061. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1062. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1063. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1064. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1065. for more detail about the classes.
  1066. .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights
  1067. :members:
  1068. """
  1069. weights = RegNet_Y_1_6GF_Weights.verify(weights)
  1070. params = BlockParams.from_init_params(
  1071. depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
  1072. )
  1073. return _regnet(params, weights, progress, **kwargs)
  1074. @handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
  1075. def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1076. """
  1077. Constructs a RegNetY_3.2GF architecture from
  1078. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1079. Args:
  1080. weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
  1081. See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values.
  1082. By default, no pretrained weights are used.
  1083. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1084. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1085. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1086. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1087. for more detail about the classes.
  1088. .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights
  1089. :members:
  1090. """
  1091. weights = RegNet_Y_3_2GF_Weights.verify(weights)
  1092. params = BlockParams.from_init_params(
  1093. depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
  1094. )
  1095. return _regnet(params, weights, progress, **kwargs)
  1096. @handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
  1097. def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1098. """
  1099. Constructs a RegNetY_8GF architecture from
  1100. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1101. Args:
  1102. weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
  1103. See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values.
  1104. By default, no pretrained weights are used.
  1105. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1106. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1107. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1108. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1109. for more detail about the classes.
  1110. .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights
  1111. :members:
  1112. """
  1113. weights = RegNet_Y_8GF_Weights.verify(weights)
  1114. params = BlockParams.from_init_params(
  1115. depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
  1116. )
  1117. return _regnet(params, weights, progress, **kwargs)
  1118. @handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
  1119. def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1120. """
  1121. Constructs a RegNetY_16GF architecture from
  1122. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1123. Args:
  1124. weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
  1125. See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values.
  1126. By default, no pretrained weights are used.
  1127. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1128. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1129. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1130. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1131. for more detail about the classes.
  1132. .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights
  1133. :members:
  1134. """
  1135. weights = RegNet_Y_16GF_Weights.verify(weights)
  1136. params = BlockParams.from_init_params(
  1137. depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
  1138. )
  1139. return _regnet(params, weights, progress, **kwargs)
  1140. @handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
  1141. def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1142. """
  1143. Constructs a RegNetY_32GF architecture from
  1144. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1145. Args:
  1146. weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
  1147. See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values.
  1148. By default, no pretrained weights are used.
  1149. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1150. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1151. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1152. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1153. for more detail about the classes.
  1154. .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights
  1155. :members:
  1156. """
  1157. weights = RegNet_Y_32GF_Weights.verify(weights)
  1158. params = BlockParams.from_init_params(
  1159. depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
  1160. )
  1161. return _regnet(params, weights, progress, **kwargs)
  1162. @handle_legacy_interface(weights=("pretrained", None))
  1163. def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1164. """
  1165. Constructs a RegNetY_128GF architecture from
  1166. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1167. Args:
  1168. weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
  1169. See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values.
  1170. By default, no pretrained weights are used.
  1171. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1172. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1173. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1174. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1175. for more detail about the classes.
  1176. .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights
  1177. :members:
  1178. """
  1179. weights = RegNet_Y_128GF_Weights.verify(weights)
  1180. params = BlockParams.from_init_params(
  1181. depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
  1182. )
  1183. return _regnet(params, weights, progress, **kwargs)
  1184. @handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
  1185. def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1186. """
  1187. Constructs a RegNetX_400MF architecture from
  1188. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1189. Args:
  1190. weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
  1191. See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values.
  1192. By default, no pretrained weights are used.
  1193. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1194. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1195. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1196. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1197. for more detail about the classes.
  1198. .. autoclass:: torchvision.models.RegNet_X_400MF_Weights
  1199. :members:
  1200. """
  1201. weights = RegNet_X_400MF_Weights.verify(weights)
  1202. params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
  1203. return _regnet(params, weights, progress, **kwargs)
  1204. @handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
  1205. def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1206. """
  1207. Constructs a RegNetX_800MF architecture from
  1208. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1209. Args:
  1210. weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
  1211. See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values.
  1212. By default, no pretrained weights are used.
  1213. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1214. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1215. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1216. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1217. for more detail about the classes.
  1218. .. autoclass:: torchvision.models.RegNet_X_800MF_Weights
  1219. :members:
  1220. """
  1221. weights = RegNet_X_800MF_Weights.verify(weights)
  1222. params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
  1223. return _regnet(params, weights, progress, **kwargs)
  1224. @handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
  1225. def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1226. """
  1227. Constructs a RegNetX_1.6GF architecture from
  1228. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1229. Args:
  1230. weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
  1231. See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values.
  1232. By default, no pretrained weights are used.
  1233. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1234. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1235. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1236. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1237. for more detail about the classes.
  1238. .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights
  1239. :members:
  1240. Args:
  1241. weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
  1242. progress (bool): If True, displays a progress bar of the download to stderr
  1243. """
  1244. weights = RegNet_X_1_6GF_Weights.verify(weights)
  1245. params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
  1246. return _regnet(params, weights, progress, **kwargs)
  1247. @handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
  1248. def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1249. """
  1250. Constructs a RegNetX_3.2GF architecture from
  1251. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1252. Args:
  1253. weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
  1254. See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values.
  1255. By default, no pretrained weights are used.
  1256. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1257. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1258. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1259. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1260. for more detail about the classes.
  1261. .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights
  1262. :members:
  1263. Args:
  1264. weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
  1265. progress (bool): If True, displays a progress bar of the download to stderr
  1266. """
  1267. weights = RegNet_X_3_2GF_Weights.verify(weights)
  1268. params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
  1269. return _regnet(params, weights, progress, **kwargs)
  1270. @handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
  1271. def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1272. """
  1273. Constructs a RegNetX_8GF architecture from
  1274. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1275. Args:
  1276. weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
  1277. See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values.
  1278. By default, no pretrained weights are used.
  1279. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1280. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1281. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1282. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1283. for more detail about the classes.
  1284. .. autoclass:: torchvision.models.RegNet_X_8GF_Weights
  1285. :members:
  1286. Args:
  1287. weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
  1288. progress (bool): If True, displays a progress bar of the download to stderr
  1289. """
  1290. weights = RegNet_X_8GF_Weights.verify(weights)
  1291. params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
  1292. return _regnet(params, weights, progress, **kwargs)
  1293. @handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
  1294. def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1295. """
  1296. Constructs a RegNetX_16GF architecture from
  1297. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1298. Args:
  1299. weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
  1300. See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values.
  1301. By default, no pretrained weights are used.
  1302. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1303. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1304. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1305. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1306. for more detail about the classes.
  1307. .. autoclass:: torchvision.models.RegNet_X_16GF_Weights
  1308. :members:
  1309. Args:
  1310. weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
  1311. progress (bool): If True, displays a progress bar of the download to stderr
  1312. """
  1313. weights = RegNet_X_16GF_Weights.verify(weights)
  1314. params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
  1315. return _regnet(params, weights, progress, **kwargs)
  1316. @handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
  1317. def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
  1318. """
  1319. Constructs a RegNetX_32GF architecture from
  1320. `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
  1321. Args:
  1322. weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
  1323. See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values.
  1324. By default, no pretrained weights are used.
  1325. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  1326. **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
  1327. ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
  1328. <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
  1329. for more detail about the classes.
  1330. .. autoclass:: torchvision.models.RegNet_X_32GF_Weights
  1331. :members:
  1332. Args:
  1333. weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
  1334. progress (bool): If True, displays a progress bar of the download to stderr
  1335. """
  1336. weights = RegNet_X_32GF_Weights.verify(weights)
  1337. params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
  1338. return _regnet(params, weights, progress, **kwargs)
  1339. # The dictionary below is internal implementation detail and will be removed in v0.15
  1340. from ._utils import _ModelURLs
  1341. model_urls = _ModelURLs(
  1342. {
  1343. "regnet_y_400mf": RegNet_Y_400MF_Weights.IMAGENET1K_V1.url,
  1344. "regnet_y_800mf": RegNet_Y_800MF_Weights.IMAGENET1K_V1.url,
  1345. "regnet_y_1_6gf": RegNet_Y_1_6GF_Weights.IMAGENET1K_V1.url,
  1346. "regnet_y_3_2gf": RegNet_Y_3_2GF_Weights.IMAGENET1K_V1.url,
  1347. "regnet_y_8gf": RegNet_Y_8GF_Weights.IMAGENET1K_V1.url,
  1348. "regnet_y_16gf": RegNet_Y_16GF_Weights.IMAGENET1K_V1.url,
  1349. "regnet_y_32gf": RegNet_Y_32GF_Weights.IMAGENET1K_V1.url,
  1350. "regnet_x_400mf": RegNet_X_400MF_Weights.IMAGENET1K_V1.url,
  1351. "regnet_x_800mf": RegNet_X_800MF_Weights.IMAGENET1K_V1.url,
  1352. "regnet_x_1_6gf": RegNet_X_1_6GF_Weights.IMAGENET1K_V1.url,
  1353. "regnet_x_3_2gf": RegNet_X_3_2GF_Weights.IMAGENET1K_V1.url,
  1354. "regnet_x_8gf": RegNet_X_8GF_Weights.IMAGENET1K_V1.url,
  1355. "regnet_x_16gf": RegNet_X_16GF_Weights.IMAGENET1K_V1.url,
  1356. "regnet_x_32gf": RegNet_X_32GF_Weights.IMAGENET1K_V1.url,
  1357. }
  1358. )