net_printer_test.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from caffe2.python import net_printer
  2. from caffe2.python.checkpoint import Job
  3. from caffe2.python.net_builder import ops
  4. from caffe2.python.task import Task, final_output, WorkspaceType
  5. import unittest
  6. def example_loop():
  7. with Task():
  8. total = ops.Const(0)
  9. total_large = ops.Const(0)
  10. total_small = ops.Const(0)
  11. total_tiny = ops.Const(0)
  12. with ops.loop(10) as loop:
  13. outer = ops.Mul([loop.iter(), ops.Const(10)])
  14. with ops.loop(loop.iter()) as inner:
  15. val = ops.Add([outer, inner.iter()])
  16. with ops.If(ops.GE([val, ops.Const(80)])) as c:
  17. ops.Add([total_large, val], [total_large])
  18. with c.Elif(ops.GE([val, ops.Const(50)])) as c:
  19. ops.Add([total_small, val], [total_small])
  20. with c.Else():
  21. ops.Add([total_tiny, val], [total_tiny])
  22. ops.Add([total, val], total)
  23. def example_task():
  24. with Task():
  25. with ops.task_init():
  26. one = ops.Const(1)
  27. two = ops.Add([one, one])
  28. with ops.task_init():
  29. three = ops.Const(3)
  30. accum = ops.Add([two, three])
  31. # here, accum should be 5
  32. with ops.task_exit():
  33. # here, accum should be 6, since this executes after lines below
  34. seven_1 = ops.Add([accum, one])
  35. six = ops.Add([accum, one])
  36. ops.Add([accum, one], [accum])
  37. seven_2 = ops.Add([accum, one])
  38. o6 = final_output(six)
  39. o7_1 = final_output(seven_1)
  40. o7_2 = final_output(seven_2)
  41. with Task(num_instances=2):
  42. with ops.task_init():
  43. one = ops.Const(1)
  44. with ops.task_instance_init():
  45. local = ops.Const(2)
  46. ops.Add([one, local], [one])
  47. ops.LogInfo('ble')
  48. return o6, o7_1, o7_2
  49. def example_job():
  50. with Job() as job:
  51. with job.init_group:
  52. example_loop()
  53. example_task()
  54. return job
  55. class TestNetPrinter(unittest.TestCase):
  56. def test_print(self):
  57. self.assertTrue(len(net_printer.to_string(example_job())) > 0)
  58. def test_valid_job(self):
  59. job = example_job()
  60. with job:
  61. with Task():
  62. # distributed_ctx_init_* ignored by analyzer
  63. ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b'])
  64. # net_printer.analyze(example_job())
  65. print(net_printer.to_string(example_job()))
  66. def test_undefined_blob(self):
  67. job = example_job()
  68. with job:
  69. with Task():
  70. ops.Add(['a', 'b'])
  71. with self.assertRaises(AssertionError) as e:
  72. net_printer.analyze(job)
  73. self.assertEqual("Blob undefined: a", str(e.exception))
  74. def test_multiple_definition(self):
  75. job = example_job()
  76. with job:
  77. with Task(workspace_type=WorkspaceType.GLOBAL):
  78. ops.Add([ops.Const(0), ops.Const(1)], 'out1')
  79. with Task(workspace_type=WorkspaceType.GLOBAL):
  80. ops.Add([ops.Const(2), ops.Const(3)], 'out1')
  81. with self.assertRaises(AssertionError):
  82. net_printer.analyze(job)