Triton 源碼初步研讀

一、核心接口形態

def jit(
    fn: Optional[T] = None,
    *,
    version=None,
    do_not_specialize: Optional[Iterable[int]] = None,
    debug: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:

接口返回的是一個 JITFunction 對象,繼承自 KernelInterface

class JITFunction(KernelInterface[T]):
    # 略
    
class KernelInterface(Generic[T]):
    run: T

    def __getitem__(self, grid) -> T:
        return cast(T, functools.partial(cast(Callable, self.run), grid=grid))

JITFunction 調用時會有一個額外的參數 grid,類似 fn[grid](*args, **kwargs)

JITFunction 類實現中,核心的邏輯是 _make_launcher() 函數,內部會執行一個模版函數:

def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
    sig_key =  {sig_keys},
    constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
    spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
    key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug)
    if not extern_libs is None:
      key = (key, tuple(extern_libs.items()))
    assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
    if callable(grid):
        grid = grid({{{grid_args}}})
    grid_size = len(grid)
    grid_0 = grid[0]
    grid_1 = grid[1] if grid_size > 1 else 1
    grid_2 = grid[2] if grid_size > 2 else 1
    if device is None:
        device = get_current_device()
        set_current_device(device)
    if stream is None and not warmup:
      stream = get_cuda_stream(device)
    try:
      bin = cache[device][key]      # <-------- 首先嚐試從self.cache中獲取當前grid已經編譯後的 bin 
      if not warmup:
          bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
      return bin
    # kernel not cached -- compile
    except KeyError:
      # build dict of constant values
      args = [{args}]
      all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
      configs = self._get_config(*all_args),
      constants = self._make_constants(constexpr_key)
      constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
      constants.update({{i: 1 for i in configs[0].equal_to_1}})
      # build kernel signature -- doesn't include specialized arguments
      signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
      # build stub signature -- includes arguments that are specialized
      for i, arg in constants.items():
        if callable(arg):
          raise TypeError(f"Callable constexpr at index {{i}} is not supported")
      if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):   # <-------- 如果緩存沒有命中,則第一次指定 triton.compile
        bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
        if not warmup:
            bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
        self.cache[device][key] = bin
        return bin
      return None

進一步我們來看 compile 函數的實現。其中第一個關鍵步驟是 make_stub 返回一個 so_path。此部分分爲兩步:

  • step 1:創建一個 main.c文件,將CUDA launcher相關的代碼寫入
  • step 2:調用 _build函數編譯出一個 .so文件緩存起來,並返回此 so_path
def make_stub(name, signature, constants):
    # name of files that are cached
    so_cache_key = make_so_cache_key(version_key(), signature, constants)
    so_cache_manager = CacheManager(so_cache_key)
    so_name = f"{name}.so"
    # retrieve stub from cache if it exists
    if not so_cache_manager.has_file(so_name):
        with tempfile.TemporaryDirectory() as tmpdir:
            src = generate_launcher(constants, signature)
            src_path = os.path.join(tmpdir, "main.c")    # <------- step 1
            with open(src_path, "w") as f:
                f.write(src)
            so = _build(name, src_path, tmpdir)          # <------- step 2
            with open(so, "rb") as f:
                so_cache_manager.put(f.read(), so_name, binary=True)
    return so_cache_manager._make_path(so_name)

我們首先來看 step 1 都寫了什麼內容到 main.c 裏,主要內容如下,從代碼裏可以看出,此處這裏主要是啓動相關的代碼:cuLaunchKernel(function, XXXX)

#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>

static inline void gpuAssert(CUresult code, const char *file, int line)
{{
   if (code != CUDA_SUCCESS)
   {{
      const char* prefix = "Triton Error [CUDA]: ";
      const char* str;
      cuGetErrorString(code, &str);
      char err[1024] = {{0}};
      strcat(err, prefix);
      strcat(err, str);
      PyErr_SetString(PyExc_RuntimeError, err);
   }}
}}

#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
  void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
  if(gridX*gridY*gridZ > 0){{
    CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));    // <-------- 注意此處的 function 應該就是我們原生的函數
  }}
}}

