[mlir][spirv] Fix extended umul expansion for WebGPU

Fix an off-by-one error in extended umul extension for WebGPU.
Revert to the long multiplication algorithm originally added to wide
integer emulation, which was deleted in D139776. It is much easier
to see why it is correct.

Add runtime tests based on the mlir-vulkan-runner. These run both with
and without umul extension.

Issue: https://github.com/llvm/llvm-project/issues/59563

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D141085
This commit is contained in:
Jakub Kuderski
2023-01-05 18:37:49 -05:00
parent 2c6ecc9db6
commit 47232bea9e
6 changed files with 174 additions and 46 deletions

View File

@@ -17,8 +17,13 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <array>
#include <cstdint>
namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
@@ -61,41 +66,62 @@ struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
loc,
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
// Calculate the 'low' and the 'high' result separately, using long
// multiplication:
// Emulate 64-bit multiplication by splitting each input element of type i32
// into 2 16-bit digits of type i32. This is so that the intermediate
// multiplications and additions do not overflow. We extract these 16-bit
// digits from i32 vector elements by masking (low digit) and shifting right
// (high digit).
//
// lhs = [0 0] [a b]
// rhs = [0 0] [c d]
// --lhs * rhs--
// = [ a * c ] [ b * d ] +
// [ 0 ] [a * d + b * c] [ 0 ]
//
// ==> high = (a * c) + (a * d + b * c) >> 16
Value low = rewriter.create<IMulOp>(loc, lhs, rhs);
// The multiplication algorithm used is the standard (long) multiplication.
// Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
// digits. After constant-folding, we end up emitting only 4 multiplications
// and 4 additions.
Value cstLowMask = rewriter.create<ConstantOp>(
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) {
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
};
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
getScalarOrSplatAttr(argTy, 16));
auto getHighHalf = [&rewriter, loc, cst16](Value val) {
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
};
Value lhsLow = getLowHalf(lhs);
Value lhsHigh = getHighHalf(lhs);
Value rhsLow = getLowHalf(rhs);
Value rhsHigh = getHighHalf(rhs);
Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
getScalarOrSplatAttr(argTy, 0));
Value high0 = rewriter.create<IMulOp>(loc, lhsHigh, rhsHigh);
Value mid = rewriter.create<IAddOp>(
loc, rewriter.create<IMulOp>(loc, lhsHigh, rhsLow),
rewriter.create<IMulOp>(loc, lhsLow, rhsHigh));
Value high1 = getHighHalf(mid);
Value high = rewriter.create<IAddOp>(loc, high0, high1);
Value lhsLow = getLowDigit(lhs);
Value lhsHigh = getHighDigit(lhs);
Value rhsLow = getLowDigit(rhs);
Value rhsHigh = getHighDigit(rhs);
std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
Value &thisResDigit = resultDigits[i + j];
Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
thisResDigit = getLowDigit(current);
if (i + j + 1 != resultDigits.size()) {
Value &nextResDigit = resultDigits[i + j + 1];
Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
getHighDigit(current));
nextResDigit = carry;
}
}
}
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
};
Value low = combineDigits(resultDigits[0], resultDigits[1]);
Value high = combineDigits(resultDigits[2], resultDigits[3]);
rewriter.replaceOpWithNewOp<CompositeConstructOp>(
op, op.getType(), llvm::makeArrayRef({low, high}));

View File

