mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir] Add RuntimeVerifiableOpInterface and transform
Static op verification cannot detect cases where an op is valid at compile time but may be invalid at runtime. An example of such an op is `memref::ExpandShapeOp`. Invalid at compile time: `memref.expand_shape %m [[0, 1]] : memref<11xf32> into memref<2x5xf32>` Valid at compile time (because we do not know any better): `memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>`. This op may or may not be valid at runtime depending on the runtime shape of `%m`. Invalid runtime ops such as the one above are hard to debug because they can crash the program execution at a seemingly unrelated position or (even worse) compute an invalid result without crashing. This revision adds a new op interface `RuntimeVerifiableOpInterface` that can be implemented by ops that provide additional runtime verification. Such runtime verification can be computationally expensive, so it is only generated on an opt-in basis by running `-generate-runtime-verification`. A simple runtime verifier for `memref::ExpandShapeOp` is provided as an example. Differential Revision: https://reviews.llvm.org/D138576
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
|
||||
#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace memref {
|
||||
void registerRuntimeVerifiableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry);
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
|
||||
@@ -45,6 +45,7 @@
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
@@ -130,6 +131,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
registry);
|
||||
linalg::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg::registerTilingInterfaceExternalModels(registry);
|
||||
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
shape::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
||||
@@ -8,6 +8,7 @@ add_mlir_interface(InferIntRangeInterface)
|
||||
add_mlir_interface(InferTypeOpInterface)
|
||||
add_mlir_interface(LoopLikeInterface)
|
||||
add_mlir_interface(ParallelCombiningOpInterface)
|
||||
add_mlir_interface(RuntimeVerifiableOpInterface)
|
||||
add_mlir_interface(ShapedOpInterfaces)
|
||||
add_mlir_interface(SideEffectInterfaces)
|
||||
add_mlir_interface(TilingInterface)
|
||||
|
||||
17
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h
Normal file
17
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h
Normal file
@@ -0,0 +1,17 @@
|
||||
//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
|
||||
#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc"
|
||||
|
||||
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
|
||||
40
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
Normal file
40
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
Normal file
@@ -0,0 +1,40 @@
|
||||
//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
|
||||
#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
|
||||
let description = [{
|
||||
Implementations of this interface generate IR for runtime op verification.
|
||||
|
||||
Incorrect op usage can often be caught by op verifiers based on static
|
||||
program information. However, in the absence of static program information,
|
||||
it can remain undetected at compile time (e.g., in case of dynamic memref
|
||||
strides instead of static memref strides). Such cases can be checked at
|
||||
runtime. The op-specific checks are generated by this interface.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Generate IR to verify this op at runtime, aborting runtime execution if
|
||||
verification fails.
|
||||
}],
|
||||
/*retTy=*/"void",
|
||||
/*methodName=*/"generateRuntimeVerification",
|
||||
/*args=*/(ins "::mlir::OpBuilder &":$builder,
|
||||
"::mlir::Location":$loc)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
|
||||
@@ -64,6 +64,9 @@ std::unique_ptr<Pass> createControlFlowSinkPass();
|
||||
/// Creates a pass to perform common sub expression elimination.
|
||||
std::unique_ptr<Pass> createCSEPass();
|
||||
|
||||
/// Creates a pass that generates IR to verify ops at runtime.
|
||||
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
|
||||
|
||||
/// Creates a loop invariant code motion pass that hoists loop invariant
|
||||
/// instructions out of the loop.
|
||||
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
|
||||
|
||||
@@ -77,6 +77,16 @@ def CSE : Pass<"cse"> {
|
||||
];
|
||||
}
|
||||
|
||||
def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
|
||||
let summary = "Generate additional runtime op verification checks";
|
||||
let description = [{
|
||||
This pass generates op-specific runtime checks using the
|
||||
`RuntimeVerifiableOpInterface`. It can be run for debugging purposes after
|
||||
passes that are suspected to introduce faulty IR.
|
||||
}];
|
||||
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
|
||||
}
|
||||
|
||||
def Inliner : Pass<"inline"> {
|
||||
let summary = "Inline function calls";
|
||||
let constructor = "mlir::createInlinerPass()";
|
||||
|
||||
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
|
||||
MultiBuffer.cpp
|
||||
NormalizeMemRefs.cpp
|
||||
ResolveShapedTypeResultDims.cpp
|
||||
RuntimeOpVerification.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
|
||||
|
||||
70
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Normal file
70
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace memref {
|
||||
namespace {
|
||||
struct ExpandShapeOpInterface
|
||||
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
|
||||
ExpandShapeOp> {
|
||||
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
|
||||
Location loc) const {
|
||||
auto expandShapeOp = cast<ExpandShapeOp>(op);
|
||||
|
||||
// Verify that the expanded dim sizes are a product of the collapsed dim
|
||||
// size.
|
||||
for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) {
|
||||
Value srcDimSz =
|
||||
builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
|
||||
int64_t groupSz = 1;
|
||||
bool foundDynamicDim = false;
|
||||
for (int64_t resultDim : it.value()) {
|
||||
if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
|
||||
// Keep this assert here in case the op is extended in the future.
|
||||
assert(!foundDynamicDim &&
|
||||
"more than one dynamic dim found in reassoc group");
|
||||
foundDynamicDim = true;
|
||||
continue;
|
||||
}
|
||||
groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
|
||||
}
|
||||
Value staticResultDimSz =
|
||||
builder.create<arith::ConstantIndexOp>(loc, groupSz);
|
||||
// staticResultDimSz must divide srcDimSz evenly.
|
||||
Value mod =
|
||||
builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
|
||||
Value isModZero = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, mod,
|
||||
builder.create<arith::ConstantIndexOp>(loc, 0));
|
||||
builder.create<cf::AssertOp>(
|
||||
loc, isModZero,
|
||||
"static result dims in reassoc group do not divide src dim evenly");
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
} // namespace memref
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
|
||||
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
|
||||
|
||||
// Load additional dialects of which ops may get created.
|
||||
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
|
||||
});
|
||||
}
|
||||
@@ -10,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||
InferTypeOpInterface.cpp
|
||||
LoopLikeInterface.cpp
|
||||
ParallelCombiningOpInterface.cpp
|
||||
RuntimeVerifiableOpInterface.cpp
|
||||
ShapedOpInterfaces.cpp
|
||||
SideEffectInterfaces.cpp
|
||||
TilingInterface.cpp
|
||||
@@ -44,6 +45,7 @@ add_mlir_interface_library(InferIntRangeInterface)
|
||||
add_mlir_interface_library(InferTypeOpInterface)
|
||||
add_mlir_interface_library(LoopLikeInterface)
|
||||
add_mlir_interface_library(ParallelCombiningOpInterface)
|
||||
add_mlir_interface_library(RuntimeVerifiableOpInterface)
|
||||
add_mlir_interface_library(ShapedOpInterfaces)
|
||||
add_mlir_interface_library(SideEffectInterfaces)
|
||||
add_mlir_interface_library(TilingInterface)
|
||||
|
||||
17
mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
Normal file
17
mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
Normal file
@@ -0,0 +1,17 @@
|
||||
//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class Location;
|
||||
class OpBuilder;
|
||||
} // namespace mlir
|
||||
|
||||
/// Include the definitions of the interface.
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"
|
||||
@@ -4,6 +4,7 @@ add_mlir_library(MLIRTransforms
|
||||
Canonicalizer.cpp
|
||||
ControlFlowSink.cpp
|
||||
CSE.cpp
|
||||
GenerateRuntimeVerification.cpp
|
||||
Inliner.cpp
|
||||
LocationSnapshot.cpp
|
||||
LoopInvariantCodeMotion.cpp
|
||||
@@ -26,6 +27,7 @@ add_mlir_library(MLIRTransforms
|
||||
MLIRCopyOpInterface
|
||||
MLIRLoopLikeInterface
|
||||
MLIRPass
|
||||
MLIRRuntimeVerifiableOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRTransformUtils
|
||||
|
||||
40
mlir/lib/Transforms/GenerateRuntimeVerification.cpp
Normal file
40
mlir/lib/Transforms/GenerateRuntimeVerification.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
|
||||
#include "mlir/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct GenerateRuntimeVerificationPass
|
||||
: public impl::GenerateRuntimeVerificationBase<
|
||||
GenerateRuntimeVerificationPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void GenerateRuntimeVerificationPass::runOnOperation() {
|
||||
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
|
||||
OpBuilder builder(getOperation()->getContext());
|
||||
builder.setInsertionPoint(verifiableOp);
|
||||
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
|
||||
return std::make_unique<GenerateRuntimeVerificationPass>();
|
||||
}
|
||||
14
mlir/test/Dialect/MemRef/runtime-verification.mlir
Normal file
14
mlir/test/Dialect/MemRef/runtime-verification.mlir
Normal file
@@ -0,0 +1,14 @@
|
||||
// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @expand_shape(
|
||||
// CHECK-SAME: %[[m:.*]]: memref<?xf32>
|
||||
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
|
||||
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
|
||||
// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
|
||||
// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
|
||||
// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
|
||||
func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
|
||||
%0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
|
||||
return %0 : memref<?x5xf32>
|
||||
}
|
||||
@@ -1037,6 +1037,13 @@ td_library(
|
||||
deps = [":OpBaseTdFiles"],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "RuntimeVerifiableOpInterfaceTdFiles",
|
||||
srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"],
|
||||
includes = ["include"],
|
||||
deps = [":OpBaseTdFiles"],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "SideEffectInterfacesTdFiles",
|
||||
srcs = [
|
||||
@@ -2992,6 +2999,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "RuntimeVerifiableOpInterface",
|
||||
srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"],
|
||||
hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":IR",
|
||||
":RuntimeVerifiableOpInterfaceIncGen",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "VectorInterfaces",
|
||||
srcs = ["lib/Interfaces/VectorInterfaces.cpp"],
|
||||
@@ -5715,6 +5734,24 @@ gentbl_cc_library(
|
||||
deps = [":ParallelCombiningOpInterfaceTdFiles"],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "RuntimeVerifiableOpInterfaceIncGen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
["-gen-op-interface-decls"],
|
||||
"include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-op-interface-defs"],
|
||||
"include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td",
|
||||
deps = [":RuntimeVerifiableOpInterfaceTdFiles"],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "VectorInterfacesIncGen",
|
||||
strip_include_prefix = "include",
|
||||
@@ -5818,6 +5855,7 @@ cc_library(
|
||||
":LoopLikeInterface",
|
||||
":Pass",
|
||||
":Rewrite",
|
||||
":RuntimeVerifiableOpInterface",
|
||||
":SideEffectInterfaces",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
@@ -9783,6 +9821,7 @@ cc_library(
|
||||
":ArithDialect",
|
||||
":ArithTransforms",
|
||||
":ArithUtils",
|
||||
":ControlFlowDialect",
|
||||
":DialectUtils",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
@@ -9791,6 +9830,7 @@ cc_library(
|
||||
":MemRefDialect",
|
||||
":MemRefPassIncGen",
|
||||
":Pass",
|
||||
":RuntimeVerifiableOpInterface",
|
||||
":TensorDialect",
|
||||
":Transforms",
|
||||
":VectorDialect",
|
||||
|
||||
Reference in New Issue
Block a user