| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from caffe2.python import workspace, crf
- from caffe2.python.cnn import CNNModelHelper
- from caffe2.python.crf_predict import crf_update_predictions
- from caffe2.python.test_util import TestCase
- import hypothesis.strategies as st
- from hypothesis import given, settings
- import numpy as np
- class TestCrfDecode(TestCase):
- @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15))
- @settings(deadline=2000)
- def test_crf_viterbi(self, num_tags, num_words):
- model = CNNModelHelper(name='external')
- predictions = np.random.randn(num_words, num_tags).astype(np.float32)
- transitions = np.random.uniform(
- low=-1, high=1, size=(num_tags + 2, num_tags + 2)
- ).astype(np.float32)
- predictions_blob, transitions_blob = (
- model.net.AddExternalInputs('predictions', 'crf_transitions')
- )
- workspace.FeedBlob(str(transitions_blob), transitions)
- workspace.FeedBlob(str(predictions_blob), predictions)
- crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob)
- updated_predictions = crf_update_predictions(
- model, crf_layer, predictions_blob
- )
- ref_predictions = crf_layer.update_predictions(predictions_blob)
- workspace.RunNetOnce(model.param_init_net)
- workspace.RunNetOnce(model.net)
- updated_predictions = workspace.FetchBlob(str(updated_predictions))
- ref_predictions = workspace.FetchBlob(str(ref_predictions))
- np.testing.assert_allclose(
- updated_predictions,
- ref_predictions,
- atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions'
- )
|