typedef struct _DevicePtrInfo {{
    CUdeviceptr dev_ptr;
    bool valid;
}} DevicePtrInfo;

static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
  DevicePtrInfo ptr_info;
  ptr_info.dev_ptr = 0;
  ptr_info.valid = true;
  if (PyLong_Check(obj)) {{
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
    return ptr_info;
  }}
  if (obj == Py_None) {{
    // valid nullptr
    return ptr_info;
  }}
  PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
  if(ptr){{
    PyObject *empty_tuple = PyTuple_New(0);
    PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
    Py_DECREF(empty_tuple);
    Py_DECREF(ptr);
    if (!PyLong_Check(ret)) {{
      PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
      ptr_info.valid = false;
      return ptr_info;
    }}
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
    if(!ptr_info.dev_ptr)
      return ptr_info;
    uint64_t dev_ptr;
    int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
    if (status == CUDA_ERROR_INVALID_VALUE) {{
        PyErr_Format(PyExc_ValueError,
                     "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
        ptr_info.valid = false;
    }}
    ptr_info.dev_ptr = dev_ptr;
    Py_DECREF(ret);  // Thanks ChatGPT!
    return ptr_info;
  }}
  PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
  return ptr_info;
}}

static PyObject* launch(PyObject* self, PyObject* args) {{
  int gridX, gridY, gridZ;
  uint64_t _stream;
  uint64_t _function;
  int num_warps;
  int shared_memory;
  PyObject *launch_enter_hook = NULL;
  PyObject *launch_exit_hook = NULL;
  PyObject *compiled_kernel = NULL;
  {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
  if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
    return NULL;
  }}

  if (launch_enter_hook != Py_None) {{
    PyObject_CallObject(launch_enter_hook, args);
  }}


  // raise exception asap
  {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
  _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});

  if (launch_exit_hook != Py_None) {{
    PyObject_CallObject(launch_exit_hook, args);
  }}

  if(PyErr_Occurred()) {{
    return NULL;
  }}
  // return None
  Py_INCREF(Py_None);
  return Py_None;
}}

static PyMethodDef ModuleMethods[] = {{
  {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},   # <-------- 與 CompiledKernel 中的 self.c_wrapper = getattr(mod, "launch") 相呼應
  {{NULL, NULL, 0, NULL}} // sentinel
}};

static struct PyModuleDef ModuleDef = {{
  PyModuleDef_HEAD_INIT,
  \"__triton_launcher\",
  NULL, //documentation
  -1, //size
  ModuleMethods
}};

PyMODINIT_FUNC PyInit___triton_launcher(void) {{
  PyObject *m = PyModule_Create(&ModuleDef);
  if(m == NULL) {{
    return NULL;
  }}
  PyModule_AddFunctions(m, ModuleMethods);
  return m;
}}

接着 step 2 做的事情就是創建一個子進程,調用 CC 命令+ setuptools工具編譯成一個 .so 文件。(注意:setuptools 似乎是一個兜底策略,只有在 CC 編譯失敗纔會走到。)

def _build(name, src, srcdir):
    cuda_lib_dirs = libcuda_dirs()
    base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
    cuda_path = os.path.join(base_dir, "third_party", "cuda")

    cu_include_dir = os.path.join(cuda_path, "include")
    triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
    cuda_header = os.path.join(cu_include_dir, "cuda.h")
    triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
    cu_include_dir = triton_include_dir
    
    cc = os.environ.get("CC")
    cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
    cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
    ret = subprocess.check_call(cc_cmd)
    
    
    # fallback on setuptools
    extra_compile_args = []
    library_dirs = cuda_lib_dirs
    include_dirs = [srcdir, cu_include_dir]
    libraries = ['cuda']
    # extra arguments
    extra_link_args = []
    # create extension module
    ext = setuptools.Extension(
        name=name,
        language='c',
        sources=[src],
        include_dirs=include_dirs,
        extra_compile_args=extra_compile_args + ['-O3'],
        extra_link_args=extra_link_args,
        library_dirs=library_dirs,
        libraries=libraries,
    )
    # build extension module
    args = ['build_ext']
    args.append('--build-temp=' + srcdir)
    args.append('--build-lib=' + srcdir)
    args.append('-q')
    args = dict(
        name=name,
        ext_modules=[ext],
        script_args=args,
    )
    with quiet():
        setuptools.setup(**args)
    return so

