mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[mlir][spirv] Add IsInf/IsNan expansion for WebGPU (#86903)
These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in https://github.com/llvm/llvm-project/pull/86578. Also do some misc cleanups in the surrounding code.
This commit is contained in:
@@ -18,12 +18,18 @@
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
/// Appends to a pattern list additional patterns to expand extended
|
||||
/// multiplication ops into regular arithmetic ops. Extended multiplication ops
|
||||
/// are not supported by the WebGPU Shading Language (WGSL).
|
||||
/// Appends patterns to expand extended multiplication and adition ops into
|
||||
/// regular arithmetic ops. Extended arithmetic ops are not supported by the
|
||||
/// WebGPU Shading Language (WGSL).
|
||||
void populateSPIRVExpandExtendedMultiplicationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
|
||||
/// These are not supported by the WebGPU Shading Language (WGSL). We follow
|
||||
/// fast math assumptions and assume that all floating point values are finite.
|
||||
void populateSPIRVExpandNonFiniteArithmeticPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
|
||||
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
|
||||
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
|
||||
if (auto intTy = dyn_cast<IntegerType>(type))
|
||||
return IntegerAttr::get(intTy, sizedValue);
|
||||
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
|
||||
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
|
||||
}
|
||||
|
||||
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
|
||||
Value lhs, Value rhs,
|
||||
bool signExtendArguments) {
|
||||
static Value lowerExtendedMultiplication(Operation *mulOp,
|
||||
PatternRewriter &rewriter, Value lhs,
|
||||
Value rhs, bool signExtendArguments) {
|
||||
Location loc = mulOp->getLoc();
|
||||
Type argTy = lhs.getType();
|
||||
// Emulate 64-bit multiplication by splitting each input element of type i32
|
||||
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IsInfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We assume values to be finite and turn `IsInf` info `false`.
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||||
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IsNanOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We assume values to be finite and turn `IsNan` info `false`.
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
|
||||
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
class WebGPUPreparePass
|
||||
: public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
|
||||
public:
|
||||
struct WebGPUPreparePass final
|
||||
: impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
|
||||
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
|
||||
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
// WGSL currently does not support extended multiplication ops, see:
|
||||
// https://github.com/gpuweb/gpuweb/issues/1565.
|
||||
patterns.add<
|
||||
// clang-format off
|
||||
ExpandSMulExtendedPattern,
|
||||
ExpandUMulExtendedPattern,
|
||||
ExpandAddCarryPattern
|
||||
>(patterns.getContext());
|
||||
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
|
||||
ExpandAddCarryPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
void populateSPIRVExpandNonFiniteArithmeticPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
// WGSL currently does not support `isInf` and `isNan`, see:
|
||||
// https://github.com/gpuweb/gpuweb/pull/2311.
|
||||
patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
|
||||
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @is_inf_f32
|
||||
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
|
||||
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
|
||||
spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
|
||||
%0 = spirv.IsInf %a : f32
|
||||
spirv.ReturnValue %0 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @is_inf_4xf32
|
||||
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
|
||||
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
|
||||
spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
|
||||
%0 = spirv.IsInf %a : vector<4xf32>
|
||||
spirv.ReturnValue %0 : vector<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @is_nan_f32
|
||||
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
|
||||
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
|
||||
spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
|
||||
%0 = spirv.IsNan %a : f32
|
||||
spirv.ReturnValue %0 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @is_nan_4xf32
|
||||
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
|
||||
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
|
||||
spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
|
||||
%0 = spirv.IsNan %a : vector<4xf32>
|
||||
spirv.ReturnValue %0 : vector<4xi1>
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
||||
Reference in New Issue
Block a user