crf_predict.py 1.1 KB

123456789101112131415161718192021222324252627282930313233
  1. import numpy as np
  2. from caffe2.python.crf import CRFWithLoss
  3. def crf_update_predictions(model, crf_with_loss, classes):
  4. return apply_crf(
  5. model.param_init_net,
  6. model.net,
  7. crf_with_loss.transitions,
  8. classes,
  9. crf_with_loss.num_classes,
  10. )
  11. def apply_crf(init_net, net, transitions, predictions, num_classes):
  12. padded_classes = CRFWithLoss.pad_predictions(
  13. predictions, init_net, net, num_classes
  14. )
  15. bestPath = net.ViterbiPath([padded_classes, transitions])
  16. new_padded_classes = net.SwapBestPath([padded_classes, bestPath])
  17. # Revert the effect of pad_predictions by removing the last two rows and
  18. # the last two columns
  19. new_classes = net.RemovePadding(
  20. [new_padded_classes], padding_width=1, end_padding_width=1
  21. )
  22. slice_starts = np.array([0, 0]).astype(np.int32)
  23. slice_ends = np.array([-1, -3]).astype(np.int32)
  24. slice_starts = net.GivenTensorIntFill([], shape=[2], values=slice_starts)
  25. slice_ends = net.GivenTensorIntFill([], shape=[2], values=slice_ends)
  26. new_classes = net.Slice([new_classes, slice_starts, slice_ends])
  27. return new_classes