[mlir] Add RuntimeVerifiableOpInterface and transform

Static op verification cannot detect cases where an op is valid at compile time but may be invalid at runtime.

An example of such an op is `memref::ExpandShapeOp`.

Invalid at compile time: `memref.expand_shape %m [[0, 1]] : memref<11xf32> into memref<2x5xf32>`

Valid at compile time (because we do not know any better): `memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>`. This op may or may not be valid at runtime depending on the runtime shape of `%m`.

Invalid runtime ops such as the one above are hard to debug because they can crash the program execution at a seemingly unrelated position or (even worse) compute an invalid result without crashing.

This revision adds a new op interface `RuntimeVerifiableOpInterface` that can be implemented by ops that provide additional runtime verification. Such runtime verification can be computationally expensive, so it is only generated on an opt-in basis by running `-generate-runtime-verification`. A simple runtime verifier for `memref::ExpandShapeOp` is provided as an example.

Differential Revision: https://reviews.llvm.org/D138576
This commit is contained in:
Matthias Springer
2022-12-21 10:51:10 +01:00
parent b8e1071a29
commit 108b08f2a9
15 changed files with 280 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
namespace mlir {
class DialectRegistry;
namespace memref {
void registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace memref
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H

View File

@@ -45,6 +45,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -130,6 +131,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);

View File

@@ -8,6 +8,7 @@ add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface)

View File

@@ -0,0 +1,17 @@
//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
/// Include the generated interface declarations.
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc"
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_

View File

@@ -0,0 +1,40 @@
//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
include "mlir/IR/OpBase.td"
def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
let description = [{
Implementations of this interface generate IR for runtime op verification.
Incorrect op usage can often be caught by op verifiers based on static
program information. However, in the absence of static program information,
it can remain undetected at compile time (e.g., in case of dynamic memref
strides instead of static memref strides). Such cases can be checked at
runtime. The op-specific checks are generated by this interface.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Generate IR to verify this op at runtime, aborting runtime execution if
verification fails.
}],
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
"::mlir::Location":$loc)
>,
];
}
#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE

View File

@@ -64,6 +64,9 @@ std::unique_ptr<Pass> createControlFlowSinkPass();
/// Creates a pass to perform common sub expression elimination.
std::unique_ptr<Pass> createCSEPass();
/// Creates a pass that generates IR to verify ops at runtime.
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
/// Creates a loop invariant code motion pass that hoists loop invariant
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();

View File

@@ -77,6 +77,16 @@ def CSE : Pass<"cse"> {
];
}
def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
let summary = "Generate additional runtime op verification checks";
let description = [{
This pass generates op-specific runtime checks using the
`RuntimeVerifiableOpInterface`. It can be run for debugging purposes after
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
}
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";

View File

@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MultiBuffer.cpp
NormalizeMemRefs.cpp
ResolveShapedTypeResultDims.cpp
RuntimeOpVerification.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef

View File

@@ -0,0 +1,70 @@
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
namespace mlir {
namespace memref {
namespace {
struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
// size.
for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) {
Value srcDimSz =
builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
int64_t groupSz = 1;
bool foundDynamicDim = false;
for (int64_t resultDim : it.value()) {
if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
// Keep this assert here in case the op is extended in the future.
assert(!foundDynamicDim &&
"more than one dynamic dim found in reassoc group");
foundDynamicDim = true;
continue;
}
groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
}
Value staticResultDimSz =
builder.create<arith::ConstantIndexOp>(loc, groupSz);
// staticResultDimSz must divide srcDimSz evenly.
Value mod =
builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
Value isModZero = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, mod,
builder.create<arith::ConstantIndexOp>(loc, 0));
builder.create<cf::AssertOp>(
loc, isModZero,
"static result dims in reassoc group do not divide src dim evenly");
}
}
};
} // namespace
} // namespace memref
} // namespace mlir
void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
});
}

View File

@@ -10,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
ParallelCombiningOpInterface.cpp
RuntimeVerifiableOpInterface.cpp
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
TilingInterface.cpp
@@ -44,6 +45,7 @@ add_mlir_interface_library(InferIntRangeInterface)
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(LoopLikeInterface)
add_mlir_interface_library(ParallelCombiningOpInterface)
add_mlir_interface_library(RuntimeVerifiableOpInterface)
add_mlir_interface_library(ShapedOpInterfaces)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)

View File

@@ -0,0 +1,17 @@
//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
namespace mlir {
class Location;
class OpBuilder;
} // namespace mlir
/// Include the definitions of the interface.
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"

View File

@@ -4,6 +4,7 @@ add_mlir_library(MLIRTransforms
Canonicalizer.cpp
ControlFlowSink.cpp
CSE.cpp
GenerateRuntimeVerification.cpp
Inliner.cpp
LocationSnapshot.cpp
LoopInvariantCodeMotion.cpp
@@ -26,6 +27,7 @@ add_mlir_library(MLIRTransforms
MLIRCopyOpInterface
MLIRLoopLikeInterface
MLIRPass
MLIRRuntimeVerifiableOpInterface
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils

View File

@@ -0,0 +1,40 @@
//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
namespace mlir {
#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
struct GenerateRuntimeVerificationPass
: public impl::GenerateRuntimeVerificationBase<
GenerateRuntimeVerificationPass> {
void runOnOperation() override;
};
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
OpBuilder builder(getOperation()->getContext());
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
});
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
return std::make_unique<GenerateRuntimeVerificationPass>();
}

View File

@@ -0,0 +1,14 @@
// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s
// CHECK-LABEL: func @expand_shape(
// CHECK-SAME: %[[m:.*]]: memref<?xf32>
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
%0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
return %0 : memref<?x5xf32>
}

View File

@@ -1037,6 +1037,13 @@ td_library(
deps = [":OpBaseTdFiles"],
)
td_library(
name = "RuntimeVerifiableOpInterfaceTdFiles",
srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"],
includes = ["include"],
deps = [":OpBaseTdFiles"],
)
td_library(
name = "SideEffectInterfacesTdFiles",
srcs = [
@@ -2992,6 +2999,18 @@ cc_library(
],
)
cc_library(
name = "RuntimeVerifiableOpInterface",
srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"],
hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"],
includes = ["include"],
deps = [
":IR",
":RuntimeVerifiableOpInterfaceIncGen",
"//llvm:Support",
],
)
cc_library(
name = "VectorInterfaces",
srcs = ["lib/Interfaces/VectorInterfaces.cpp"],
@@ -5715,6 +5734,24 @@ gentbl_cc_library(
deps = [":ParallelCombiningOpInterfaceTdFiles"],
)
gentbl_cc_library(
name = "RuntimeVerifiableOpInterfaceIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-interface-decls"],
"include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
"include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td",
deps = [":RuntimeVerifiableOpInterfaceTdFiles"],
)
gentbl_cc_library(
name = "VectorInterfacesIncGen",
strip_include_prefix = "include",
@@ -5818,6 +5855,7 @@ cc_library(
":LoopLikeInterface",
":Pass",
":Rewrite",
":RuntimeVerifiableOpInterface",
":SideEffectInterfaces",
":Support",
":TransformUtils",
@@ -9783,6 +9821,7 @@ cc_library(
":ArithDialect",
":ArithTransforms",
":ArithUtils",
":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
":IR",
@@ -9791,6 +9830,7 @@ cc_library(
":MemRefDialect",
":MemRefPassIncGen",
":Pass",
":RuntimeVerifiableOpInterface",
":TensorDialect",
":Transforms",
":VectorDialect",