net_modifier.py 823 B

12345678910111213141516171819202122232425262728293031323334
  1. import abc
  2. class NetModifier(metaclass=abc.ABCMeta):
  3. """
  4. An abstraction class for supporting modifying a generated net.
  5. Inherited classes should implement the modify_net method where
  6. related operators are added to the net.
  7. Example usage:
  8. modifier = SomeNetModifier(opts)
  9. modifier(net)
  10. """
  11. def __init__(self):
  12. pass
  13. @abc.abstractmethod
  14. def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):
  15. pass
  16. def __call__(self, net, init_net=None, grad_map=None, blob_to_device=None,
  17. modify_output_record=False):
  18. self.modify_net(
  19. net,
  20. init_net=init_net,
  21. grad_map=grad_map,
  22. blob_to_device=blob_to_device,
  23. modify_output_record=modify_output_record)