efficientnet.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. import copy
  2. import math
  3. import warnings
  4. from dataclasses import dataclass
  5. from functools import partial
  6. from typing import Any, Callable, Dict, Optional, List, Sequence, Tuple, Union
  7. import torch
  8. from torch import nn, Tensor
  9. from torchvision.ops import StochasticDepth
  10. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
  11. from ..transforms._presets import ImageClassification, InterpolationMode
  12. from ..utils import _log_api_usage_once
  13. from ._api import WeightsEnum, Weights
  14. from ._meta import _IMAGENET_CATEGORIES
  15. from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
  16. __all__ = [
  17. "EfficientNet",
  18. "EfficientNet_B0_Weights",
  19. "EfficientNet_B1_Weights",
  20. "EfficientNet_B2_Weights",
  21. "EfficientNet_B3_Weights",
  22. "EfficientNet_B4_Weights",
  23. "EfficientNet_B5_Weights",
  24. "EfficientNet_B6_Weights",
  25. "EfficientNet_B7_Weights",
  26. "EfficientNet_V2_S_Weights",
  27. "EfficientNet_V2_M_Weights",
  28. "EfficientNet_V2_L_Weights",
  29. "efficientnet_b0",
  30. "efficientnet_b1",
  31. "efficientnet_b2",
  32. "efficientnet_b3",
  33. "efficientnet_b4",
  34. "efficientnet_b5",
  35. "efficientnet_b6",
  36. "efficientnet_b7",
  37. "efficientnet_v2_s",
  38. "efficientnet_v2_m",
  39. "efficientnet_v2_l",
  40. ]
  41. @dataclass
  42. class _MBConvConfig:
  43. expand_ratio: float
  44. kernel: int
  45. stride: int
  46. input_channels: int
  47. out_channels: int
  48. num_layers: int
  49. block: Callable[..., nn.Module]
  50. @staticmethod
  51. def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
  52. return _make_divisible(channels * width_mult, 8, min_value)
  53. class MBConvConfig(_MBConvConfig):
  54. # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
  55. def __init__(
  56. self,
  57. expand_ratio: float,
  58. kernel: int,
  59. stride: int,
  60. input_channels: int,
  61. out_channels: int,
  62. num_layers: int,
  63. width_mult: float = 1.0,
  64. depth_mult: float = 1.0,
  65. block: Optional[Callable[..., nn.Module]] = None,
  66. ) -> None:
  67. input_channels = self.adjust_channels(input_channels, width_mult)
  68. out_channels = self.adjust_channels(out_channels, width_mult)
  69. num_layers = self.adjust_depth(num_layers, depth_mult)
  70. if block is None:
  71. block = MBConv
  72. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  73. @staticmethod
  74. def adjust_depth(num_layers: int, depth_mult: float):
  75. return int(math.ceil(num_layers * depth_mult))
  76. class FusedMBConvConfig(_MBConvConfig):
  77. # Stores information listed at Table 4 of the EfficientNetV2 paper
  78. def __init__(
  79. self,
  80. expand_ratio: float,
  81. kernel: int,
  82. stride: int,
  83. input_channels: int,
  84. out_channels: int,
  85. num_layers: int,
  86. block: Optional[Callable[..., nn.Module]] = None,
  87. ) -> None:
  88. if block is None:
  89. block = FusedMBConv
  90. super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
  91. class MBConv(nn.Module):
  92. def __init__(
  93. self,
  94. cnf: MBConvConfig,
  95. stochastic_depth_prob: float,
  96. norm_layer: Callable[..., nn.Module],
  97. se_layer: Callable[..., nn.Module] = SqueezeExcitation,
  98. ) -> None:
  99. super().__init__()
  100. if not (1 <= cnf.stride <= 2):
  101. raise ValueError("illegal stride value")
  102. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  103. layers: List[nn.Module] = []
  104. activation_layer = nn.SiLU
  105. # expand
  106. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  107. if expanded_channels != cnf.input_channels:
  108. layers.append(
  109. Conv2dNormActivation(
  110. cnf.input_channels,
  111. expanded_channels,
  112. kernel_size=1,
  113. norm_layer=norm_layer,
  114. activation_layer=activation_layer,
  115. )
  116. )
  117. # depthwise
  118. layers.append(
  119. Conv2dNormActivation(
  120. expanded_channels,
  121. expanded_channels,
  122. kernel_size=cnf.kernel,
  123. stride=cnf.stride,
  124. groups=expanded_channels,
  125. norm_layer=norm_layer,
  126. activation_layer=activation_layer,
  127. )
  128. )
  129. # squeeze and excitation
  130. squeeze_channels = max(1, cnf.input_channels // 4)
  131. layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
  132. # project
  133. layers.append(
  134. Conv2dNormActivation(
  135. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  136. )
  137. )
  138. self.block = nn.Sequential(*layers)
  139. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  140. self.out_channels = cnf.out_channels
  141. def forward(self, input: Tensor) -> Tensor:
  142. result = self.block(input)
  143. if self.use_res_connect:
  144. result = self.stochastic_depth(result)
  145. result += input
  146. return result
  147. class FusedMBConv(nn.Module):
  148. def __init__(
  149. self,
  150. cnf: FusedMBConvConfig,
  151. stochastic_depth_prob: float,
  152. norm_layer: Callable[..., nn.Module],
  153. ) -> None:
  154. super().__init__()
  155. if not (1 <= cnf.stride <= 2):
  156. raise ValueError("illegal stride value")
  157. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  158. layers: List[nn.Module] = []
  159. activation_layer = nn.SiLU
  160. expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
  161. if expanded_channels != cnf.input_channels:
  162. # fused expand
  163. layers.append(
  164. Conv2dNormActivation(
  165. cnf.input_channels,
  166. expanded_channels,
  167. kernel_size=cnf.kernel,
  168. stride=cnf.stride,
  169. norm_layer=norm_layer,
  170. activation_layer=activation_layer,
  171. )
  172. )
  173. # project
  174. layers.append(
  175. Conv2dNormActivation(
  176. expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  177. )
  178. )
  179. else:
  180. layers.append(
  181. Conv2dNormActivation(
  182. cnf.input_channels,
  183. cnf.out_channels,
  184. kernel_size=cnf.kernel,
  185. stride=cnf.stride,
  186. norm_layer=norm_layer,
  187. activation_layer=activation_layer,
  188. )
  189. )
  190. self.block = nn.Sequential(*layers)
  191. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  192. self.out_channels = cnf.out_channels
  193. def forward(self, input: Tensor) -> Tensor:
  194. result = self.block(input)
  195. if self.use_res_connect:
  196. result = self.stochastic_depth(result)
  197. result += input
  198. return result
  199. class EfficientNet(nn.Module):
  200. def __init__(
  201. self,
  202. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  203. dropout: float,
  204. stochastic_depth_prob: float = 0.2,
  205. num_classes: int = 1000,
  206. norm_layer: Optional[Callable[..., nn.Module]] = None,
  207. last_channel: Optional[int] = None,
  208. **kwargs: Any,
  209. ) -> None:
  210. """
  211. EfficientNet V1 and V2 main class
  212. Args:
  213. inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
  214. dropout (float): The droupout probability
  215. stochastic_depth_prob (float): The stochastic depth probability
  216. num_classes (int): Number of classes
  217. norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
  218. last_channel (int): The number of channels on the penultimate layer
  219. """
  220. super().__init__()
  221. _log_api_usage_once(self)
  222. if not inverted_residual_setting:
  223. raise ValueError("The inverted_residual_setting should not be empty")
  224. elif not (
  225. isinstance(inverted_residual_setting, Sequence)
  226. and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
  227. ):
  228. raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
  229. if "block" in kwargs:
  230. warnings.warn(
  231. "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. "
  232. "Please pass this information on 'MBConvConfig.block' instead."
  233. )
  234. if kwargs["block"] is not None:
  235. for s in inverted_residual_setting:
  236. if isinstance(s, MBConvConfig):
  237. s.block = kwargs["block"]
  238. if norm_layer is None:
  239. norm_layer = nn.BatchNorm2d
  240. layers: List[nn.Module] = []
  241. # building first layer
  242. firstconv_output_channels = inverted_residual_setting[0].input_channels
  243. layers.append(
  244. Conv2dNormActivation(
  245. 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
  246. )
  247. )
  248. # building inverted residual blocks
  249. total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
  250. stage_block_id = 0
  251. for cnf in inverted_residual_setting:
  252. stage: List[nn.Module] = []
  253. for _ in range(cnf.num_layers):
  254. # copy to avoid modifications. shallow copy is enough
  255. block_cnf = copy.copy(cnf)
  256. # overwrite info if not the first conv in the stage
  257. if stage:
  258. block_cnf.input_channels = block_cnf.out_channels
  259. block_cnf.stride = 1
  260. # adjust stochastic depth probability based on the depth of the stage block
  261. sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
  262. stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
  263. stage_block_id += 1
  264. layers.append(nn.Sequential(*stage))
  265. # building last several layers
  266. lastconv_input_channels = inverted_residual_setting[-1].out_channels
  267. lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
  268. layers.append(
  269. Conv2dNormActivation(
  270. lastconv_input_channels,
  271. lastconv_output_channels,
  272. kernel_size=1,
  273. norm_layer=norm_layer,
  274. activation_layer=nn.SiLU,
  275. )
  276. )
  277. self.features = nn.Sequential(*layers)
  278. self.avgpool = nn.AdaptiveAvgPool2d(1)
  279. self.classifier = nn.Sequential(
  280. nn.Dropout(p=dropout, inplace=True),
  281. nn.Linear(lastconv_output_channels, num_classes),
  282. )
  283. for m in self.modules():
  284. if isinstance(m, nn.Conv2d):
  285. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  286. if m.bias is not None:
  287. nn.init.zeros_(m.bias)
  288. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  289. nn.init.ones_(m.weight)
  290. nn.init.zeros_(m.bias)
  291. elif isinstance(m, nn.Linear):
  292. init_range = 1.0 / math.sqrt(m.out_features)
  293. nn.init.uniform_(m.weight, -init_range, init_range)
  294. nn.init.zeros_(m.bias)
  295. def _forward_impl(self, x: Tensor) -> Tensor:
  296. x = self.features(x)
  297. x = self.avgpool(x)
  298. x = torch.flatten(x, 1)
  299. x = self.classifier(x)
  300. return x
  301. def forward(self, x: Tensor) -> Tensor:
  302. return self._forward_impl(x)
  303. def _efficientnet(
  304. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
  305. dropout: float,
  306. last_channel: Optional[int],
  307. weights: Optional[WeightsEnum],
  308. progress: bool,
  309. **kwargs: Any,
  310. ) -> EfficientNet:
  311. if weights is not None:
  312. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  313. model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
  314. if weights is not None:
  315. model.load_state_dict(weights.get_state_dict(progress=progress))
  316. return model
  317. def _efficientnet_conf(
  318. arch: str,
  319. **kwargs: Any,
  320. ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
  321. inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
  322. if arch.startswith("efficientnet_b"):
  323. bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
  324. inverted_residual_setting = [
  325. bneck_conf(1, 3, 1, 32, 16, 1),
  326. bneck_conf(6, 3, 2, 16, 24, 2),
  327. bneck_conf(6, 5, 2, 24, 40, 2),
  328. bneck_conf(6, 3, 2, 40, 80, 3),
  329. bneck_conf(6, 5, 1, 80, 112, 3),
  330. bneck_conf(6, 5, 2, 112, 192, 4),
  331. bneck_conf(6, 3, 1, 192, 320, 1),
  332. ]
  333. last_channel = None
  334. elif arch.startswith("efficientnet_v2_s"):
  335. inverted_residual_setting = [
  336. FusedMBConvConfig(1, 3, 1, 24, 24, 2),
  337. FusedMBConvConfig(4, 3, 2, 24, 48, 4),
  338. FusedMBConvConfig(4, 3, 2, 48, 64, 4),
  339. MBConvConfig(4, 3, 2, 64, 128, 6),
  340. MBConvConfig(6, 3, 1, 128, 160, 9),
  341. MBConvConfig(6, 3, 2, 160, 256, 15),
  342. ]
  343. last_channel = 1280
  344. elif arch.startswith("efficientnet_v2_m"):
  345. inverted_residual_setting = [
  346. FusedMBConvConfig(1, 3, 1, 24, 24, 3),
  347. FusedMBConvConfig(4, 3, 2, 24, 48, 5),
  348. FusedMBConvConfig(4, 3, 2, 48, 80, 5),
  349. MBConvConfig(4, 3, 2, 80, 160, 7),
  350. MBConvConfig(6, 3, 1, 160, 176, 14),
  351. MBConvConfig(6, 3, 2, 176, 304, 18),
  352. MBConvConfig(6, 3, 1, 304, 512, 5),
  353. ]
  354. last_channel = 1280
  355. elif arch.startswith("efficientnet_v2_l"):
  356. inverted_residual_setting = [
  357. FusedMBConvConfig(1, 3, 1, 32, 32, 4),
  358. FusedMBConvConfig(4, 3, 2, 32, 64, 7),
  359. FusedMBConvConfig(4, 3, 2, 64, 96, 7),
  360. MBConvConfig(4, 3, 2, 96, 192, 10),
  361. MBConvConfig(6, 3, 1, 192, 224, 19),
  362. MBConvConfig(6, 3, 2, 224, 384, 25),
  363. MBConvConfig(6, 3, 1, 384, 640, 7),
  364. ]
  365. last_channel = 1280
  366. else:
  367. raise ValueError(f"Unsupported model type {arch}")
  368. return inverted_residual_setting, last_channel
  369. _COMMON_META: Dict[str, Any] = {
  370. "categories": _IMAGENET_CATEGORIES,
  371. }
  372. _COMMON_META_V1 = {
  373. **_COMMON_META,
  374. "min_size": (1, 1),
  375. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
  376. }
  377. _COMMON_META_V2 = {
  378. **_COMMON_META,
  379. "min_size": (33, 33),
  380. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
  381. }
  382. class EfficientNet_B0_Weights(WeightsEnum):
  383. IMAGENET1K_V1 = Weights(
  384. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  385. url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
  386. transforms=partial(
  387. ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
  388. ),
  389. meta={
  390. **_COMMON_META_V1,
  391. "num_params": 5288548,
  392. "_metrics": {
  393. "ImageNet-1K": {
  394. "acc@1": 77.692,
  395. "acc@5": 93.532,
  396. }
  397. },
  398. "_docs": """These weights are ported from the original paper.""",
  399. },
  400. )
  401. DEFAULT = IMAGENET1K_V1
  402. class EfficientNet_B1_Weights(WeightsEnum):
  403. IMAGENET1K_V1 = Weights(
  404. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  405. url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
  406. transforms=partial(
  407. ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
  408. ),
  409. meta={
  410. **_COMMON_META_V1,
  411. "num_params": 7794184,
  412. "_metrics": {
  413. "ImageNet-1K": {
  414. "acc@1": 78.642,
  415. "acc@5": 94.186,
  416. }
  417. },
  418. "_docs": """These weights are ported from the original paper.""",
  419. },
  420. )
  421. IMAGENET1K_V2 = Weights(
  422. url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
  423. transforms=partial(
  424. ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
  425. ),
  426. meta={
  427. **_COMMON_META_V1,
  428. "num_params": 7794184,
  429. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
  430. "_metrics": {
  431. "ImageNet-1K": {
  432. "acc@1": 79.838,
  433. "acc@5": 94.934,
  434. }
  435. },
  436. "_docs": """
  437. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  438. `new training recipe
  439. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  440. """,
  441. },
  442. )
  443. DEFAULT = IMAGENET1K_V2
  444. class EfficientNet_B2_Weights(WeightsEnum):
  445. IMAGENET1K_V1 = Weights(
  446. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  447. url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
  448. transforms=partial(
  449. ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
  450. ),
  451. meta={
  452. **_COMMON_META_V1,
  453. "num_params": 9109994,
  454. "_metrics": {
  455. "ImageNet-1K": {
  456. "acc@1": 80.608,
  457. "acc@5": 95.310,
  458. }
  459. },
  460. "_docs": """These weights are ported from the original paper.""",
  461. },
  462. )
  463. DEFAULT = IMAGENET1K_V1
  464. class EfficientNet_B3_Weights(WeightsEnum):
  465. IMAGENET1K_V1 = Weights(
  466. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  467. url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
  468. transforms=partial(
  469. ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
  470. ),
  471. meta={
  472. **_COMMON_META_V1,
  473. "num_params": 12233232,
  474. "_metrics": {
  475. "ImageNet-1K": {
  476. "acc@1": 82.008,
  477. "acc@5": 96.054,
  478. }
  479. },
  480. "_docs": """These weights are ported from the original paper.""",
  481. },
  482. )
  483. DEFAULT = IMAGENET1K_V1
  484. class EfficientNet_B4_Weights(WeightsEnum):
  485. IMAGENET1K_V1 = Weights(
  486. # Weights ported from https://github.com/rwightman/pytorch-image-models/
  487. url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
  488. transforms=partial(
  489. ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
  490. ),
  491. meta={
  492. **_COMMON_META_V1,
  493. "num_params": 19341616,
  494. "_metrics": {
  495. "ImageNet-1K": {
  496. "acc@1": 83.384,
  497. "acc@5": 96.594,
  498. }
  499. },
  500. "_docs": """These weights are ported from the original paper.""",
  501. },
  502. )
  503. DEFAULT = IMAGENET1K_V1
  504. class EfficientNet_B5_Weights(WeightsEnum):
  505. IMAGENET1K_V1 = Weights(
  506. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  507. url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
  508. transforms=partial(
  509. ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
  510. ),
  511. meta={
  512. **_COMMON_META_V1,
  513. "num_params": 30389784,
  514. "_metrics": {
  515. "ImageNet-1K": {
  516. "acc@1": 83.444,
  517. "acc@5": 96.628,
  518. }
  519. },
  520. "_docs": """These weights are ported from the original paper.""",
  521. },
  522. )
  523. DEFAULT = IMAGENET1K_V1
  524. class EfficientNet_B6_Weights(WeightsEnum):
  525. IMAGENET1K_V1 = Weights(
  526. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  527. url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
  528. transforms=partial(
  529. ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
  530. ),
  531. meta={
  532. **_COMMON_META_V1,
  533. "num_params": 43040704,
  534. "_metrics": {
  535. "ImageNet-1K": {
  536. "acc@1": 84.008,
  537. "acc@5": 96.916,
  538. }
  539. },
  540. "_docs": """These weights are ported from the original paper.""",
  541. },
  542. )
  543. DEFAULT = IMAGENET1K_V1
  544. class EfficientNet_B7_Weights(WeightsEnum):
  545. IMAGENET1K_V1 = Weights(
  546. # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
  547. url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
  548. transforms=partial(
  549. ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
  550. ),
  551. meta={
  552. **_COMMON_META_V1,
  553. "num_params": 66347960,
  554. "_metrics": {
  555. "ImageNet-1K": {
  556. "acc@1": 84.122,
  557. "acc@5": 96.908,
  558. }
  559. },
  560. "_docs": """These weights are ported from the original paper.""",
  561. },
  562. )
  563. DEFAULT = IMAGENET1K_V1
  564. class EfficientNet_V2_S_Weights(WeightsEnum):
  565. IMAGENET1K_V1 = Weights(
  566. url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
  567. transforms=partial(
  568. ImageClassification,
  569. crop_size=384,
  570. resize_size=384,
  571. interpolation=InterpolationMode.BILINEAR,
  572. ),
  573. meta={
  574. **_COMMON_META_V2,
  575. "num_params": 21458488,
  576. "_metrics": {
  577. "ImageNet-1K": {
  578. "acc@1": 84.228,
  579. "acc@5": 96.878,
  580. }
  581. },
  582. "_docs": """
  583. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  584. `new training recipe
  585. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  586. """,
  587. },
  588. )
  589. DEFAULT = IMAGENET1K_V1
  590. class EfficientNet_V2_M_Weights(WeightsEnum):
  591. IMAGENET1K_V1 = Weights(
  592. url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
  593. transforms=partial(
  594. ImageClassification,
  595. crop_size=480,
  596. resize_size=480,
  597. interpolation=InterpolationMode.BILINEAR,
  598. ),
  599. meta={
  600. **_COMMON_META_V2,
  601. "num_params": 54139356,
  602. "_metrics": {
  603. "ImageNet-1K": {
  604. "acc@1": 85.112,
  605. "acc@5": 97.156,
  606. }
  607. },
  608. "_docs": """
  609. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  610. `new training recipe
  611. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  612. """,
  613. },
  614. )
  615. DEFAULT = IMAGENET1K_V1
  616. class EfficientNet_V2_L_Weights(WeightsEnum):
  617. # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
  618. IMAGENET1K_V1 = Weights(
  619. url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
  620. transforms=partial(
  621. ImageClassification,
  622. crop_size=480,
  623. resize_size=480,
  624. interpolation=InterpolationMode.BICUBIC,
  625. mean=(0.5, 0.5, 0.5),
  626. std=(0.5, 0.5, 0.5),
  627. ),
  628. meta={
  629. **_COMMON_META_V2,
  630. "num_params": 118515272,
  631. "_metrics": {
  632. "ImageNet-1K": {
  633. "acc@1": 85.808,
  634. "acc@5": 97.788,
  635. }
  636. },
  637. "_docs": """These weights are ported from the original paper.""",
  638. },
  639. )
  640. DEFAULT = IMAGENET1K_V1
  641. @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
  642. def efficientnet_b0(
  643. *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
  644. ) -> EfficientNet:
  645. """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  646. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  647. Args:
  648. weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
  649. pretrained weights to use. See
  650. :class:`~torchvision.models.EfficientNet_B0_Weights` below for
  651. more details, and possible values. By default, no pre-trained
  652. weights are used.
  653. progress (bool, optional): If True, displays a progress bar of the
  654. download to stderr. Default is True.
  655. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  656. base class. Please refer to the `source code
  657. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  658. for more details about this class.
  659. .. autoclass:: torchvision.models.EfficientNet_B0_Weights
  660. :members:
  661. """
  662. weights = EfficientNet_B0_Weights.verify(weights)
  663. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
  664. return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
  665. @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
  666. def efficientnet_b1(
  667. *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
  668. ) -> EfficientNet:
  669. """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  670. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  671. Args:
  672. weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
  673. pretrained weights to use. See
  674. :class:`~torchvision.models.EfficientNet_B1_Weights` below for
  675. more details, and possible values. By default, no pre-trained
  676. weights are used.
  677. progress (bool, optional): If True, displays a progress bar of the
  678. download to stderr. Default is True.
  679. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  680. base class. Please refer to the `source code
  681. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  682. for more details about this class.
  683. .. autoclass:: torchvision.models.EfficientNet_B1_Weights
  684. :members:
  685. """
  686. weights = EfficientNet_B1_Weights.verify(weights)
  687. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
  688. return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
  689. @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
  690. def efficientnet_b2(
  691. *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
  692. ) -> EfficientNet:
  693. """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  694. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  695. Args:
  696. weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
  697. pretrained weights to use. See
  698. :class:`~torchvision.models.EfficientNet_B2_Weights` below for
  699. more details, and possible values. By default, no pre-trained
  700. weights are used.
  701. progress (bool, optional): If True, displays a progress bar of the
  702. download to stderr. Default is True.
  703. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  704. base class. Please refer to the `source code
  705. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  706. for more details about this class.
  707. .. autoclass:: torchvision.models.EfficientNet_B2_Weights
  708. :members:
  709. """
  710. weights = EfficientNet_B2_Weights.verify(weights)
  711. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
  712. return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
  713. @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
  714. def efficientnet_b3(
  715. *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
  716. ) -> EfficientNet:
  717. """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  718. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  719. Args:
  720. weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
  721. pretrained weights to use. See
  722. :class:`~torchvision.models.EfficientNet_B3_Weights` below for
  723. more details, and possible values. By default, no pre-trained
  724. weights are used.
  725. progress (bool, optional): If True, displays a progress bar of the
  726. download to stderr. Default is True.
  727. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  728. base class. Please refer to the `source code
  729. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  730. for more details about this class.
  731. .. autoclass:: torchvision.models.EfficientNet_B3_Weights
  732. :members:
  733. """
  734. weights = EfficientNet_B3_Weights.verify(weights)
  735. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
  736. return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
  737. @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
  738. def efficientnet_b4(
  739. *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
  740. ) -> EfficientNet:
  741. """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  742. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  743. Args:
  744. weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
  745. pretrained weights to use. See
  746. :class:`~torchvision.models.EfficientNet_B4_Weights` below for
  747. more details, and possible values. By default, no pre-trained
  748. weights are used.
  749. progress (bool, optional): If True, displays a progress bar of the
  750. download to stderr. Default is True.
  751. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  752. base class. Please refer to the `source code
  753. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  754. for more details about this class.
  755. .. autoclass:: torchvision.models.EfficientNet_B4_Weights
  756. :members:
  757. """
  758. weights = EfficientNet_B4_Weights.verify(weights)
  759. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
  760. return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
  761. @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
  762. def efficientnet_b5(
  763. *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
  764. ) -> EfficientNet:
  765. """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  766. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  767. Args:
  768. weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
  769. pretrained weights to use. See
  770. :class:`~torchvision.models.EfficientNet_B5_Weights` below for
  771. more details, and possible values. By default, no pre-trained
  772. weights are used.
  773. progress (bool, optional): If True, displays a progress bar of the
  774. download to stderr. Default is True.
  775. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  776. base class. Please refer to the `source code
  777. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  778. for more details about this class.
  779. .. autoclass:: torchvision.models.EfficientNet_B5_Weights
  780. :members:
  781. """
  782. weights = EfficientNet_B5_Weights.verify(weights)
  783. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
  784. return _efficientnet(
  785. inverted_residual_setting,
  786. 0.4,
  787. last_channel,
  788. weights,
  789. progress,
  790. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  791. **kwargs,
  792. )
  793. @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
  794. def efficientnet_b6(
  795. *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
  796. ) -> EfficientNet:
  797. """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  798. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  799. Args:
  800. weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
  801. pretrained weights to use. See
  802. :class:`~torchvision.models.EfficientNet_B6_Weights` below for
  803. more details, and possible values. By default, no pre-trained
  804. weights are used.
  805. progress (bool, optional): If True, displays a progress bar of the
  806. download to stderr. Default is True.
  807. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  808. base class. Please refer to the `source code
  809. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  810. for more details about this class.
  811. .. autoclass:: torchvision.models.EfficientNet_B6_Weights
  812. :members:
  813. """
  814. weights = EfficientNet_B6_Weights.verify(weights)
  815. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
  816. return _efficientnet(
  817. inverted_residual_setting,
  818. 0.5,
  819. last_channel,
  820. weights,
  821. progress,
  822. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  823. **kwargs,
  824. )
  825. @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
  826. def efficientnet_b7(
  827. *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
  828. ) -> EfficientNet:
  829. """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
  830. Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
  831. Args:
  832. weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
  833. pretrained weights to use. See
  834. :class:`~torchvision.models.EfficientNet_B7_Weights` below for
  835. more details, and possible values. By default, no pre-trained
  836. weights are used.
  837. progress (bool, optional): If True, displays a progress bar of the
  838. download to stderr. Default is True.
  839. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  840. base class. Please refer to the `source code
  841. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  842. for more details about this class.
  843. .. autoclass:: torchvision.models.EfficientNet_B7_Weights
  844. :members:
  845. """
  846. weights = EfficientNet_B7_Weights.verify(weights)
  847. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
  848. return _efficientnet(
  849. inverted_residual_setting,
  850. 0.5,
  851. last_channel,
  852. weights,
  853. progress,
  854. norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
  855. **kwargs,
  856. )
  857. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
  858. def efficientnet_v2_s(
  859. *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
  860. ) -> EfficientNet:
  861. """
  862. Constructs an EfficientNetV2-S architecture from
  863. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  864. Args:
  865. weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
  866. pretrained weights to use. See
  867. :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
  868. more details, and possible values. By default, no pre-trained
  869. weights are used.
  870. progress (bool, optional): If True, displays a progress bar of the
  871. download to stderr. Default is True.
  872. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  873. base class. Please refer to the `source code
  874. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  875. for more details about this class.
  876. .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
  877. :members:
  878. """
  879. weights = EfficientNet_V2_S_Weights.verify(weights)
  880. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
  881. return _efficientnet(
  882. inverted_residual_setting,
  883. 0.2,
  884. last_channel,
  885. weights,
  886. progress,
  887. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  888. **kwargs,
  889. )
  890. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
  891. def efficientnet_v2_m(
  892. *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
  893. ) -> EfficientNet:
  894. """
  895. Constructs an EfficientNetV2-M architecture from
  896. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  897. Args:
  898. weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
  899. pretrained weights to use. See
  900. :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
  901. more details, and possible values. By default, no pre-trained
  902. weights are used.
  903. progress (bool, optional): If True, displays a progress bar of the
  904. download to stderr. Default is True.
  905. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  906. base class. Please refer to the `source code
  907. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  908. for more details about this class.
  909. .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
  910. :members:
  911. """
  912. weights = EfficientNet_V2_M_Weights.verify(weights)
  913. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
  914. return _efficientnet(
  915. inverted_residual_setting,
  916. 0.3,
  917. last_channel,
  918. weights,
  919. progress,
  920. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  921. **kwargs,
  922. )
  923. @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
  924. def efficientnet_v2_l(
  925. *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
  926. ) -> EfficientNet:
  927. """
  928. Constructs an EfficientNetV2-L architecture from
  929. `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
  930. Args:
  931. weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
  932. pretrained weights to use. See
  933. :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
  934. more details, and possible values. By default, no pre-trained
  935. weights are used.
  936. progress (bool, optional): If True, displays a progress bar of the
  937. download to stderr. Default is True.
  938. **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
  939. base class. Please refer to the `source code
  940. <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
  941. for more details about this class.
  942. .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
  943. :members:
  944. """
  945. weights = EfficientNet_V2_L_Weights.verify(weights)
  946. inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
  947. return _efficientnet(
  948. inverted_residual_setting,
  949. 0.4,
  950. last_channel,
  951. weights,
  952. progress,
  953. norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
  954. **kwargs,
  955. )
  956. # The dictionary below is internal implementation detail and will be removed in v0.15
  957. from ._utils import _ModelURLs
  958. model_urls = _ModelURLs(
  959. {
  960. "efficientnet_b0": EfficientNet_B0_Weights.IMAGENET1K_V1.url,
  961. "efficientnet_b1": EfficientNet_B1_Weights.IMAGENET1K_V1.url,
  962. "efficientnet_b2": EfficientNet_B2_Weights.IMAGENET1K_V1.url,
  963. "efficientnet_b3": EfficientNet_B3_Weights.IMAGENET1K_V1.url,
  964. "efficientnet_b4": EfficientNet_B4_Weights.IMAGENET1K_V1.url,
  965. "efficientnet_b5": EfficientNet_B5_Weights.IMAGENET1K_V1.url,
  966. "efficientnet_b6": EfficientNet_B6_Weights.IMAGENET1K_V1.url,
  967. "efficientnet_b7": EfficientNet_B7_Weights.IMAGENET1K_V1.url,
  968. }
  969. )