@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s
// RUN: mlir-opt --split-input-file --verify-diagnostics \
// RUN: --spirv-webgpu-prepare --cse %s | FileCheck %s
//===----------------------------------------------------------------------===//
// spirv.UMulExtended
@@ -10,18 +11,23 @@ spirv.module Logical GLSL450 {
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32
// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : i32
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : i32
// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32
// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32
// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : i32
// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32
// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)>
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
// CHECK: spirv.BitwiseOr
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
// CHECK: spirv.BitwiseOr
// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
%0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
@@ -32,18 +38,23 @@ spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i
// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32>
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32>
// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32>
// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32>
// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32>
// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32>
// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32>
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]]
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK-DAG: spirv.IAdd
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
// CHECK: spirv.BitwiseOr
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
// CHECK: spirv.BitwiseOr
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {

View File

@@ -0,0 +1,66 @@
// Make sure that unsigned extended multiplication produces expected results
// with and without expansion to primitive mul/add ops for WebGPU.
// RUN: mlir-vulkan-runner %s \
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void | FileCheck %s
// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void | FileCheck %s
// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, -49]
// CHECK: [0, 0, 1, -2, 0, 65534, -131070, 6]
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
} {
gpu.module @kernels {
gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
%0 = gpu.block_id x
%lhs = memref.load %arg0[%0] : memref<8xi32>
%rhs = memref.load %arg1[%0] : memref<8xi32>
%low, %hi = arith.mului_extended %lhs, %rhs : i32
memref.store %low, %arg2[%0] : memref<8xi32>
memref.store %hi, %arg3[%0] : memref<8xi32>
gpu.return
}
}
func.func @main() {
%buf0 = memref.alloc() : memref<8xi32>
%buf1 = memref.alloc() : memref<8xi32>
%buf2 = memref.alloc() : memref<8xi32>
%buf3 = memref.alloc() : memref<8xi32>
%i32_0 = arith.constant 0 : i32
// Initialize output buffers.
%buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
%buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_8 = arith.constant 8 : index
// Initialize input buffers.
%lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 7]> : vector<8xi32>
%rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -7]> : vector<8xi32>
vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
gpu.launch_func @kernels::@kernel_add
blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
%buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
%buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
return
}
func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
func.func private @printMemrefI32(%ptr : memref<*xi32>)
}

View File

@@ -74,6 +74,8 @@ if (MLIR_ENABLE_VULKAN_RUNNER)
MLIRTargetLLVMIRExport
MLIRTransforms
MLIRTranslateLib
MLIRVectorDialect
MLIRVectorToLLVM
${Vulkan_LIBRARY}
)

View File

@@ -13,12 +13,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -30,18 +30,28 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
using namespace mlir;
static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
namespace {
struct VulkanRunnerOptions {
llvm::cl::OptionCategory category{"mlir-vulkan-runner options"};
llvm::cl::opt<bool> spirvWebGPUPrepare{
"vulkan-runner-spirv-webgpu-prepare",
llvm::cl::desc("Run MLIR transforms used when targetting WebGPU"),
llvm::cl::cat(category)};
};
} // namespace
static LogicalResult runMLIRPasses(Operation *op,
VulkanRunnerOptions &options) {
auto module = dyn_cast<ModuleOp>(op);
if (!module)
return op->emitOpError("expected a 'builtin.module' op");
@@ -55,10 +65,13 @@ static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
modulePM.addPass(spirv::createLowerABIAttributesPass());
modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
if (options.spirvWebGPUPrepare)
modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
passManager.addPass(createMemRefToLLVMConversionPass());
passManager.addPass(createConvertVectorToLLVMPass());
passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
passManager.addPass(createReconcileUnrealizedCastsPass());
@@ -75,13 +88,21 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Initialize runner-specific CLI options. These will be parsed and
// initialzied in `JitRunnerMain`.
VulkanRunnerOptions options;
auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) {
return runMLIRPasses(op, options);
};
mlir::JitRunnerConfig jitRunnerConfig;
jitRunnerConfig.mlirTransformer = runMLIRPasses;
jitRunnerConfig.mlirTransformer = runPassesWithOptions;
mlir::DialectRegistry registry;
registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
mlir::vector::VectorDialect>();
mlir::registerLLVMDialectTranslation(registry);
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);

View File

@@ -7303,6 +7303,8 @@ cc_binary(
":SPIRVDialect",
":SPIRVTransforms",
":ToLLVMIRTranslation",
":VectorDialect",
":VectorToLLVM",
"//llvm:Support",
],
)