crf_viterbi_test.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from caffe2.python import workspace, crf
  2. from caffe2.python.cnn import CNNModelHelper
  3. from caffe2.python.crf_predict import crf_update_predictions
  4. from caffe2.python.test_util import TestCase
  5. import hypothesis.strategies as st
  6. from hypothesis import given, settings
  7. import numpy as np
  8. class TestCrfDecode(TestCase):
  9. @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15))
  10. @settings(deadline=2000)
  11. def test_crf_viterbi(self, num_tags, num_words):
  12. model = CNNModelHelper(name='external')
  13. predictions = np.random.randn(num_words, num_tags).astype(np.float32)
  14. transitions = np.random.uniform(
  15. low=-1, high=1, size=(num_tags + 2, num_tags + 2)
  16. ).astype(np.float32)
  17. predictions_blob, transitions_blob = (
  18. model.net.AddExternalInputs('predictions', 'crf_transitions')
  19. )
  20. workspace.FeedBlob(str(transitions_blob), transitions)
  21. workspace.FeedBlob(str(predictions_blob), predictions)
  22. crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob)
  23. updated_predictions = crf_update_predictions(
  24. model, crf_layer, predictions_blob
  25. )
  26. ref_predictions = crf_layer.update_predictions(predictions_blob)
  27. workspace.RunNetOnce(model.param_init_net)
  28. workspace.RunNetOnce(model.net)
  29. updated_predictions = workspace.FetchBlob(str(updated_predictions))
  30. ref_predictions = workspace.FetchBlob(str(ref_predictions))
  31. np.testing.assert_allclose(
  32. updated_predictions,
  33. ref_predictions,
  34. atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions'
  35. )