diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h index ac4d38e0c5b1..d0fc85ccc9de 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.h @@ -18,12 +18,18 @@ namespace mlir { namespace spirv { -/// Appends to a pattern list additional patterns to expand extended -/// multiplication ops into regular arithmetic ops. Extended multiplication ops -/// are not supported by the WebGPU Shading Language (WGSL). +/// Appends patterns to expand extended multiplication and adition ops into +/// regular arithmetic ops. Extended arithmetic ops are not supported by the +/// WebGPU Shading Language (WGSL). void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns); +/// Appends patterns to expand non-finite arithmetic ops `IsNan` and `IsInf`. +/// These are not supported by the WebGPU Shading Language (WGSL). We follow +/// fast math assumptions and assume that all floating point values are finite. +void populateSPIRVExpandNonFiniteArithmeticPatterns( + RewritePatternSet &patterns); + } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index 21de1c9e867c..5d4dd5b3a1e0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -39,7 +39,7 @@ namespace { //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// -Attribute getScalarOrSplatAttr(Type type, int64_t value) { +static Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); if (auto intTy = dyn_cast(type)) return IntegerAttr::get(intTy, sizedValue); @@ -47,9 +47,9 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) { return SplatElementsAttr::get(cast(type), sizedValue); } -Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, - Value lhs, Value rhs, - bool signExtendArguments) { +static Value lowerExtendedMultiplication(Operation *mulOp, + PatternRewriter &rewriter, Value lhs, + Value rhs, bool signExtendArguments) { Location loc = mulOp->getLoc(); Type argTy = lhs.getType(); // Emulate 64-bit multiplication by splitting each input element of type i32 @@ -203,15 +203,39 @@ struct ExpandAddCarryPattern final : OpRewritePattern { } }; +struct ExpandIsInfPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsInfOp op, + PatternRewriter &rewriter) const override { + // We assume values to be finite and turn `IsInf` info `false`. + rewriter.replaceOpWithNewOp( + op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); + return success(); + } +}; + +struct ExpandIsNanPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IsNanOp op, + PatternRewriter &rewriter) const override { + // We assume values to be finite and turn `IsNan` info `false`. + rewriter.replaceOpWithNewOp( + op, op.getType(), getScalarOrSplatAttr(op.getType(), 0)); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// -class WebGPUPreparePass - : public impl::SPIRVWebGPUPreparePassBase { -public: +struct WebGPUPreparePass final + : impl::SPIRVWebGPUPreparePassBase { void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSPIRVExpandExtendedMultiplicationPatterns(patterns); + populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) @@ -227,12 +251,16 @@ void populateSPIRVExpandExtendedMultiplicationPatterns( RewritePatternSet &patterns) { // WGSL currently does not support extended multiplication ops, see: // https://github.com/gpuweb/gpuweb/issues/1565. - patterns.add< - // clang-format off - ExpandSMulExtendedPattern, - ExpandUMulExtendedPattern, - ExpandAddCarryPattern - >(patterns.getContext()); + patterns.add(patterns.getContext()); } + +void populateSPIRVExpandNonFiniteArithmeticPatterns( + RewritePatternSet &patterns) { + // WGSL currently does not support `isInf` and `isNan`, see: + // https://github.com/gpuweb/gpuweb/pull/2311. + patterns.add(patterns.getContext()); +} + } // namespace spirv } // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir index 1ec4e5e4f966..45f188da3815 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/webgpu-prepare.mlir @@ -182,4 +182,36 @@ spirv.func @iaddcarry_i16(%a : i16, %b : i16) -> !spirv.struct<(i16, i16)> "None spirv.ReturnValue %0 : !spirv.struct<(i16, i16)> } +// CHECK-LABEL: func @is_inf_f32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1 +spirv.func @is_inf_f32(%a : f32) -> i1 "None" { + %0 = spirv.IsInf %a : f32 + spirv.ReturnValue %0 : i1 +} + +// CHECK-LABEL: func @is_inf_4xf32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense : vector<4xi1> +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1> +spirv.func @is_inf_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" { + %0 = spirv.IsInf %a : vector<4xf32> + spirv.ReturnValue %0 : vector<4xi1> +} + +// CHECK-LABEL: func @is_nan_f32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant false +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : i1 +spirv.func @is_nan_f32(%a : f32) -> i1 "None" { + %0 = spirv.IsNan %a : f32 + spirv.ReturnValue %0 : i1 +} + +// CHECK-LABEL: func @is_nan_4xf32 +// CHECK-NEXT: [[FALSE:%.+]] = spirv.Constant dense : vector<4xi1> +// CHECK-NEXT: spirv.ReturnValue [[FALSE]] : vector<4xi1> +spirv.func @is_nan_4xf32(%a : vector<4xf32>) -> vector<4xi1> "None" { + %0 = spirv.IsNan %a : vector<4xf32> + spirv.ReturnValue %0 : vector<4xi1> +} + } // end module