然後開始執行 compile 的第二階段的工作,主要是IR的幾個 stage 的轉換:

def compile(xxx):
    # 省略
    asm = dict()
    module = fn
    first_stage = list(stages.keys()).index(ext)
    # run compilation pipeline  and populate metadata
    for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
        # stage 1: ast_to_trir
        if ir == ext:
            next_module = parse(fn)
        else:
            # stage 2,3,4
            next_module = compile_kernel(module)
      
        if ir == "llir" and "shared" not in metadata:
            metadata["shared"] = _triton.get_shared_memory_size(module)
    
        if ir == "ptx":
            metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
            
        if ir == "cubin":
            asm[ir] = next_module
        else:
            asm[ir] = str(next_module)
        
        module = next_module

幾個重要階段的IR轉換分別對應如下函數:

  • ast_to_ttir :其中 CodeGenerator(ast.NodeVisitor) 負責主要工作,返回的是一個 C++ 端 Pybind 的 mlir::ModuleOp 對象
  • ttir_to_ttgir :其中 pm變量是一個 C++ 端Pybind 的 mlir::PassManager 對象
  • ttgir_to_llir:其中核心函數 gpu_to_llvmir 也是C++端 Pybind 的 translateTritonGPUToLLVMIR 函數
  • llir_to_ptx :其中核心函數 translate_llvmir_to_ptx 也是主要由C++端的 translateLLVMIRToPT 函數來是實現
  • ptx_to_cubin :其中核心函數compile_ptx_to_cubin主要由 C++ 端來實現的
def ast_to_ttir(fn, signature, specialization, constants, debug):
    context = ir.context()
    context.load_triton()
    prototype = language.function_type([], arg_types)
    generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
                              function_name=function_name, attributes=new_attrs,
                              is_kernel=True, debug=debug)
    generator.visit(fn.parse())    
    ret = generator.module   # <------- 由 ir.builder(context).create_module()
    # module takes ownership of the context
    ret.context = context
    return ret

def ttir_to_ttgir(mod, num_warps):
    pm = _triton.ir.pass_manager(mod.context)
    pm.add_convert_triton_to_tritongpu_pass(num_warps)
    pm.run(mod)
    return mod

def ttgir_to_llir(mod, extern_libs, arch):
    if extern_libs:
        _add_external_libs(mod, extern_libs)
    # TODO: separate tritongpu_to_llvmir for different backends
    if _is_cuda(arch):
        return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
    else:
        return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)

def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
    if ptx_version is None:
        _, cuda_version = path_to_ptxas()
        ptx_version = ptx_get_version(cuda_version)
    return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)

def ptx_to_cubin(ptx: str, arch: int):
    ptxas, _ = path_to_ptxas()
    return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)

這裏額外貼一下 ttgir_to_llir 中的核心函數代碼:

