tools.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. ## @package tools
  2. # Module caffe2.python.helpers.tools
  3. def image_input(
  4. model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
  5. ):
  6. assert 'is_test' in kwargs, "Argument 'is_test' is required"
  7. if order == "NCHW":
  8. if (use_gpu_transform):
  9. kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
  10. # GPU transform will handle NHWC -> NCHW
  11. outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
  12. pass
  13. else:
  14. outputs = model.net.ImageInput(
  15. blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
  16. )
  17. outputs_list = list(outputs)
  18. outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
  19. outputs = tuple(outputs_list)
  20. else:
  21. outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
  22. return outputs
  23. def video_input(model, blob_in, blob_out, **kwargs):
  24. # size of outputs can vary depending on kwargs
  25. outputs = model.net.VideoInput(blob_in, blob_out, **kwargs)
  26. return outputs