jiterator.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. from torch import Tensor
  3. from typing import Callable, List
  4. import re
  5. __all__ : List[str] = []
  6. class _CodeParser:
  7. def __init__(self, code_string: str):
  8. optional_ws = r"\s*"
  9. required_ws = r"\s+"
  10. template_params = r"(?P<template_params>\<.+\>)"
  11. return_type = r"(?P<return_type>\w+)"
  12. function_name = r"(?P<function_name>\w+)"
  13. function_params = r"(?P<function_params>\(.+\))"
  14. function_body = r"(?P<function_body>\{.+\})"
  15. pattern = \
  16. optional_ws \
  17. + "template" \
  18. + optional_ws + template_params \
  19. + optional_ws + return_type \
  20. + required_ws + function_name \
  21. + optional_ws + function_params \
  22. + optional_ws + function_body \
  23. + optional_ws
  24. result = re.match(pattern, code_string, re.DOTALL) # DOTALL for matching multiline
  25. if result is None:
  26. raise Exception(f"Couldn't parse code, please check correctness:\n {code_string}")
  27. self.template_params = result["template_params"]
  28. self.return_type = result["return_type"]
  29. self.function_name = result["function_name"]
  30. self.function_params = result["function_params"]
  31. self.function_body = result["function_body"]
  32. def _create_jit_fn(code_string: str, **kwargs) -> Callable:
  33. """
  34. Create a jiterator-generated cuda kernel for an elementwise op.
  35. The code string has to be a valid CUDA function that describes the computation for a single element. The code
  36. string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
  37. into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
  38. local temp dir.
  39. Jiterator-generated kernels accepts noncontiguous tensors, and supports boardcasting and type promotion.
  40. Args:
  41. code_string (string): CUDA code string to be compiled by jiterator.
  42. kwargs (Dict, optional): Keyword arguments for generated function
  43. Example:
  44. >>> code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
  45. >>> jitted_fn = create_jit_fn(code_string, alpha=1.0)
  46. >>> a = torch.rand(3, device='cuda')
  47. >>> b = torch.rand(3, device='cuda')
  48. >>> # invoke jitted function like a regular python function
  49. >>> result = jitted_fn(a, b, alpha=3.14)
  50. Jiterator can be used together with python registration to override an operator's cuda kernel
  51. Following example is overriding gelu's cuda kernel with relu:
  52. >>> code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
  53. >>> my_gelu = create_jit_fn(code_string)
  54. >>> my_lib = torch.library.Library("aten", "IMPL")
  55. >>> my_lib.impl('aten::gelu', my_gelu, "CUDA")
  56. >>> # torch.nn.GELU and torch.nn.function.gelu are now overridden
  57. >>> a = torch.rand(3, device='cuda')
  58. >>> torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
  59. .. warning::
  60. This API is in beta and may change in future releases.
  61. .. warning::
  62. Jiterator only supports up to 8 tensor inputs
  63. .. warning::
  64. All input tensors must live in CUDA device
  65. """
  66. class JittedFunction:
  67. def __init__(self, code_string: str, **kwargs):
  68. self.code_string = code_string
  69. parsed_code = _CodeParser(code_string)
  70. self.kernel_name = parsed_code.function_name
  71. self.kwargs_dict = kwargs
  72. self.is_cuda_available = torch.cuda.is_available()
  73. def __call__(self, *tensors: Tensor, **kwargs):
  74. # Jiterator follow torch.cuda's lazy initialization behavior
  75. # Defer checking cuda's availability at the function invocation time
  76. assert self.is_cuda_available, "Jiterator is only supported on CUDA GPUs, no CUDA GPUs are available."
  77. assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
  78. expanded_kwargs = self.kwargs_dict.copy()
  79. for key, value in kwargs.items():
  80. if key in self.kwargs_dict:
  81. expanded_kwargs[key] = value
  82. else:
  83. raise KeyError(f"{key} is not declared in function definition")
  84. return torch._C._cuda_jiterator_compile_and_launch_kernel(
  85. self.code_string,
  86. self.kernel_name,
  87. tensors,
  88. expanded_kwargs)
  89. return JittedFunction(code_string, **kwargs)