[mlir][spirv] Add IsInf/IsNan expansion for WebGPU (#86903)

These non-finite math ops are supported by SPIR-V but not by WGSL.
Assume finite floating point values and expand these ops into `false`.

Previously, this worked by adding fast math flags during conversion from
arith to spirv, but this got removed in
https://github.com/llvm/llvm-project/pull/86578.

Also do some misc cleanups in the surrounding code.
This commit is contained in:
Jakub Kuderski
2024-03-28 14:13:04 -04:00
committed by GitHub
parent 599027857e
commit d61ec513c4
3 changed files with 82 additions and 16 deletions

View File

@@ -18,12 +18,18 @@
namespace mlir {
namespace spirv {
/// Appends to a pattern list additional patterns to expand extended
/// multiplication ops into regular arithmetic ops. Extended multiplication ops
/// are not supported by the WebGPU Shading Language (WGSL).
/// Appends patterns to expand extended multiplication and adition ops into
/// regular arithmetic ops. Extended arithmetic ops are not supported by the
/// WebGPU Shading Language (WGSL).
void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns);
/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`.
/// These are not supported by the WebGPU Shading Language (WGSL). We follow
/// fast math assumptions and assume that all floating point values are finite.
void populateSPIRVExpandNonFiniteArithmeticPatterns(
RewritePatternSet &patterns);
} // namespace spirv
} // namespace mlir

View File

@@ -39,7 +39,7 @@ namespace {
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
static Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
@@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
Value lhs, Value rhs,
bool signExtendArguments) {
static Value lowerExtendedMultiplication(Operation *mulOp,
PatternRewriter &rewriter, Value lhs,
Value rhs, bool signExtendArguments) {
Location loc = mulOp->getLoc();
Type argTy = lhs.getType();
// Emulate 64-bit multiplication by splitting each input element of type i32
@@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern<IAddCarryOp> {
}
};
struct ExpandIsInfPattern final : OpRewritePattern<IsInfOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(IsInfOp op,
PatternRewriter &rewriter) const override {
// We assume values to be finite and turn `IsInf` info `false`.
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};
struct ExpandIsNanPattern final : OpRewritePattern<IsNanOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(IsNanOp op,
PatternRewriter &rewriter) const override {
// We assume values to be finite and turn `IsNan` info `false`.
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, op.getType(), getScalarOrSplatAttr(op.getType(), 0));
return success();
}
};
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
class WebGPUPreparePass
: public impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
public:
struct WebGPUPreparePass final
: impl::SPIRVWebGPUPreparePassBase<WebGPUPreparePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
@@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support extended multiplication ops, see:
// https://github.com/gpuweb/gpuweb/issues/1565.
patterns.add<
// clang-format off
ExpandSMulExtendedPattern,
ExpandUMulExtendedPattern,
ExpandAddCarryPattern
>(patterns.getContext());
patterns.add<ExpandSMulExtendedPattern, ExpandUMulExtendedPattern,
ExpandAddCarryPattern>(patterns.getContext());
}
void populateSPIRVExpandNonFiniteArithmeticPatterns(
RewritePatternSet &patterns) {
// WGSL currently does not support `isInf` and `isNan`, see:
// https://github.com/gpuweb/gpuweb/pull/2311.
patterns.add<ExpandIsInfPattern, ExpandIsNanPattern>(patterns.getContext());
}
} // namespace spirv
} // namespace mlir

View File

@@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None
spirv.ReturnValue %0 : !spirv.struct<(i16, i16)>
}
// CHECK-LABEL: func @is_inf_f32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
spirv.func @is_inf_f32(%a : f32) -> i1 "None" {
%0 = spirv.IsInf %a : f32
spirv.ReturnValue %0 : i1
}
// CHECK-LABEL: func @is_inf_4xf32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
%0 = spirv.IsInf %a : vector<4xf32>
spirv.ReturnValue %0 : vector<4xi1>
}
// CHECK-LABEL: func @is_nan_f32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1
spirv.func @is_nan_f32(%a : f32) -> i1 "None" {
%0 = spirv.IsNan %a : f32
spirv.ReturnValue %0 : i1
}
// CHECK-LABEL: func @is_nan_4xf32
// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense<false> : vector<4xi1>
// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1>
spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" {
%0 = spirv.IsNan %a : vector<4xf32>
spirv.ReturnValue %0 : vector<4xi1>
}
} // end module