std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
                           mlir::ModuleOp module, int computeCapability,
                           bool isROCM) {
  mlir::PassManager pm(module->getContext());
  applyPassManagerCLOptions(pm);
  auto printingFlags = mlir::OpPrintingFlags();
  printingFlags.elideLargeElementsAttrs(16);
  pm.enableIRPrinting(
      /*shouldPrintBeforePass=*/nullptr,
      /*shouldPrintAfterPass=*/
      [](mlir::Pass *pass, mlir::Operation *) {
        return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
      },
      /*printModuleScope=*/false,
      /*printAfterOnlyOnChange=*/true,
      /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);

  pm.addPass(mlir::createConvertSCFToCFPass());
  pm.addPass(mlir::createConvertIndexToLLVMPass());
  pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability, isROCM));
  pm.addPass(mlir::createArithToLLVMConversionPass());
  pm.addPass(mlir::createCanonicalizerPass());
  // Simplify the IR
  pm.addPass(mlir::createCSEPass());
  pm.addPass(mlir::createSymbolDCEPass());

  if (failed(pm.run(module))) {
    llvm::errs() << "Pass execution failed";
    return nullptr;
  }

  auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, isROCM);
  if (!llvmIR) {
    llvm::errs() << "Translate to LLVM IR failed";
    return nullptr;
  }

  if (::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
    std::string mod_string;
    std::unique_ptr<llvm::raw_string_ostream> ir_ss(
        new llvm::raw_string_ostream(mod_string));
    llvmIR->print(*ir_ss, nullptr);
    std::cout << "// -----// LLVM IR Dump //----- //\n"
              << mod_string << std::endl;
  }

  return llvmIR;
}

其中 translateLLVMIRToPTX 的核心代碼是:

std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
  // LLVM version in use may not officially support target hardware.
  // Supported versions for LLVM 14 are here:
  // https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
  int maxPTX = std::min(80, version);
  int maxCC = std::min(90, cc);
  // options
  auto options = llvm::cl::getRegisteredOptions();
  auto *shortPtr =
      static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
  assert(shortPtr);
  shortPtr->setValue(true);
  std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
  // max PTX version
  int ptxMajor = maxPTX / 10;
  int ptxMinor = maxPTX % 10;
  // create
  std::string triple = "nvptx64-nvidia-cuda";
  std::string proc = "sm_" + std::to_string(maxCC);
  std::string layout = "";
  std::string features = "";
  // std::string features = "+ptx" + std::to_string(maxPTX);
  initLLVM();
  // verify and store llvm
  llvm::legacy::PassManager pm;
  pm.add(llvm::createVerifierPass());
  pm.run(module);
  // module->print(llvm::outs(), nullptr);

  // create machine
  module.setTargetTriple(triple);
  std::string error;
  auto target =
      llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
  llvm::TargetOptions opt;
  opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
  opt.UnsafeFPMath = false;
  opt.NoInfsFPMath = false;
  opt.NoNaNsFPMath = true;
  llvm::TargetMachine *machine = target->createTargetMachine(
      module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
      std::nullopt, llvm::CodeGenOpt::Aggressive);
  // set data layout
  if (layout.empty())
    module.setDataLayout(machine->createDataLayout());
  else
    module.setDataLayout(layout);
  // emit machine code
  std::string result;
  {
    llvm::raw_string_ostream stream(result);
    llvm::buffer_ostream pstream(stream);
    for (llvm::Function &f : module.functions())
      f.addFnAttr(llvm::Attribute::AlwaysInline);
    llvm::legacy::PassManager pass;
    // emit
    machine->addPassesToEmitFile(pass, pstream, nullptr,
                                 llvm::CodeGenFileType::CGFT_AssemblyFile);
    pass.run(module);
  }
  // post-process
  findAndReplace(result, ".version", "\n",
                 ".version " + std::to_string(ptxMajor) + "." +
                     std::to_string(ptxMinor) + "\n");
  findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
  while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
    ;
  while (findAndReplace(result, "\t// end inline asm", "\n", ""))
    ;
  return result;
}

如此已經得到了cubin文件,接下來是 compile 函數的最後一步:返回編譯後的 CompiledKernel(fn, so_path, metadata, asm) handle對象。主要邏輯由CompiledKernel 來實現。
在最前面 jit函數裏最後負責執行的邏輯語句是:

bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)

也就是CompiledKernel.c_wrapper 函數:

class CompiledKernel:

    def __init__(self, fn, so_path, metadata, asm):
        # initialize launcher
        import importlib.util
        spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
        mod = importlib.util.module_from_spec(spec)
        self.fn = fn
        spec.loader.exec_module(mod)
        self.c_wrapper = getattr(mod, "launch")    # <-------- 這個傢伙

這裏的 so_name 就是 make_stub 中編譯出來的 Launcher 相關的 .so 的路徑。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章