initializers_test.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import unittest
  2. from caffe2.python import brew, model_helper, workspace
  3. from caffe2.python.modeling.initializers import (
  4. Initializer, PseudoFP16Initializer)
  5. class InitializerTest(unittest.TestCase):
  6. def test_fc_initializer(self):
  7. model = model_helper.ModelHelper(name="test")
  8. data = model.net.AddExternalInput("data")
  9. fc1 = brew.fc(model, data, "fc1", dim_in=1, dim_out=1)
  10. # no operator name set, will use default
  11. fc2 = brew.fc(model, fc1, "fc2", dim_in=1, dim_out=1,
  12. WeightInitializer=Initializer)
  13. # no operator name set, will use custom
  14. fc3 = brew.fc(model, fc2, "fc3", dim_in=1, dim_out=1,
  15. WeightInitializer=Initializer,
  16. weight_init=("ConstantFill", {}),
  17. )
  18. # operator name set, no initializer class set
  19. fc4 = brew.fc(model, fc3, "fc4", dim_in=1, dim_out=1,
  20. WeightInitializer=None,
  21. weight_init=("ConstantFill", {})
  22. )
  23. @unittest.skipIf(not workspace.has_gpu_support, 'No GPU support')
  24. def test_fc_fp16_initializer(self):
  25. model = model_helper.ModelHelper(name="test")
  26. data = model.net.AddExternalInput("data")
  27. fc1 = brew.fc(model, data, "fc1", dim_in=1, dim_out=1)
  28. # default operator, PseudoFP16Initializer
  29. fc2 = brew.fc(model, fc1, "fc2", dim_in=1, dim_out=1,
  30. WeightInitializer=PseudoFP16Initializer
  31. )
  32. # specified operator, PseudoFP16Initializer
  33. fc3 = brew.fc(model, fc2, "fc3", dim_in=1, dim_out=1,
  34. weight_init=("ConstantFill", {}),
  35. WeightInitializer=PseudoFP16Initializer
  36. )
  37. def test_fc_external_initializer(self):
  38. model = model_helper.ModelHelper(name="test", init_params=False)
  39. data = model.net.AddExternalInput("data")
  40. fc1 = brew.fc(model, data, "fc1", dim_in=1, dim_out=1) # noqa
  41. self.assertEqual(len(model.net.Proto().op), 1)
  42. self.assertEqual(len(model.param_init_net.Proto().op), 0)