convnet_benchmarks.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. # Copyright (c) 2016-present, Facebook, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. ##############################################################################
  15. ## @package convnet_benchmarks
  16. # Module caffe2.experiments.python.convnet_benchmarks
  17. """
  18. Benchmark for common convnets.
  19. (NOTE: Numbers below prior with missing parameter=update step, TODO to update)
  20. Speed on Titan X, with 10 warmup steps and 10 main steps and with different
  21. versions of cudnn, are as follows (time reported below is per-batch time,
  22. forward / forward+backward):
  23. CuDNN V3 CuDNN v4
  24. AlexNet 32.5 / 108.0 27.4 / 90.1
  25. OverFeat 113.0 / 342.3 91.7 / 276.5
  26. Inception 134.5 / 485.8 125.7 / 450.6
  27. VGG (batch 64) 200.8 / 650.0 164.1 / 551.7
  28. Speed on Inception with varied batch sizes and CuDNN v4 is as follows:
  29. Batch Size Speed per batch Speed per image
  30. 16 22.8 / 72.7 1.43 / 4.54
  31. 32 38.0 / 127.5 1.19 / 3.98
  32. 64 67.2 / 233.6 1.05 / 3.65
  33. 128 125.7 / 450.6 0.98 / 3.52
  34. Speed on Tesla M40, which 10 warmup steps and 10 main steps and with cudnn
  35. v4, is as follows:
  36. AlexNet 68.4 / 218.1
  37. OverFeat 210.5 / 630.3
  38. Inception 300.2 / 1122.2
  39. VGG (batch 64) 405.8 / 1327.7
  40. (Note that these numbers involve a "full" backprop, i.e. the gradient
  41. with respect to the input image is also computed.)
  42. To get the numbers, simply run:
  43. for MODEL in AlexNet OverFeat Inception; do
  44. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  45. --batch_size 128 --model $MODEL --forward_only True
  46. done
  47. for MODEL in AlexNet OverFeat Inception; do
  48. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  49. --batch_size 128 --model $MODEL
  50. done
  51. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  52. --batch_size 64 --model VGGA --forward_only True
  53. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  54. --batch_size 64 --model VGGA
  55. for BS in 16 32 64 128; do
  56. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  57. --batch_size $BS --model Inception --forward_only True
  58. PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
  59. --batch_size $BS --model Inception
  60. done
  61. Note that VGG needs to be run at batch 64 due to memory limit on the backward
  62. pass.
  63. """
  64. import argparse
  65. import time
  66. from caffe2.python import cnn, workspace, core
  67. import caffe2.python.SparseTransformer as SparseTransformer # type: ignore[import]
  68. def MLP(order):
  69. model = cnn.CNNModelHelper()
  70. d = 256
  71. depth = 20
  72. width = 3
  73. for i in range(depth):
  74. for j in range(width):
  75. current = "fc_{}_{}".format(i, j) if i > 0 else "data"
  76. next_ = "fc_{}_{}".format(i + 1, j)
  77. model.FC(
  78. current, next_,
  79. dim_in=d, dim_out=d,
  80. weight_init=model.XavierInit,
  81. bias_init=model.XavierInit)
  82. model.Sum(["fc_{}_{}".format(depth, j)
  83. for j in range(width)], ["sum"])
  84. model.FC("sum", "last",
  85. dim_in=d, dim_out=1000,
  86. weight_init=model.XavierInit,
  87. bias_init=model.XavierInit)
  88. xent = model.LabelCrossEntropy(["last", "label"], "xent")
  89. model.AveragedLoss(xent, "loss")
  90. return model, d
  91. def AlexNet(order):
  92. model = cnn.CNNModelHelper(order, name="alexnet",
  93. use_cudnn=True, cudnn_exhaustive_search=True)
  94. conv1 = model.Conv(
  95. "data",
  96. "conv1",
  97. 3,
  98. 64,
  99. 11,
  100. ('XavierFill', {}),
  101. ('ConstantFill', {}),
  102. stride=4,
  103. pad=2
  104. )
  105. relu1 = model.Relu(conv1, "conv1")
  106. pool1 = model.MaxPool(relu1, "pool1", kernel=3, stride=2)
  107. conv2 = model.Conv(
  108. pool1,
  109. "conv2",
  110. 64,
  111. 192,
  112. 5,
  113. ('XavierFill', {}),
  114. ('ConstantFill', {}),
  115. pad=2
  116. )
  117. relu2 = model.Relu(conv2, "conv2")
  118. pool2 = model.MaxPool(relu2, "pool2", kernel=3, stride=2)
  119. conv3 = model.Conv(
  120. pool2,
  121. "conv3",
  122. 192,
  123. 384,
  124. 3,
  125. ('XavierFill', {}),
  126. ('ConstantFill', {}),
  127. pad=1
  128. )
  129. relu3 = model.Relu(conv3, "conv3")
  130. conv4 = model.Conv(
  131. relu3,
  132. "conv4",
  133. 384,
  134. 256,
  135. 3,
  136. ('XavierFill', {}),
  137. ('ConstantFill', {}),
  138. pad=1
  139. )
  140. relu4 = model.Relu(conv4, "conv4")
  141. conv5 = model.Conv(
  142. relu4,
  143. "conv5",
  144. 256,
  145. 256,
  146. 3,
  147. ('XavierFill', {}),
  148. ('ConstantFill', {}),
  149. pad=1
  150. )
  151. relu5 = model.Relu(conv5, "conv5")
  152. pool5 = model.MaxPool(relu5, "pool5", kernel=3, stride=2)
  153. fc6 = model.FC(
  154. pool5, "fc6", 256 * 6 * 6, 4096, ('XavierFill', {}),
  155. ('ConstantFill', {})
  156. )
  157. relu6 = model.Relu(fc6, "fc6")
  158. fc7 = model.FC(
  159. relu6, "fc7", 4096, 4096, ('XavierFill', {}), ('ConstantFill', {})
  160. )
  161. relu7 = model.Relu(fc7, "fc7")
  162. fc8 = model.FC(
  163. relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
  164. )
  165. pred = model.Softmax(fc8, "pred")
  166. xent = model.LabelCrossEntropy([pred, "label"], "xent")
  167. model.AveragedLoss(xent, "loss")
  168. return model, 224
  169. def OverFeat(order):
  170. model = cnn.CNNModelHelper(order, name="overfeat",
  171. use_cudnn=True, cudnn_exhaustive_search=True)
  172. conv1 = model.Conv(
  173. "data",
  174. "conv1",
  175. 3,
  176. 96,
  177. 11,
  178. ('XavierFill', {}),
  179. ('ConstantFill', {}),
  180. stride=4
  181. )
  182. relu1 = model.Relu(conv1, "conv1")
  183. pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2)
  184. conv2 = model.Conv(
  185. pool1, "conv2", 96, 256, 5, ('XavierFill', {}), ('ConstantFill', {})
  186. )
  187. relu2 = model.Relu(conv2, "conv2")
  188. pool2 = model.MaxPool(relu2, "pool2", kernel=2, stride=2)
  189. conv3 = model.Conv(
  190. pool2,
  191. "conv3",
  192. 256,
  193. 512,
  194. 3,
  195. ('XavierFill', {}),
  196. ('ConstantFill', {}),
  197. pad=1
  198. )
  199. relu3 = model.Relu(conv3, "conv3")
  200. conv4 = model.Conv(
  201. relu3,
  202. "conv4",
  203. 512,
  204. 1024,
  205. 3,
  206. ('XavierFill', {}),
  207. ('ConstantFill', {}),
  208. pad=1
  209. )
  210. relu4 = model.Relu(conv4, "conv4")
  211. conv5 = model.Conv(
  212. relu4,
  213. "conv5",
  214. 1024,
  215. 1024,
  216. 3,
  217. ('XavierFill', {}),
  218. ('ConstantFill', {}),
  219. pad=1
  220. )
  221. relu5 = model.Relu(conv5, "conv5")
  222. pool5 = model.MaxPool(relu5, "pool5", kernel=2, stride=2)
  223. fc6 = model.FC(
  224. pool5, "fc6", 1024 * 6 * 6, 3072, ('XavierFill', {}),
  225. ('ConstantFill', {})
  226. )
  227. relu6 = model.Relu(fc6, "fc6")
  228. fc7 = model.FC(
  229. relu6, "fc7", 3072, 4096, ('XavierFill', {}), ('ConstantFill', {})
  230. )
  231. relu7 = model.Relu(fc7, "fc7")
  232. fc8 = model.FC(
  233. relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
  234. )
  235. pred = model.Softmax(fc8, "pred")
  236. xent = model.LabelCrossEntropy([pred, "label"], "xent")
  237. model.AveragedLoss(xent, "loss")
  238. return model, 231
  239. def VGGA(order):
  240. model = cnn.CNNModelHelper(order, name='vgg-a',
  241. use_cudnn=True, cudnn_exhaustive_search=True)
  242. conv1 = model.Conv(
  243. "data",
  244. "conv1",
  245. 3,
  246. 64,
  247. 3,
  248. ('XavierFill', {}),
  249. ('ConstantFill', {}),
  250. pad=1
  251. )
  252. relu1 = model.Relu(conv1, "conv1")
  253. pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2)
  254. conv2 = model.Conv(
  255. pool1,
  256. "conv2",
  257. 64,
  258. 128,
  259. 3,
  260. ('XavierFill', {}),
  261. ('ConstantFill', {}),
  262. pad=1
  263. )
  264. relu2 = model.Relu(conv2, "conv2")
  265. pool2 = model.MaxPool(relu2, "pool2", kernel=2, stride=2)
  266. conv3 = model.Conv(
  267. pool2,
  268. "conv3",
  269. 128,
  270. 256,
  271. 3,
  272. ('XavierFill', {}),
  273. ('ConstantFill', {}),
  274. pad=1
  275. )
  276. relu3 = model.Relu(conv3, "conv3")
  277. conv4 = model.Conv(
  278. relu3,
  279. "conv4",
  280. 256,
  281. 256,
  282. 3,
  283. ('XavierFill', {}),
  284. ('ConstantFill', {}),
  285. pad=1
  286. )
  287. relu4 = model.Relu(conv4, "conv4")
  288. pool4 = model.MaxPool(relu4, "pool4", kernel=2, stride=2)
  289. conv5 = model.Conv(
  290. pool4,
  291. "conv5",
  292. 256,
  293. 512,
  294. 3,
  295. ('XavierFill', {}),
  296. ('ConstantFill', {}),
  297. pad=1
  298. )
  299. relu5 = model.Relu(conv5, "conv5")
  300. conv6 = model.Conv(
  301. relu5,
  302. "conv6",
  303. 512,
  304. 512,
  305. 3,
  306. ('XavierFill', {}),
  307. ('ConstantFill', {}),
  308. pad=1
  309. )
  310. relu6 = model.Relu(conv6, "conv6")
  311. pool6 = model.MaxPool(relu6, "pool6", kernel=2, stride=2)
  312. conv7 = model.Conv(
  313. pool6,
  314. "conv7",
  315. 512,
  316. 512,
  317. 3,
  318. ('XavierFill', {}),
  319. ('ConstantFill', {}),
  320. pad=1
  321. )
  322. relu7 = model.Relu(conv7, "conv7")
  323. conv8 = model.Conv(
  324. relu7,
  325. "conv8",
  326. 512,
  327. 512,
  328. 3,
  329. ('XavierFill', {}),
  330. ('ConstantFill', {}),
  331. pad=1
  332. )
  333. relu8 = model.Relu(conv8, "conv8")
  334. pool8 = model.MaxPool(relu8, "pool8", kernel=2, stride=2)
  335. fcix = model.FC(
  336. pool8, "fcix", 512 * 7 * 7, 4096, ('XavierFill', {}),
  337. ('ConstantFill', {})
  338. )
  339. reluix = model.Relu(fcix, "fcix")
  340. fcx = model.FC(
  341. reluix, "fcx", 4096, 4096, ('XavierFill', {}), ('ConstantFill', {})
  342. )
  343. relux = model.Relu(fcx, "fcx")
  344. fcxi = model.FC(
  345. relux, "fcxi", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
  346. )
  347. pred = model.Softmax(fcxi, "pred")
  348. xent = model.LabelCrossEntropy([pred, "label"], "xent")
  349. model.AveragedLoss(xent, "loss")
  350. return model, 231
  351. def net_DAG_Builder(model):
  352. print("====================================================")
  353. print(" Start Building DAG ")
  354. print("====================================================")
  355. net_root = SparseTransformer.netbuilder(model)
  356. return net_root
  357. def _InceptionModule(
  358. model, input_blob, input_depth, output_name, conv1_depth, conv3_depths,
  359. conv5_depths, pool_depth
  360. ):
  361. # path 1: 1x1 conv
  362. conv1 = model.Conv(
  363. input_blob, output_name + ":conv1", input_depth, conv1_depth, 1,
  364. ('XavierFill', {}), ('ConstantFill', {})
  365. )
  366. conv1 = model.Relu(conv1, conv1)
  367. # path 2: 1x1 conv + 3x3 conv
  368. conv3_reduce = model.Conv(
  369. input_blob, output_name +
  370. ":conv3_reduce", input_depth, conv3_depths[0],
  371. 1, ('XavierFill', {}), ('ConstantFill', {})
  372. )
  373. conv3_reduce = model.Relu(conv3_reduce, conv3_reduce)
  374. conv3 = model.Conv(
  375. conv3_reduce,
  376. output_name + ":conv3",
  377. conv3_depths[0],
  378. conv3_depths[1],
  379. 3,
  380. ('XavierFill', {}),
  381. ('ConstantFill', {}),
  382. pad=1
  383. )
  384. conv3 = model.Relu(conv3, conv3)
  385. # path 3: 1x1 conv + 5x5 conv
  386. conv5_reduce = model.Conv(
  387. input_blob, output_name +
  388. ":conv5_reduce", input_depth, conv5_depths[0],
  389. 1, ('XavierFill', {}), ('ConstantFill', {})
  390. )
  391. conv5_reduce = model.Relu(conv5_reduce, conv5_reduce)
  392. conv5 = model.Conv(
  393. conv5_reduce,
  394. output_name + ":conv5",
  395. conv5_depths[0],
  396. conv5_depths[1],
  397. 5,
  398. ('XavierFill', {}),
  399. ('ConstantFill', {}),
  400. pad=2
  401. )
  402. conv5 = model.Relu(conv5, conv5)
  403. # path 4: pool + 1x1 conv
  404. pool = model.MaxPool(
  405. input_blob,
  406. output_name + ":pool",
  407. kernel=3,
  408. stride=1,
  409. pad=1
  410. )
  411. pool_proj = model.Conv(
  412. pool, output_name + ":pool_proj", input_depth, pool_depth, 1,
  413. ('XavierFill', {}), ('ConstantFill', {})
  414. )
  415. pool_proj = model.Relu(pool_proj, pool_proj)
  416. output = model.Concat([conv1, conv3, conv5, pool_proj], output_name)
  417. return output
  418. def Inception(order):
  419. model = cnn.CNNModelHelper(order, name="inception",
  420. use_cudnn=True, cudnn_exhaustive_search=True)
  421. conv1 = model.Conv(
  422. "data",
  423. "conv1",
  424. 3,
  425. 64,
  426. 7,
  427. ('XavierFill', {}),
  428. ('ConstantFill', {}),
  429. stride=2,
  430. pad=3
  431. )
  432. relu1 = model.Relu(conv1, "conv1")
  433. pool1 = model.MaxPool(relu1, "pool1", kernel=3, stride=2, pad=1)
  434. conv2a = model.Conv(
  435. pool1, "conv2a", 64, 64, 1, ('XavierFill', {}), ('ConstantFill', {})
  436. )
  437. conv2a = model.Relu(conv2a, conv2a)
  438. conv2 = model.Conv(
  439. conv2a,
  440. "conv2",
  441. 64,
  442. 192,
  443. 3,
  444. ('XavierFill', {}),
  445. ('ConstantFill', {}),
  446. pad=1
  447. )
  448. relu2 = model.Relu(conv2, "conv2")
  449. pool2 = model.MaxPool(relu2, "pool2", kernel=3, stride=2, pad=1)
  450. # Inception modules
  451. inc3 = _InceptionModule(
  452. model, pool2, 192, "inc3", 64, [96, 128], [16, 32], 32
  453. )
  454. inc4 = _InceptionModule(
  455. model, inc3, 256, "inc4", 128, [128, 192], [32, 96], 64
  456. )
  457. pool5 = model.MaxPool(inc4, "pool5", kernel=3, stride=2, pad=1)
  458. inc5 = _InceptionModule(
  459. model, pool5, 480, "inc5", 192, [96, 208], [16, 48], 64
  460. )
  461. inc6 = _InceptionModule(
  462. model, inc5, 512, "inc6", 160, [112, 224], [24, 64], 64
  463. )
  464. inc7 = _InceptionModule(
  465. model, inc6, 512, "inc7", 128, [128, 256], [24, 64], 64
  466. )
  467. inc8 = _InceptionModule(
  468. model, inc7, 512, "inc8", 112, [144, 288], [32, 64], 64
  469. )
  470. inc9 = _InceptionModule(
  471. model, inc8, 528, "inc9", 256, [160, 320], [32, 128], 128
  472. )
  473. pool9 = model.MaxPool(inc9, "pool9", kernel=3, stride=2, pad=1)
  474. inc10 = _InceptionModule(
  475. model, pool9, 832, "inc10", 256, [160, 320], [32, 128], 128
  476. )
  477. inc11 = _InceptionModule(
  478. model, inc10, 832, "inc11", 384, [192, 384], [48, 128], 128
  479. )
  480. pool11 = model.AveragePool(inc11, "pool11", kernel=7, stride=1)
  481. fc = model.FC(
  482. pool11, "fc", 1024, 1000, ('XavierFill', {}), ('ConstantFill', {})
  483. )
  484. # It seems that Soumith's benchmark does not have softmax on top
  485. # for Inception. We will add it anyway so we can have a proper
  486. # backward pass.
  487. pred = model.Softmax(fc, "pred")
  488. xent = model.LabelCrossEntropy([pred, "label"], "xent")
  489. model.AveragedLoss(xent, "loss")
  490. return model, 224
  491. def AddInput(model, batch_size, db, db_type):
  492. """Adds the data input part."""
  493. data_uint8, label = model.TensorProtosDBInput(
  494. [], ["data_uint8", "label"], batch_size=batch_size,
  495. db=db, db_type=db_type
  496. )
  497. data = model.Cast(data_uint8, "data_nhwc", to=core.DataType.FLOAT)
  498. data = model.NHWC2NCHW(data, "data")
  499. data = model.Scale(data, data, scale=float(1. / 256))
  500. data = model.StopGradient(data, data)
  501. return data, label
  502. def AddParameterUpdate(model):
  503. """ Simple plain SGD update -- not tuned to actually train the models """
  504. ITER = model.Iter("iter")
  505. LR = model.LearningRate(
  506. ITER, "LR", base_lr=-1e-8, policy="step", stepsize=10000, gamma=0.999)
  507. ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
  508. for param in model.params:
  509. param_grad = model.param_to_grad[param]
  510. model.WeightedSum([param, ONE, param_grad, LR], param)
  511. def Benchmark(model_gen, arg):
  512. model, input_size = model_gen(arg.order)
  513. model.Proto().type = arg.net_type
  514. model.Proto().num_workers = arg.num_workers
  515. # In order to be able to run everything without feeding more stuff, let's
  516. # add the data and label blobs to the parameter initialization net as well.
  517. if arg.order == "NCHW":
  518. input_shape = [arg.batch_size, 3, input_size, input_size]
  519. else:
  520. input_shape = [arg.batch_size, input_size, input_size, 3]
  521. if arg.model == "MLP":
  522. input_shape = [arg.batch_size, input_size]
  523. model.param_init_net.GaussianFill(
  524. [],
  525. "data",
  526. shape=input_shape,
  527. mean=0.0,
  528. std=1.0
  529. )
  530. model.param_init_net.UniformIntFill(
  531. [],
  532. "label",
  533. shape=[arg.batch_size, ],
  534. min=0,
  535. max=999
  536. )
  537. if arg.forward_only:
  538. print('{}: running forward only.'.format(arg.model))
  539. else:
  540. print('{}: running forward-backward.'.format(arg.model))
  541. model.AddGradientOperators(["loss"])
  542. AddParameterUpdate(model)
  543. if arg.order == 'NHWC':
  544. print(
  545. '==WARNING==\n'
  546. 'NHWC order with CuDNN may not be supported yet, so I might\n'
  547. 'exit suddenly.'
  548. )
  549. if not arg.cpu:
  550. model.param_init_net.RunAllOnGPU()
  551. model.net.RunAllOnGPU()
  552. if arg.dump_model:
  553. # Writes out the pbtxt for benchmarks on e.g. Android
  554. with open(
  555. "{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size), "w"
  556. ) as fid:
  557. fid.write(str(model.param_init_net.Proto()))
  558. with open("{0}.pbtxt".format(arg.model), "w") as fid:
  559. fid.write(str(model.net.Proto()))
  560. workspace.RunNetOnce(model.param_init_net)
  561. workspace.CreateNet(model.net)
  562. for i in range(arg.warmup_iterations):
  563. workspace.RunNet(model.net.Proto().name)
  564. plan = core.Plan("plan")
  565. plan.AddStep(core.ExecutionStep("run", model.net, arg.iterations))
  566. start = time.time()
  567. workspace.RunPlan(plan)
  568. print('Spent: {}'.format((time.time() - start) / arg.iterations))
  569. if arg.layer_wise_benchmark:
  570. print('Layer-wise benchmark.')
  571. workspace.BenchmarkNet(model.net.Proto().name, 1, arg.iterations, True)
  572. def GetArgumentParser():
  573. parser = argparse.ArgumentParser(description="Caffe2 benchmark.")
  574. parser.add_argument(
  575. "--batch_size",
  576. type=int,
  577. default=128,
  578. help="The batch size."
  579. )
  580. parser.add_argument("--model", type=str, help="The model to benchmark.")
  581. parser.add_argument(
  582. "--order",
  583. type=str,
  584. default="NCHW",
  585. help="The order to evaluate."
  586. )
  587. parser.add_argument(
  588. "--cudnn_ws",
  589. type=int,
  590. default=-1,
  591. help="The cudnn workspace size."
  592. )
  593. parser.add_argument(
  594. "--iterations",
  595. type=int,
  596. default=10,
  597. help="Number of iterations to run the network."
  598. )
  599. parser.add_argument(
  600. "--warmup_iterations",
  601. type=int,
  602. default=10,
  603. help="Number of warm-up iterations before benchmarking."
  604. )
  605. parser.add_argument(
  606. "--forward_only",
  607. action='store_true',
  608. help="If set, only run the forward pass."
  609. )
  610. parser.add_argument(
  611. "--layer_wise_benchmark",
  612. action='store_true',
  613. help="If True, run the layer-wise benchmark as well."
  614. )
  615. parser.add_argument(
  616. "--cpu",
  617. action='store_true',
  618. help="If True, run testing on CPU instead of GPU."
  619. )
  620. parser.add_argument(
  621. "--dump_model",
  622. action='store_true',
  623. help="If True, dump the model prototxts to disk."
  624. )
  625. parser.add_argument("--net_type", type=str, default="dag")
  626. parser.add_argument("--num_workers", type=int, default=2)
  627. return parser
  628. if __name__ == '__main__':
  629. args = GetArgumentParser().parse_args()
  630. if (
  631. not args.batch_size or not args.model or not args.order or
  632. not args.cudnn_ws
  633. ):
  634. GetArgumentParser().print_help()
  635. workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
  636. model_map = {
  637. 'AlexNet': AlexNet,
  638. 'OverFeat': OverFeat,
  639. 'VGGA': VGGA,
  640. 'Inception': Inception,
  641. 'MLP': MLP,
  642. }
  643. Benchmark(model_map[args.model], args)