layer_test_util.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. ## @package layer_test_util
  2. # Module caffe2.python.layer_test_util
  3. from collections import namedtuple
  4. from caffe2.python import (
  5. core,
  6. layer_model_instantiator,
  7. layer_model_helper,
  8. schema,
  9. test_util,
  10. workspace,
  11. utils,
  12. )
  13. from caffe2.proto import caffe2_pb2
  14. import numpy as np
  15. # pyre-fixme[13]: Pyre can't detect attribute initialization through the
  16. # super().__new__ call
  17. class OpSpec(namedtuple("OpSpec", "type input output arg")):
  18. def __new__(cls, op_type, op_input, op_output, op_arg=None):
  19. return super(OpSpec, cls).__new__(cls, op_type, op_input,
  20. op_output, op_arg)
  21. class LayersTestCase(test_util.TestCase):
  22. def setUp(self):
  23. super(LayersTestCase, self).setUp()
  24. self.setup_example()
  25. def setup_example(self):
  26. """
  27. This is undocumented feature in hypothesis,
  28. https://github.com/HypothesisWorks/hypothesis-python/issues/59
  29. """
  30. workspace.ResetWorkspace()
  31. self.reset_model()
  32. def reset_model(self, input_feature_schema=None, trainer_extra_schema=None):
  33. input_feature_schema = input_feature_schema or schema.Struct(
  34. ('float_features', schema.Scalar((np.float32, (32,)))),
  35. )
  36. trainer_extra_schema = trainer_extra_schema or schema.Struct()
  37. self.model = layer_model_helper.LayerModelHelper(
  38. 'test_model',
  39. input_feature_schema=input_feature_schema,
  40. trainer_extra_schema=trainer_extra_schema)
  41. def new_record(self, schema_obj):
  42. return schema.NewRecord(self.model.net, schema_obj)
  43. def get_training_nets(self, add_constants=False):
  44. """
  45. We don't use
  46. layer_model_instantiator.generate_training_nets_forward_only()
  47. here because it includes initialization of global constants, which make
  48. testing tricky
  49. """
  50. train_net = core.Net('train_net')
  51. if add_constants:
  52. train_init_net = self.model.create_init_net('train_init_net')
  53. else:
  54. train_init_net = core.Net('train_init_net')
  55. for layer in self.model.layers:
  56. layer.add_operators(train_net, train_init_net)
  57. return train_init_net, train_net
  58. def get_eval_net(self):
  59. return layer_model_instantiator.generate_eval_net(self.model)
  60. def get_predict_net(self):
  61. return layer_model_instantiator.generate_predict_net(self.model)
  62. def run_train_net(self):
  63. self.model.output_schema = schema.Struct()
  64. train_init_net, train_net = \
  65. layer_model_instantiator.generate_training_nets(self.model)
  66. workspace.RunNetOnce(train_init_net)
  67. workspace.RunNetOnce(train_net)
  68. def run_train_net_forward_only(self, num_iter=1):
  69. self.model.output_schema = schema.Struct()
  70. train_init_net, train_net = \
  71. layer_model_instantiator.generate_training_nets_forward_only(
  72. self.model)
  73. workspace.RunNetOnce(train_init_net)
  74. assert num_iter > 0, 'num_iter must be larger than 0'
  75. workspace.CreateNet(train_net)
  76. workspace.RunNet(train_net.Proto().name, num_iter=num_iter)
  77. def assertBlobsEqual(self, spec_blobs, op_blobs):
  78. """
  79. spec_blobs can either be None or a list of blob names. If it's None,
  80. then no assertion is performed. The elements of the list can be None,
  81. in that case, it means that position will not be checked.
  82. """
  83. if spec_blobs is None:
  84. return
  85. self.assertEqual(len(spec_blobs), len(op_blobs))
  86. for spec_blob, op_blob in zip(spec_blobs, op_blobs):
  87. if spec_blob is None:
  88. continue
  89. self.assertEqual(spec_blob, op_blob)
  90. def assertArgsEqual(self, spec_args, op_args):
  91. self.assertEqual(len(spec_args), len(op_args))
  92. keys = [a.name for a in op_args]
  93. def parse_args(args):
  94. operator = caffe2_pb2.OperatorDef()
  95. # Generate the expected value in the same order
  96. for k in keys:
  97. v = args[k]
  98. arg = utils.MakeArgument(k, v)
  99. operator.arg.add().CopyFrom(arg)
  100. return operator.arg
  101. self.assertEqual(parse_args(spec_args), op_args)
  102. def assertNetContainOps(self, net, op_specs):
  103. """
  104. Given a net and a list of OpSpec's, check that the net match the spec
  105. """
  106. ops = net.Proto().op
  107. self.assertEqual(len(op_specs), len(ops))
  108. for op, op_spec in zip(ops, op_specs):
  109. self.assertEqual(op_spec.type, op.type)
  110. self.assertBlobsEqual(op_spec.input, op.input)
  111. self.assertBlobsEqual(op_spec.output, op.output)
  112. if op_spec.arg is not None:
  113. self.assertArgsEqual(op_spec.arg, op.arg)
  114. return ops