[mlir][UB] Add ub.unreachable operation (#169872)

Add `ub.unreachable` operation and lowerings to LLVM/SPIRV.

---------

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This commit is contained in:
Matthias Springer
2025-11-28 18:35:18 +08:00
committed by GitHub
parent b76089c7f3
commit 310211cce5
6 changed files with 88 additions and 8 deletions

View File

@@ -66,4 +66,24 @@ def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// UnreachableOp
//===----------------------------------------------------------------------===//
def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
let summary = "Unreachable operation.";
let description = [{
The `unreachable` operation triggers immediate undefined behavior if
executed.
Example:
```
ub.unreachable
```
}];
let assemblyFormat = "attr-dict";
}
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD

View File

@@ -23,8 +23,11 @@ namespace mlir {
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// PoisonOpLowering
//===----------------------------------------------------------------------===//
namespace {
struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
@@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// PoisonOpLowering
//===----------------------------------------------------------------------===//
LogicalResult
PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -60,6 +58,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
return success();
}
//===----------------------------------------------------------------------===//
// UnreachableOpLowering
//===----------------------------------------------------------------------===//
namespace {
struct UnreachableOpLowering
: public ConvertOpToLLVMPattern<ub::UnreachableOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
UnreachableOpLowering::matchAndRewrite(
ub::UnreachableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op);
return success();
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -93,7 +114,7 @@ struct UBToLLVMConversionPass
void mlir::ub::populateUBToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<PoisonOpLowering>(converter);
patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter);
}
//===----------------------------------------------------------------------===//

View File

@@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
}
};
struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> {
using Base::Base;
LogicalResult
matchAndRewrite(ub::UnreachableOp op, OpAdaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final
void mlir::ub::populateUBToSPIRVConversionPatterns(
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<PoisonOpLowering>(converter, patterns.getContext());
patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter,
patterns.getContext());
}

View File

@@ -17,3 +17,9 @@ func.func @check_poison() {
%3 = ub.poison : !llvm.ptr
return
}
// CHECK-LABEL: @check_unrechable
func.func @check_unrechable() {
// CHECK: llvm.unreachable
ub.unreachable
}

View File

@@ -19,3 +19,18 @@ func.func @check_poison() {
}
}
// -----
// No successful test because the dialect conversion framework does not convert
// unreachable blocks.
module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
func.func @check_unrechable() {
// expected-error@+1{{cannot be used in reachable block}}
spirv.Unreachable
}
}

View File

@@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> {
%0 = ub.poison : tensor<8x?xf64>
return %0 : tensor<8x?xf64>
}
// CHECK-LABEL: func @unreachable()
// CHECK: ub.unreachable
func.func @unreachable() {
ub.unreachable
}