From 9cd47a26d593985ad2b36857cb9fcbc7659ffde8 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 16 Feb 2021 18:05:47 +0100 Subject: [PATCH] [mlir] add verifiers for NVVM and ROCDL kernel attributes Make sure they can only be attached to LLVM functions as a result of converting GPU functions to the LLVM Dialect. --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 1 + mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 1 + mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 12 ++++++++++++ mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 12 ++++++++++++ .../LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 5 +---- .../Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp | 4 +--- mlir/test/Dialect/LLVMIR/nvvm.mlir | 7 ++++++- mlir/test/Dialect/LLVMIR/rocdl.mlir | 6 +++++- 8 files changed, 39 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index de7fd01fe1ca..203a0b2031c9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -24,6 +24,7 @@ def NVVM_Dialect : Dialect { let name = "nvvm"; let cppNamespace = "::mlir::NVVM"; let dependentDialects = ["LLVM::LLVMDialect"]; + let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ /// Get the name of the attribute used to annotate external kernel diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index cfb08ff465a2..1b45e5144263 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -24,6 +24,7 @@ def ROCDL_Dialect : Dialect { let name = "rocdl"; let cppNamespace = "::mlir::ROCDL"; let dependentDialects = ["LLVM::LLVMDialect"]; + let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ /// Get the name of the attribute used to annotate external kernel diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 06e7378d5fe1..3b6d2395d0ca 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -145,5 +145,17 @@ void NVVMDialect::initialize() { allowUnknownOperations(); } +LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Kernel function attribute should be attached to functions. + if (attr.first == NVVMDialect::getKernelFuncAttrName()) { + if (!isa(op)) { + return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() + << "' attribute attached to unexpected op"; + } + } + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 1cdceaf78904..f54fcdbca319 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -91,5 +91,17 @@ void ROCDLDialect::initialize() { allowUnknownOperations(); } +LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // Kernel function attribute should be attached to functions. + if (attr.first == ROCDLDialect::getKernelFuncAttrName()) { + if (!isa(op)) { + return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName() + << "' attribute attached to unexpected op"; + } + } + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index fdc9add52dee..a0eb87ec3125 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -47,10 +47,7 @@ LogicalResult mlir::NVVMDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { if (attribute.first == NVVM::NVVMDialect::getKernelFuncAttrName()) { - auto func = dyn_cast(op); - if (!func) - return failure(); - + auto func = cast(op); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName()); llvm::Metadata *llvmMetadata[] = { diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp index 7b34f803e7e3..9288af17beef 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -54,9 +54,7 @@ LogicalResult mlir::ROCDLDialectLLVMIRTranslationInterface::amendOperation( Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { if (attribute.first == ROCDL::ROCDLDialect::getKernelFuncAttrName()) { - auto func = dyn_cast(op); - if (!func) - return failure(); + auto func = cast(op); // For GPU kernels, // 1. Insert AMDGPU_KERNEL calling convention. diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 545364dc0732..1e3d6dc73978 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s func @nvvm_special_regs() -> i32 { // CHECK: nvvm.read.ptx.sreg.tid.x : i32 @@ -68,3 +68,8 @@ func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } + +// ----- + +// expected-error@below {{attribute attached to unexpected op}} +func private @expected_llvm_func() attributes { nvvm.kernel } diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 31a56bede3bd..e9a3a59d8501 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s func @rocdl_special_regs() -> i32 { // CHECK-LABEL: rocdl_special_regs @@ -167,3 +167,7 @@ llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32, llvm.return } +// ----- + +// expected-error@below {{attribute attached to unexpected op}} +func private @expected_llvm_func() attributes { rocdl.kernel }