functional.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # @package functional
  2. # Module caffe2.python.layers.functional
  3. from caffe2.python import core, schema, scope, workspace
  4. from caffe2.python.layers.layers import (
  5. ModelLayer,
  6. )
  7. import caffe2.proto.caffe2_pb2 as caffe2_pb2
  8. import numpy as np
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(logging.INFO)
  12. class Functional(ModelLayer):
  13. def __init__(self, model, input_record, output_names_or_num, function,
  14. name='functional', output_dtypes=None, tags=None, **kwargs):
  15. # allow coercion
  16. input_record = schema.as_record(input_record)
  17. super(Functional, self).__init__(model, name, input_record, tags=tags, **kwargs)
  18. self._function = function
  19. self._kwargs = kwargs
  20. return_struct = (
  21. isinstance(output_names_or_num, list) or
  22. (isinstance(output_names_or_num, int) and
  23. output_names_or_num != 1)
  24. )
  25. with scope.NameScope(self.name, reset=True):
  26. if isinstance(output_names_or_num, int):
  27. struct_output_schema = schema.NewRecord(
  28. model.net, schema.RawTuple(output_names_or_num))
  29. elif isinstance(output_names_or_num, schema.Field):
  30. self.output_schema = output_names_or_num.clone(keep_blobs=True)
  31. return
  32. else:
  33. if not isinstance(output_names_or_num, list):
  34. output_names_or_num = [output_names_or_num]
  35. out_tuple = [(out, np.void) for out in output_names_or_num]
  36. struct_output_schema = schema.NewRecord(
  37. model.net, schema.Struct(*out_tuple))
  38. num_outputs = len(struct_output_schema.field_blobs())
  39. # functional layer returns Struct if more than one outputs or output is
  40. # a list, otherwise Scalar
  41. if return_struct:
  42. self.output_schema = struct_output_schema
  43. else:
  44. self.output_schema = struct_output_schema[0]
  45. # If output_dtypes is provided, use it for output schema. Otherwise
  46. # the shape and type will be inferred.
  47. if output_dtypes is not None:
  48. if not isinstance(output_dtypes, list):
  49. output_dtypes = [output_dtypes] * num_outputs
  50. assert len(output_dtypes) == num_outputs
  51. for dtype, scalar in zip(output_dtypes,
  52. self.output_schema.all_scalars()):
  53. scalar.set_type(dtype)
  54. return
  55. # Fake execution of the function to infer shapes and types automatically
  56. had_issues = False
  57. try:
  58. type_net = core.Net('_temp_type_and_shape_inference_net')
  59. schema.InitEmptyRecord(type_net, input_record, enforce_types=True)
  60. function(type_net, self.input_record, self.output_schema, **kwargs)
  61. (shapes, types) = workspace.InferShapesAndTypes([type_net], {})
  62. for i in range(num_outputs):
  63. scalar_schema = (self.output_schema[i] if return_struct
  64. else self.output_schema)
  65. blob = scalar_schema()
  66. if blob not in types or blob not in shapes:
  67. had_issues = True
  68. continue
  69. if shapes[blob] == []:
  70. # Scalar type
  71. shape = tuple()
  72. elif shapes[blob][0] == 0:
  73. shape = tuple(shapes[blob][1:])
  74. else:
  75. logger.warning("unexpected shape: {}".format(shapes[blob]))
  76. # If batch dimension is not first - give up on shape
  77. # inference for that blob
  78. had_issues = True
  79. continue
  80. # TODO(amalevich): Move it to some shared library
  81. dtype = None
  82. if types[blob] == caffe2_pb2.TensorProto.DOUBLE:
  83. dtype = (np.float64, shape)
  84. elif types[blob] == caffe2_pb2.TensorProto.FLOAT:
  85. dtype = (np.float32, shape)
  86. elif types[blob] == caffe2_pb2.TensorProto.INT32:
  87. dtype = (np.int32, shape)
  88. elif types[blob] == caffe2_pb2.TensorProto.INT64:
  89. dtype = (np.int64, shape)
  90. elif types[blob] == caffe2_pb2.TensorProto.FLOAT16:
  91. dtype = (np.float16, shape)
  92. if dtype is not None:
  93. scalar_schema.set_type(dtype)
  94. except TypeError as ex:
  95. had_issues = True
  96. logger.warning(str(ex))
  97. if had_issues:
  98. logger.warning(
  99. "Type inference had problems for layer: {}".format(self.name))
  100. def add_ops(self, net):
  101. self._function(
  102. net, self.input_record, self.output_schema, **(self._kwargs))