From d61ec513c42005bb071eb15386deb5de585ff267 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 28 Mar 2024 14:13:04 -0400 Subject: [PATCH] [mlir][spirv] Add IsInf/IsNan expansion for WebGPU (#86903) These non-finite math ops are supported by SPIR-V but not by WGSL. Assume finite floating point values and expand these ops into `false`. Previously, this worked by adding fast math flags during conversion from arith to spirv, but this got removed in https://github.com/llvm/llvm-project/pull/86578. Also do some misc cleanups in the surrounding code. --- .../SPIRV/Transforms/SPIRVWebGPUTransforms.h | 12 +++-- .../Transforms/SPIRVWebGPUTransforms.cpp | 54 ++++++++++++++----- .../SPIRV/Transforms/webgpu-prepare.mlir | 32 +++++++++++ 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h index ac4d38e0c5b1..d0fc85ccc9de 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h @@ -18,12 +18,18 @@ namespace mlir { namespace spirv { -/// Appends to a pattern list additional patterns to expand extended -/// multiplication ops into regular arithmetic ops. Extended multiplication ops -/// are not supported by the WebGPU Shading Language (WGSL). +/// Appends patterns to expand extended multiplication and adition ops into +/// regular arithmetic ops. Extended arithmetic ops are not supported by the +/// WebGPU Shading Language (WGSL). void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns); +/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`. +/// These are not supported by the WebGPU Shading Language (WGSL). We follow +/// fast math assumptions and assume that all floating point values are finite. +void populateSPIRVExpandNonFiniteArithmeticPatterns( + RewritePatternSet &patterns); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index 21de1c9e867c..5d4dd5b3a1e0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -39,7 +39,7 @@ namespace { //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// -Attribute getScalarOrSplatAttr(Type type, int64_t value) { +static Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); if (auto intTy = dyn_cast(type)) return IntegerAttr::get(intTy, sizedValue); @@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) { return SplatElementsAttr::get(cast(type), sizedValue); } -Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, - Value lhs, Value rhs, - bool signExtendArguments) { +static Value lowerExtendedMultiplication(Operation *mulOp, + PatternRewriter &rewriter, Value lhs, + Value rhs, bool signExtendArguments) { Location loc = mulOp->getLoc(); Type argTy = lhs.getType(); // Emulate 64-bit multiplication by splitting each input element of type i32 @@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern { } }; +struct ExpandIsInfPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsInfOp op, + PatternRewriter &rewriter) const override { + // We assume values to be finite and turn `IsInf` info `false`. + rewriter.replaceOpWithNewOp( + op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); + return success(); + } +}; + +struct ExpandIsNanPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsNanOp op, + PatternRewriter &rewriter) const override { + // We assume values to be finite and turn `IsNan` info `false`. + rewriter.replaceOpWithNewOp( + op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// -class WebGPUPreparePass - : public impl::SPIRVWebGPUPreparePassBase { -public: +struct WebGPUPreparePass final + : impl::SPIRVWebGPUPreparePassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSPIRVExpandExtendedMultiplicationPatterns(patterns); + populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) @@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns) { // WGSL currently does not support extended multiplication ops, see: // https://github.com/gpuweb/gpuweb/issues/1565. - patterns.add< - // clang-format off - ExpandSMulExtendedPattern, - ExpandUMulExtendedPattern, - ExpandAddCarryPattern - >(patterns.getContext()); + patterns.add(patterns.getContext()); } + +void populateSPIRVExpandNonFiniteArithmeticPatterns( + RewritePatternSet &patterns) { + // WGSL currently does not support `isInf` and `isNan`, see: + // https://github.com/gpuweb/gpuweb/pull/2311. + patterns.add(patterns.getContext()); +} + } // namespace spirv } // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir index 1ec4e5e4f966..45f188da3815 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir @@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None spirv.ReturnValue %0 : !spirv.struct<(i16, i16)> } +// CHECK-LABEL: func @is_inf_f32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1 +spirv.func @is_inf_f32(%a : f32) -> i1 "None" { + %0 = spirv.IsInf %a : f32 + spirv.ReturnValue %0 : i1 +} + +// CHECK-LABEL: func @is_inf_4xf32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense : vector<4xi1> +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1> +spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" { + %0 = spirv.IsInf %a : vector<4xf32> + spirv.ReturnValue %0 : vector<4xi1> +} + +// CHECK-LABEL: func @is_nan_f32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1 +spirv.func @is_nan_f32(%a : f32) -> i1 "None" { + %0 = spirv.IsNan %a : f32 + spirv.ReturnValue %0 : i1 +} + +// CHECK-LABEL: func @is_nan_4xf32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense : vector<4xi1> +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1> +spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" { + %0 = spirv.IsNan %a : vector<4xf32> + spirv.ReturnValue %0 : vector<4xi1> +} + } // end module