imagenet_trainer.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. # Module caffe2.python.examples.resnet50_trainer
  2. import argparse
  3. import logging
  4. import numpy as np
  5. import time
  6. import os
  7. from caffe2.python import core, workspace, experiment_util, data_parallel_model
  8. from caffe2.python import dyndep, optimizer
  9. from caffe2.python import timeout_guard, model_helper, brew
  10. from caffe2.proto import caffe2_pb2
  11. import caffe2.python.models.resnet as resnet
  12. import caffe2.python.models.shufflenet as shufflenet
  13. from caffe2.python.modeling.initializers import Initializer, PseudoFP16Initializer
  14. import caffe2.python.predictor.predictor_exporter as pred_exp
  15. import caffe2.python.predictor.predictor_py_utils as pred_utils
  16. from caffe2.python.predictor_constants import predictor_constants
  17. '''
  18. Parallelized multi-GPU distributed trainer for Resne(X)t & Shufflenet.
  19. Can be used to train on imagenet data, for example.
  20. The default parameters can train a standard Resnet-50 (1x64d), and parameters
  21. can be provided to train ResNe(X)t models (e.g., ResNeXt-101 32x4d).
  22. To run the trainer in single-machine multi-gpu mode by setting num_shards = 1.
  23. To run the trainer in multi-machine multi-gpu mode with M machines,
  24. run the same program on all machines, specifying num_shards = M, and
  25. shard_id = a unique integer in the set [0, M-1].
  26. For rendezvous (the trainer processes have to know about each other),
  27. you can either use a directory path that is visible to all processes
  28. (e.g. NFS directory), or use a Redis instance. Use the former by
  29. passing the `file_store_path` argument. Use the latter by passing the
  30. `redis_host` and `redis_port` arguments.
  31. '''
  32. logging.basicConfig()
  33. log = logging.getLogger("Imagenet_trainer")
  34. log.setLevel(logging.DEBUG)
  35. dyndep.InitOpsLibrary('@/caffe2/caffe2/distributed:file_store_handler_ops')
  36. dyndep.InitOpsLibrary('@/caffe2/caffe2/distributed:redis_store_handler_ops')
  37. def AddImageInput(
  38. model,
  39. reader,
  40. batch_size,
  41. img_size,
  42. dtype,
  43. is_test,
  44. mean_per_channel=None,
  45. std_per_channel=None,
  46. ):
  47. '''
  48. The image input operator loads image and label data from the reader and
  49. applies transformations to the images (random cropping, mirroring, ...).
  50. '''
  51. data, label = brew.image_input(
  52. model,
  53. reader, ["data", "label"],
  54. batch_size=batch_size,
  55. output_type=dtype,
  56. use_gpu_transform=True if core.IsGPUDeviceType(model._device_type) else False,
  57. use_caffe_datum=True,
  58. mean_per_channel=mean_per_channel,
  59. std_per_channel=std_per_channel,
  60. # mean_per_channel takes precedence over mean
  61. mean=128.,
  62. std=128.,
  63. scale=256,
  64. crop=img_size,
  65. mirror=1,
  66. is_test=is_test,
  67. )
  68. data = model.StopGradient(data, data)
  69. def AddNullInput(model, reader, batch_size, img_size, dtype):
  70. '''
  71. The null input function uses a gaussian fill operator to emulate real image
  72. input. A label blob is hardcoded to a single value. This is useful if you
  73. want to test compute throughput or don't have a dataset available.
  74. '''
  75. suffix = "_fp16" if dtype == "float16" else ""
  76. model.param_init_net.GaussianFill(
  77. [],
  78. ["data" + suffix],
  79. shape=[batch_size, 3, img_size, img_size],
  80. )
  81. if dtype == "float16":
  82. model.param_init_net.FloatToHalf("data" + suffix, "data")
  83. model.param_init_net.ConstantFill(
  84. [],
  85. ["label"],
  86. shape=[batch_size],
  87. value=1,
  88. dtype=core.DataType.INT32,
  89. )
  90. def SaveModel(args, train_model, epoch, use_ideep):
  91. prefix = "[]_{}".format(train_model._device_prefix, train_model._devices[0])
  92. predictor_export_meta = pred_exp.PredictorExportMeta(
  93. predict_net=train_model.net.Proto(),
  94. parameters=data_parallel_model.GetCheckpointParams(train_model),
  95. inputs=[prefix + "/data"],
  96. outputs=[prefix + "/softmax"],
  97. shapes={
  98. prefix + "/softmax": (1, args.num_labels),
  99. prefix + "/data": (args.num_channels, args.image_size, args.image_size)
  100. }
  101. )
  102. # save the train_model for the current epoch
  103. model_path = "%s/%s_%d.mdl" % (
  104. args.file_store_path,
  105. args.save_model_name,
  106. epoch,
  107. )
  108. # set db_type to be "minidb" instead of "log_file_db", which breaks
  109. # the serialization in save_to_db. Need to switch back to log_file_db
  110. # after migration
  111. pred_exp.save_to_db(
  112. db_type="minidb",
  113. db_destination=model_path,
  114. predictor_export_meta=predictor_export_meta,
  115. use_ideep=use_ideep
  116. )
  117. def LoadModel(path, model, use_ideep):
  118. '''
  119. Load pretrained model from file
  120. '''
  121. log.info("Loading path: {}".format(path))
  122. meta_net_def = pred_exp.load_from_db(path, 'minidb')
  123. init_net = core.Net(pred_utils.GetNet(
  124. meta_net_def, predictor_constants.GLOBAL_INIT_NET_TYPE))
  125. predict_init_net = core.Net(pred_utils.GetNet(
  126. meta_net_def, predictor_constants.PREDICT_INIT_NET_TYPE))
  127. if use_ideep:
  128. predict_init_net.RunAllOnIDEEP()
  129. else:
  130. predict_init_net.RunAllOnGPU()
  131. if use_ideep:
  132. init_net.RunAllOnIDEEP()
  133. else:
  134. init_net.RunAllOnGPU()
  135. assert workspace.RunNetOnce(predict_init_net)
  136. assert workspace.RunNetOnce(init_net)
  137. # Hack: fix iteration counter which is in CUDA context after load model
  138. itercnt = workspace.FetchBlob("optimizer_iteration")
  139. workspace.FeedBlob(
  140. "optimizer_iteration",
  141. itercnt,
  142. device_option=core.DeviceOption(caffe2_pb2.CPU, 0)
  143. )
  144. def RunEpoch(
  145. args,
  146. epoch,
  147. train_model,
  148. test_model,
  149. total_batch_size,
  150. num_shards,
  151. expname,
  152. explog,
  153. ):
  154. '''
  155. Run one epoch of the trainer.
  156. TODO: add checkpointing here.
  157. '''
  158. # TODO: add loading from checkpoint
  159. log.info("Starting epoch {}/{}".format(epoch, args.num_epochs))
  160. epoch_iters = int(args.epoch_size / total_batch_size / num_shards)
  161. test_epoch_iters = int(args.test_epoch_size / total_batch_size / num_shards)
  162. for i in range(epoch_iters):
  163. # This timeout is required (temporarily) since CUDA-NCCL
  164. # operators might deadlock when synchronizing between GPUs.
  165. timeout = args.first_iter_timeout if i == 0 else args.timeout
  166. with timeout_guard.CompleteInTimeOrDie(timeout):
  167. t1 = time.time()
  168. workspace.RunNet(train_model.net.Proto().name)
  169. t2 = time.time()
  170. dt = t2 - t1
  171. fmt = "Finished iteration {}/{} of epoch {} ({:.2f} images/sec)"
  172. log.info(fmt.format(i + 1, epoch_iters, epoch, total_batch_size / dt))
  173. prefix = "{}_{}".format(
  174. train_model._device_prefix,
  175. train_model._devices[0])
  176. accuracy = workspace.FetchBlob(prefix + '/accuracy')
  177. loss = workspace.FetchBlob(prefix + '/loss')
  178. train_fmt = "Training loss: {}, accuracy: {}"
  179. log.info(train_fmt.format(loss, accuracy))
  180. num_images = epoch * epoch_iters * total_batch_size
  181. prefix = "{}_{}".format(train_model._device_prefix, train_model._devices[0])
  182. accuracy = workspace.FetchBlob(prefix + '/accuracy')
  183. loss = workspace.FetchBlob(prefix + '/loss')
  184. learning_rate = workspace.FetchBlob(
  185. data_parallel_model.GetLearningRateBlobNames(train_model)[0]
  186. )
  187. test_accuracy = 0
  188. test_accuracy_top5 = 0
  189. if test_model is not None:
  190. # Run 100 iters of testing
  191. ntests = 0
  192. for _ in range(test_epoch_iters):
  193. workspace.RunNet(test_model.net.Proto().name)
  194. for g in test_model._devices:
  195. test_accuracy += np.asscalar(workspace.FetchBlob(
  196. "{}_{}".format(test_model._device_prefix, g) + '/accuracy'
  197. ))
  198. test_accuracy_top5 += np.asscalar(workspace.FetchBlob(
  199. "{}_{}".format(test_model._device_prefix, g) + '/accuracy_top5'
  200. ))
  201. ntests += 1
  202. test_accuracy /= ntests
  203. test_accuracy_top5 /= ntests
  204. else:
  205. test_accuracy = (-1)
  206. test_accuracy_top5 = (-1)
  207. explog.log(
  208. input_count=num_images,
  209. batch_count=(i + epoch * epoch_iters),
  210. additional_values={
  211. 'accuracy': accuracy,
  212. 'loss': loss,
  213. 'learning_rate': learning_rate,
  214. 'epoch': epoch,
  215. 'top1_test_accuracy': test_accuracy,
  216. 'top5_test_accuracy': test_accuracy_top5,
  217. }
  218. )
  219. assert loss < 40, "Exploded gradients :("
  220. # TODO: add checkpointing
  221. return epoch + 1
  222. def Train(args):
  223. if args.model == "resnext":
  224. model_name = "resnext" + str(args.num_layers)
  225. elif args.model == "shufflenet":
  226. model_name = "shufflenet"
  227. # Either use specified device list or generate one
  228. if args.gpus is not None:
  229. gpus = [int(x) for x in args.gpus.split(',')]
  230. num_gpus = len(gpus)
  231. else:
  232. gpus = list(range(args.num_gpus))
  233. num_gpus = args.num_gpus
  234. log.info("Running on GPUs: {}".format(gpus))
  235. # Verify valid batch size
  236. total_batch_size = args.batch_size
  237. batch_per_device = total_batch_size // num_gpus
  238. assert \
  239. total_batch_size % num_gpus == 0, \
  240. "Number of GPUs must divide batch size"
  241. # Verify valid image mean/std per channel
  242. if args.image_mean_per_channel:
  243. assert \
  244. len(args.image_mean_per_channel) == args.num_channels, \
  245. "The number of channels of image mean doesn't match input"
  246. if args.image_std_per_channel:
  247. assert \
  248. len(args.image_std_per_channel) == args.num_channels, \
  249. "The number of channels of image std doesn't match input"
  250. # Round down epoch size to closest multiple of batch size across machines
  251. global_batch_size = total_batch_size * args.num_shards
  252. epoch_iters = int(args.epoch_size / global_batch_size)
  253. assert \
  254. epoch_iters > 0, \
  255. "Epoch size must be larger than batch size times shard count"
  256. args.epoch_size = epoch_iters * global_batch_size
  257. log.info("Using epoch size: {}".format(args.epoch_size))
  258. # Create ModelHelper object
  259. if args.use_ideep:
  260. train_arg_scope = {
  261. 'use_cudnn': False,
  262. 'cudnn_exhaustive_search': False,
  263. 'training_mode': 1
  264. }
  265. else:
  266. train_arg_scope = {
  267. 'order': 'NCHW',
  268. 'use_cudnn': True,
  269. 'cudnn_exhaustive_search': True,
  270. 'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
  271. }
  272. train_model = model_helper.ModelHelper(
  273. name=model_name, arg_scope=train_arg_scope
  274. )
  275. num_shards = args.num_shards
  276. shard_id = args.shard_id
  277. # Expect interfaces to be comma separated.
  278. # Use of multiple network interfaces is not yet complete,
  279. # so simply use the first one in the list.
  280. interfaces = args.distributed_interfaces.split(",")
  281. # Rendezvous using MPI when run with mpirun
  282. if os.getenv("OMPI_COMM_WORLD_SIZE") is not None:
  283. num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
  284. shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
  285. if num_shards > 1:
  286. rendezvous = dict(
  287. kv_handler=None,
  288. num_shards=num_shards,
  289. shard_id=shard_id,
  290. engine="GLOO",
  291. transport=args.distributed_transport,
  292. interface=interfaces[0],
  293. mpi_rendezvous=True,
  294. exit_nets=None)
  295. elif num_shards > 1:
  296. # Create rendezvous for distributed computation
  297. store_handler = "store_handler"
  298. if args.redis_host is not None:
  299. # Use Redis for rendezvous if Redis host is specified
  300. workspace.RunOperatorOnce(
  301. core.CreateOperator(
  302. "RedisStoreHandlerCreate", [], [store_handler],
  303. host=args.redis_host,
  304. port=args.redis_port,
  305. prefix=args.run_id,
  306. )
  307. )
  308. else:
  309. # Use filesystem for rendezvous otherwise
  310. workspace.RunOperatorOnce(
  311. core.CreateOperator(
  312. "FileStoreHandlerCreate", [], [store_handler],
  313. path=args.file_store_path,
  314. prefix=args.run_id,
  315. )
  316. )
  317. rendezvous = dict(
  318. kv_handler=store_handler,
  319. shard_id=shard_id,
  320. num_shards=num_shards,
  321. engine="GLOO",
  322. transport=args.distributed_transport,
  323. interface=interfaces[0],
  324. exit_nets=None)
  325. else:
  326. rendezvous = None
  327. # Model building functions
  328. def create_resnext_model_ops(model, loss_scale):
  329. initializer = (PseudoFP16Initializer if args.dtype == 'float16'
  330. else Initializer)
  331. with brew.arg_scope([brew.conv, brew.fc],
  332. WeightInitializer=initializer,
  333. BiasInitializer=initializer,
  334. enable_tensor_core=args.enable_tensor_core,
  335. float16_compute=args.float16_compute):
  336. pred = resnet.create_resnext(
  337. model,
  338. "data",
  339. num_input_channels=args.num_channels,
  340. num_labels=args.num_labels,
  341. num_layers=args.num_layers,
  342. num_groups=args.resnext_num_groups,
  343. num_width_per_group=args.resnext_width_per_group,
  344. no_bias=True,
  345. no_loss=True,
  346. )
  347. if args.dtype == 'float16':
  348. pred = model.net.HalfToFloat(pred, pred + '_fp32')
  349. softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
  350. ['softmax', 'loss'])
  351. loss = model.Scale(loss, scale=loss_scale)
  352. brew.accuracy(model, [softmax, "label"], "accuracy", top_k=1)
  353. brew.accuracy(model, [softmax, "label"], "accuracy_top5", top_k=5)
  354. return [loss]
  355. def create_shufflenet_model_ops(model, loss_scale):
  356. initializer = (PseudoFP16Initializer if args.dtype == 'float16'
  357. else Initializer)
  358. with brew.arg_scope([brew.conv, brew.fc],
  359. WeightInitializer=initializer,
  360. BiasInitializer=initializer,
  361. enable_tensor_core=args.enable_tensor_core,
  362. float16_compute=args.float16_compute):
  363. pred = shufflenet.create_shufflenet(
  364. model,
  365. "data",
  366. num_input_channels=args.num_channels,
  367. num_labels=args.num_labels,
  368. no_loss=True,
  369. )
  370. if args.dtype == 'float16':
  371. pred = model.net.HalfToFloat(pred, pred + '_fp32')
  372. softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
  373. ['softmax', 'loss'])
  374. loss = model.Scale(loss, scale=loss_scale)
  375. brew.accuracy(model, [softmax, "label"], "accuracy", top_k=1)
  376. brew.accuracy(model, [softmax, "label"], "accuracy_top5", top_k=5)
  377. return [loss]
  378. def add_optimizer(model):
  379. stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)
  380. if args.float16_compute:
  381. # TODO: merge with multi-precision optimizer
  382. opt = optimizer.build_fp16_sgd(
  383. model,
  384. args.base_learning_rate,
  385. momentum=0.9,
  386. nesterov=1,
  387. weight_decay=args.weight_decay, # weight decay included
  388. policy="step",
  389. stepsize=stepsz,
  390. gamma=0.1
  391. )
  392. else:
  393. optimizer.add_weight_decay(model, args.weight_decay)
  394. opt = optimizer.build_multi_precision_sgd(
  395. model,
  396. args.base_learning_rate,
  397. momentum=0.9,
  398. nesterov=1,
  399. policy="step",
  400. stepsize=stepsz,
  401. gamma=0.1
  402. )
  403. return opt
  404. # Define add_image_input function.
  405. # Depends on the "train_data" argument.
  406. # Note that the reader will be shared with between all GPUS.
  407. if args.train_data == "null":
  408. def add_image_input(model):
  409. AddNullInput(
  410. model,
  411. None,
  412. batch_size=batch_per_device,
  413. img_size=args.image_size,
  414. dtype=args.dtype,
  415. )
  416. else:
  417. reader = train_model.CreateDB(
  418. "reader",
  419. db=args.train_data,
  420. db_type=args.db_type,
  421. num_shards=num_shards,
  422. shard_id=shard_id,
  423. )
  424. def add_image_input(model):
  425. AddImageInput(
  426. model,
  427. reader,
  428. batch_size=batch_per_device,
  429. img_size=args.image_size,
  430. dtype=args.dtype,
  431. is_test=False,
  432. mean_per_channel=args.image_mean_per_channel,
  433. std_per_channel=args.image_std_per_channel,
  434. )
  435. def add_post_sync_ops(model):
  436. """Add ops applied after initial parameter sync."""
  437. for param_info in model.GetOptimizationParamInfo(model.GetParams()):
  438. if param_info.blob_copy is not None:
  439. model.param_init_net.HalfToFloat(
  440. param_info.blob,
  441. param_info.blob_copy[core.DataType.FLOAT]
  442. )
  443. data_parallel_model.Parallelize(
  444. train_model,
  445. input_builder_fun=add_image_input,
  446. forward_pass_builder_fun=create_resnext_model_ops
  447. if args.model == "resnext" else create_shufflenet_model_ops,
  448. optimizer_builder_fun=add_optimizer,
  449. post_sync_builder_fun=add_post_sync_ops,
  450. devices=gpus,
  451. rendezvous=rendezvous,
  452. optimize_gradient_memory=False,
  453. use_nccl=args.use_nccl,
  454. cpu_device=args.use_cpu,
  455. ideep=args.use_ideep,
  456. shared_model=args.use_cpu,
  457. combine_spatial_bn=args.use_cpu,
  458. )
  459. data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False)
  460. workspace.RunNetOnce(train_model.param_init_net)
  461. workspace.CreateNet(train_model.net)
  462. # Add test model, if specified
  463. test_model = None
  464. if (args.test_data is not None):
  465. log.info("----- Create test net ----")
  466. if args.use_ideep:
  467. test_arg_scope = {
  468. 'use_cudnn': False,
  469. 'cudnn_exhaustive_search': False,
  470. }
  471. else:
  472. test_arg_scope = {
  473. 'order': "NCHW",
  474. 'use_cudnn': True,
  475. 'cudnn_exhaustive_search': True,
  476. }
  477. test_model = model_helper.ModelHelper(
  478. name=model_name + "_test",
  479. arg_scope=test_arg_scope,
  480. init_params=False,
  481. )
  482. test_reader = test_model.CreateDB(
  483. "test_reader",
  484. db=args.test_data,
  485. db_type=args.db_type,
  486. )
  487. def test_input_fn(model):
  488. AddImageInput(
  489. model,
  490. test_reader,
  491. batch_size=batch_per_device,
  492. img_size=args.image_size,
  493. dtype=args.dtype,
  494. is_test=True,
  495. mean_per_channel=args.image_mean_per_channel,
  496. std_per_channel=args.image_std_per_channel,
  497. )
  498. data_parallel_model.Parallelize(
  499. test_model,
  500. input_builder_fun=test_input_fn,
  501. forward_pass_builder_fun=create_resnext_model_ops
  502. if args.model == "resnext" else create_shufflenet_model_ops,
  503. post_sync_builder_fun=add_post_sync_ops,
  504. param_update_builder_fun=None,
  505. devices=gpus,
  506. use_nccl=args.use_nccl,
  507. cpu_device=args.use_cpu,
  508. )
  509. workspace.RunNetOnce(test_model.param_init_net)
  510. workspace.CreateNet(test_model.net)
  511. epoch = 0
  512. # load the pre-trained model and reset epoch
  513. if args.load_model_path is not None:
  514. LoadModel(args.load_model_path, train_model, args.use_ideep)
  515. # Sync the model params
  516. data_parallel_model.FinalizeAfterCheckpoint(train_model)
  517. # reset epoch. load_model_path should end with *_X.mdl,
  518. # where X is the epoch number
  519. last_str = args.load_model_path.split('_')[-1]
  520. if last_str.endswith('.mdl'):
  521. epoch = int(last_str[:-4])
  522. log.info("Reset epoch to {}".format(epoch))
  523. else:
  524. log.warning("The format of load_model_path doesn't match!")
  525. expname = "%s_gpu%d_b%d_L%d_lr%.2f_v2" % (
  526. model_name,
  527. args.num_gpus,
  528. total_batch_size,
  529. args.num_labels,
  530. args.base_learning_rate,
  531. )
  532. explog = experiment_util.ModelTrainerLog(expname, args)
  533. # Run the training one epoch a time
  534. while epoch < args.num_epochs:
  535. epoch = RunEpoch(
  536. args,
  537. epoch,
  538. train_model,
  539. test_model,
  540. total_batch_size,
  541. num_shards,
  542. expname,
  543. explog
  544. )
  545. # Save the model for each epoch
  546. SaveModel(args, train_model, epoch, args.use_ideep)
  547. model_path = "%s/%s_" % (
  548. args.file_store_path,
  549. args.save_model_name
  550. )
  551. # remove the saved model from the previous epoch if it exists
  552. if os.path.isfile(model_path + str(epoch - 1) + ".mdl"):
  553. os.remove(model_path + str(epoch - 1) + ".mdl")
  554. def main():
  555. # TODO: use argv
  556. parser = argparse.ArgumentParser(
  557. description="Caffe2: ImageNet Trainer"
  558. )
  559. parser.add_argument("--train_data", type=str, default=None, required=True,
  560. help="Path to training data (or 'null' to simulate)")
  561. parser.add_argument("--num_layers", type=int, default=50,
  562. help="The number of layers in ResNe(X)t model")
  563. parser.add_argument("--resnext_num_groups", type=int, default=1,
  564. help="The cardinality of resnext")
  565. parser.add_argument("--resnext_width_per_group", type=int, default=64,
  566. help="The cardinality of resnext")
  567. parser.add_argument("--test_data", type=str, default=None,
  568. help="Path to test data")
  569. parser.add_argument("--image_mean_per_channel", type=float, nargs='+',
  570. help="The per channel mean for the images")
  571. parser.add_argument("--image_std_per_channel", type=float, nargs='+',
  572. help="The per channel standard deviation for the images")
  573. parser.add_argument("--test_epoch_size", type=int, default=50000,
  574. help="Number of test images")
  575. parser.add_argument("--db_type", type=str, default="lmdb",
  576. help="Database type (such as lmdb or leveldb)")
  577. parser.add_argument("--gpus", type=str,
  578. help="Comma separated list of GPU devices to use")
  579. parser.add_argument("--num_gpus", type=int, default=1,
  580. help="Number of GPU devices (instead of --gpus)")
  581. parser.add_argument("--num_channels", type=int, default=3,
  582. help="Number of color channels")
  583. parser.add_argument("--image_size", type=int, default=224,
  584. help="Input image size (to crop to)")
  585. parser.add_argument("--num_labels", type=int, default=1000,
  586. help="Number of labels")
  587. parser.add_argument("--batch_size", type=int, default=32,
  588. help="Batch size, total over all GPUs")
  589. parser.add_argument("--epoch_size", type=int, default=1500000,
  590. help="Number of images/epoch, total over all machines")
  591. parser.add_argument("--num_epochs", type=int, default=1000,
  592. help="Num epochs.")
  593. parser.add_argument("--base_learning_rate", type=float, default=0.1,
  594. help="Initial learning rate.")
  595. parser.add_argument("--weight_decay", type=float, default=1e-4,
  596. help="Weight decay (L2 regularization)")
  597. parser.add_argument("--cudnn_workspace_limit_mb", type=int, default=64,
  598. help="CuDNN workspace limit in MBs")
  599. parser.add_argument("--num_shards", type=int, default=1,
  600. help="Number of machines in distributed run")
  601. parser.add_argument("--shard_id", type=int, default=0,
  602. help="Shard id.")
  603. parser.add_argument("--run_id", type=str,
  604. help="Unique run identifier (e.g. uuid)")
  605. parser.add_argument("--redis_host", type=str,
  606. help="Host of Redis server (for rendezvous)")
  607. parser.add_argument("--redis_port", type=int, default=6379,
  608. help="Port of Redis server (for rendezvous)")
  609. parser.add_argument("--file_store_path", type=str, default="/tmp",
  610. help="Path to directory to use for rendezvous")
  611. parser.add_argument("--save_model_name", type=str, default="resnext_model",
  612. help="Save the trained model to a given name")
  613. parser.add_argument("--load_model_path", type=str, default=None,
  614. help="Load previously saved model to continue training")
  615. parser.add_argument("--use_cpu", action="store_true",
  616. help="Use CPU instead of GPU")
  617. parser.add_argument("--use_nccl", action="store_true",
  618. help="Use nccl for inter-GPU collectives")
  619. parser.add_argument("--use_ideep", type=bool, default=False,
  620. help="Use ideep")
  621. parser.add_argument('--dtype', default='float',
  622. choices=['float', 'float16'],
  623. help='Data type used for training')
  624. parser.add_argument('--float16_compute', action='store_true',
  625. help="Use float 16 compute, if available")
  626. parser.add_argument('--enable_tensor_core', action='store_true',
  627. help='Enable Tensor Core math for Conv and FC ops')
  628. parser.add_argument("--distributed_transport", type=str, default="tcp",
  629. help="Transport to use for distributed run [tcp|ibverbs]")
  630. parser.add_argument("--distributed_interfaces", type=str, default="",
  631. help="Network interfaces to use for distributed run")
  632. parser.add_argument("--first_iter_timeout", type=int, default=1200,
  633. help="Timeout (secs) of the first iteration "
  634. "(default: %(default)s)")
  635. parser.add_argument("--timeout", type=int, default=60,
  636. help="Timeout (secs) of each (except the first) iteration "
  637. "(default: %(default)s)")
  638. parser.add_argument("--model",
  639. default="resnext", const="resnext", nargs="?",
  640. choices=["shufflenet", "resnext"],
  641. help="List of models which can be run")
  642. args = parser.parse_args()
  643. Train(args)
  644. if __name__ == '__main__':
  645. workspace.GlobalInit(['caffe2', '--caffe2_log_level=2'])
  646. main()