mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][spirv] Fix extended umul expansion for WebGPU
Fix an off-by-one error in extended umul extension for WebGPU. Revert to the long multiplication algorithm originally added to wide integer emulation, which was deleted in D139776. It is much easier to see why it is correct. Add runtime tests based on the mlir-vulkan-runner. These run both with and without umul extension. Issue: https://github.com/llvm/llvm-project/issues/59563 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D141085
This commit is contained in:
@@ -17,8 +17,13 @@
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
#define GEN_PASS_DEF_SPIRVWEBGPUPREPAREPASS
|
||||
@@ -61,41 +66,62 @@ struct ExpandUMulExtendedPattern final : OpRewritePattern<UMulExtendedOp> {
|
||||
loc,
|
||||
llvm::formatv("Unexpected integer type for WebGPU: '{0}'", elemTy));
|
||||
|
||||
// Calculate the 'low' and the 'high' result separately, using long
|
||||
// multiplication:
|
||||
// Emulate 64-bit multiplication by splitting each input element of type i32
|
||||
// into 2 16-bit digits of type i32. This is so that the intermediate
|
||||
// multiplications and additions do not overflow. We extract these 16-bit
|
||||
// digits from i32 vector elements by masking (low digit) and shifting right
|
||||
// (high digit).
|
||||
//
|
||||
// lhs = [0 0] [a b]
|
||||
// rhs = [0 0] [c d]
|
||||
// --lhs * rhs--
|
||||
// = [ a * c ] [ b * d ] +
|
||||
// [ 0 ] [a * d + b * c] [ 0 ]
|
||||
//
|
||||
// ==> high = (a * c) + (a * d + b * c) >> 16
|
||||
Value low = rewriter.create<IMulOp>(loc, lhs, rhs);
|
||||
|
||||
// The multiplication algorithm used is the standard (long) multiplication.
|
||||
// Multiplying two i32 integers produces 64 bits of result, i.e., 4 16-bit
|
||||
// digits. After constant-folding, we end up emitting only 4 multiplications
|
||||
// and 4 additions.
|
||||
Value cstLowMask = rewriter.create<ConstantOp>(
|
||||
loc, lhs.getType(), getScalarOrSplatAttr(argTy, (1 << 16) - 1));
|
||||
auto getLowHalf = [&rewriter, loc, cstLowMask](Value val) {
|
||||
auto getLowDigit = [&rewriter, loc, cstLowMask](Value val) {
|
||||
return rewriter.create<BitwiseAndOp>(loc, val, cstLowMask);
|
||||
};
|
||||
|
||||
Value cst16 = rewriter.create<ConstantOp>(loc, lhs.getType(),
|
||||
getScalarOrSplatAttr(argTy, 16));
|
||||
auto getHighHalf = [&rewriter, loc, cst16](Value val) {
|
||||
auto getHighDigit = [&rewriter, loc, cst16](Value val) {
|
||||
return rewriter.create<ShiftRightLogicalOp>(loc, val, cst16);
|
||||
};
|
||||
|
||||
Value lhsLow = getLowHalf(lhs);
|
||||
Value lhsHigh = getHighHalf(lhs);
|
||||
Value rhsLow = getLowHalf(rhs);
|
||||
Value rhsHigh = getHighHalf(rhs);
|
||||
Value cst0 = rewriter.create<ConstantOp>(loc, lhs.getType(),
|
||||
getScalarOrSplatAttr(argTy, 0));
|
||||
|
||||
Value high0 = rewriter.create<IMulOp>(loc, lhsHigh, rhsHigh);
|
||||
Value mid = rewriter.create<IAddOp>(
|
||||
loc, rewriter.create<IMulOp>(loc, lhsHigh, rhsLow),
|
||||
rewriter.create<IMulOp>(loc, lhsLow, rhsHigh));
|
||||
Value high1 = getHighHalf(mid);
|
||||
Value high = rewriter.create<IAddOp>(loc, high0, high1);
|
||||
Value lhsLow = getLowDigit(lhs);
|
||||
Value lhsHigh = getHighDigit(lhs);
|
||||
Value rhsLow = getLowDigit(rhs);
|
||||
Value rhsHigh = getHighDigit(rhs);
|
||||
|
||||
std::array<Value, 2> lhsDigits = {lhsLow, lhsHigh};
|
||||
std::array<Value, 2> rhsDigits = {rhsLow, rhsHigh};
|
||||
std::array<Value, 4> resultDigits = {cst0, cst0, cst0, cst0};
|
||||
|
||||
for (auto [i, lhsDigit] : llvm::enumerate(lhsDigits)) {
|
||||
for (auto [j, rhsDigit] : llvm::enumerate(rhsDigits)) {
|
||||
Value &thisResDigit = resultDigits[i + j];
|
||||
Value mul = rewriter.create<IMulOp>(loc, lhsDigit, rhsDigit);
|
||||
Value current = rewriter.createOrFold<IAddOp>(loc, thisResDigit, mul);
|
||||
thisResDigit = getLowDigit(current);
|
||||
|
||||
if (i + j + 1 != resultDigits.size()) {
|
||||
Value &nextResDigit = resultDigits[i + j + 1];
|
||||
Value carry = rewriter.createOrFold<IAddOp>(loc, nextResDigit,
|
||||
getHighDigit(current));
|
||||
nextResDigit = carry;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto combineDigits = [loc, cst16, &rewriter](Value low, Value high) {
|
||||
Value highBits = rewriter.create<ShiftLeftLogicalOp>(loc, high, cst16);
|
||||
return rewriter.create<BitwiseOrOp>(loc, low, highBits);
|
||||
};
|
||||
Value low = combineDigits(resultDigits[0], resultDigits[1]);
|
||||
Value high = combineDigits(resultDigits[2], resultDigits[3]);
|
||||
|
||||
rewriter.replaceOpWithNewOp<CompositeConstructOp>(
|
||||
op, op.getType(), llvm::makeArrayRef({low, high}));
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics --spirv-webgpu-prepare %s | FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics \
|
||||
// RUN: --spirv-webgpu-prepare --cse %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.UMulExtended
|
||||
@@ -10,18 +11,23 @@ spirv.module Logical GLSL450 {
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32)
|
||||
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant 65535 : i32
|
||||
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant 16 : i32
|
||||
// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : i32
|
||||
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : i32
|
||||
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : i32
|
||||
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : i32
|
||||
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : i32
|
||||
// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : i32
|
||||
// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : i32
|
||||
// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : i32
|
||||
// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : i32
|
||||
// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : i32
|
||||
// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : i32
|
||||
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]] : (i32, i32) -> !spirv.struct<(i32, i32)>
|
||||
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
|
||||
// CHECK: spirv.BitwiseOr
|
||||
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]] : i32
|
||||
// CHECK: spirv.BitwiseOr
|
||||
// CHECK: [[RES:%.+]] = spirv.CompositeConstruct [[RESLO:%.+]], [[RESHI:%.+]] : (i32, i32) -> !spirv.struct<(i32, i32)>
|
||||
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(i32, i32)>
|
||||
spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i32)> "None" {
|
||||
%0 = spirv.UMulExtended %arg0, %arg1 : !spirv.struct<(i32, i32)>
|
||||
@@ -32,18 +38,23 @@ spirv.func @umul_extended_i32(%arg0 : i32, %arg1 : i32) -> !spirv.struct<(i32, i
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: vector<3xi32>, [[ARG1:%.+]]: vector<3xi32>)
|
||||
// CHECK-DAG: [[CSTMASK:%.+]] = spirv.Constant dense<65535> : vector<3xi32>
|
||||
// CHECK-DAG: [[CST16:%.+]] = spirv.Constant dense<16> : vector<3xi32>
|
||||
// CHECK-NEXT: [[RESLOW:%.+]] = spirv.IMul [[ARG0]], [[ARG1]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[LHSLOW:%.+]] = spirv.BitwiseAnd [[ARG0]], [[CSTMASK]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[LHSHI:%.+]] = spirv.ShiftRightLogical [[ARG0]], [[CST16]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[RHSLOW:%.+]] = spirv.BitwiseAnd [[ARG1]], [[CSTMASK]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[RHSHI:%.+]] = spirv.ShiftRightLogical [[ARG1]], [[CST16]] : vector<3xi32>
|
||||
// CHECK-DAG: [[RESHI0:%.+]] = spirv.IMul [[LHSHI]], [[RHSHI]] : vector<3xi32>
|
||||
// CHECK-DAG: [[MID0:%.+]] = spirv.IMul [[LHSHI]], [[RHSLOW]] : vector<3xi32>
|
||||
// CHECK-DAG: [[MID1:%.+]] = spirv.IMul [[LHSLOW]], [[RHSHI]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[MID:%.+]] = spirv.IAdd [[MID0]], [[MID1]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[RESHI1:%.+]] = spirv.ShiftRightLogical [[MID]], [[CST16]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[RESHI:%.+]] = spirv.IAdd [[RESHI0]], [[RESHI1]] : vector<3xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW]], [[RESHI]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSLOW]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSLOW]], [[RHSHI]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSLOW]]
|
||||
// CHECK-DAG: spirv.IMul [[LHSHI]], [[RHSHI]]
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK-DAG: spirv.IAdd
|
||||
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
|
||||
// CHECK: spirv.BitwiseOr
|
||||
// CHECK: spirv.ShiftLeftLogical {{%.+}}, [[CST16]]
|
||||
// CHECK: spirv.BitwiseOr
|
||||
// CHECK-NEXT: [[RES:%.+]] = spirv.CompositeConstruct [[RESLOW:%.+]], [[RESHI:%.+]]
|
||||
// CHECK-NEXT: spirv.ReturnValue [[RES]] : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
|
||||
spirv.func @umul_extended_vector_i32(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>)
|
||||
-> !spirv.struct<(vector<3xi32>, vector<3xi32>)> "None" {
|
||||
|
||||
66
mlir/test/mlir-vulkan-runner/umul_extended.mlir
Normal file
66
mlir/test/mlir-vulkan-runner/umul_extended.mlir
Normal file
@@ -0,0 +1,66 @@
|
||||
// Make sure that unsigned extended multiplication produces expected results
|
||||
// with and without expansion to primitive mul/add ops for WebGPU.
|
||||
|
||||
// RUN: mlir-vulkan-runner %s \
|
||||
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: --entry-point-result=void | FileCheck %s
|
||||
|
||||
// RUN: mlir-vulkan-runner %s --vulkan-runner-spirv-webgpu-prepare \
|
||||
// RUN: --shared-libs=%mlir_lib_dir/libvulkan-runtime-wrappers%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: --entry-point-result=void | FileCheck %s
|
||||
|
||||
// CHECK: [0, 1, -2, 1, 1048560, -87620295, -131071, -49]
|
||||
// CHECK: [0, 0, 1, -2, 0, 65534, -131070, 6]
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<
|
||||
#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
|
||||
} {
|
||||
gpu.module @kernels {
|
||||
gpu.func @kernel_add(%arg0 : memref<8xi32>, %arg1 : memref<8xi32>, %arg2 : memref<8xi32>, %arg3 : memref<8xi32>)
|
||||
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
|
||||
%0 = gpu.block_id x
|
||||
%lhs = memref.load %arg0[%0] : memref<8xi32>
|
||||
%rhs = memref.load %arg1[%0] : memref<8xi32>
|
||||
%low, %hi = arith.mului_extended %lhs, %rhs : i32
|
||||
memref.store %low, %arg2[%0] : memref<8xi32>
|
||||
memref.store %hi, %arg3[%0] : memref<8xi32>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
%buf0 = memref.alloc() : memref<8xi32>
|
||||
%buf1 = memref.alloc() : memref<8xi32>
|
||||
%buf2 = memref.alloc() : memref<8xi32>
|
||||
%buf3 = memref.alloc() : memref<8xi32>
|
||||
%i32_0 = arith.constant 0 : i32
|
||||
|
||||
// Initialize output buffers.
|
||||
%buf4 = memref.cast %buf2 : memref<8xi32> to memref<?xi32>
|
||||
%buf5 = memref.cast %buf3 : memref<8xi32> to memref<?xi32>
|
||||
call @fillResource1DInt(%buf4, %i32_0) : (memref<?xi32>, i32) -> ()
|
||||
call @fillResource1DInt(%buf5, %i32_0) : (memref<?xi32>, i32) -> ()
|
||||
|
||||
%idx_0 = arith.constant 0 : index
|
||||
%idx_1 = arith.constant 1 : index
|
||||
%idx_8 = arith.constant 8 : index
|
||||
|
||||
// Initialize input buffers.
|
||||
%lhs_vals = arith.constant dense<[0, 1, -1, -1, 65535, 65535, -65535, 7]> : vector<8xi32>
|
||||
%rhs_vals = arith.constant dense<[0, 1, 2, -1, 16, -1337, -65535, -7]> : vector<8xi32>
|
||||
vector.store %lhs_vals, %buf0[%idx_0] : memref<8xi32>, vector<8xi32>
|
||||
vector.store %rhs_vals, %buf1[%idx_0] : memref<8xi32>, vector<8xi32>
|
||||
|
||||
gpu.launch_func @kernels::@kernel_add
|
||||
blocks in (%idx_8, %idx_1, %idx_1) threads in (%idx_1, %idx_1, %idx_1)
|
||||
args(%buf0 : memref<8xi32>, %buf1 : memref<8xi32>, %buf2 : memref<8xi32>, %buf3 : memref<8xi32>)
|
||||
%buf_low = memref.cast %buf4 : memref<?xi32> to memref<*xi32>
|
||||
%buf_hi = memref.cast %buf5 : memref<?xi32> to memref<*xi32>
|
||||
call @printMemrefI32(%buf_low) : (memref<*xi32>) -> ()
|
||||
call @printMemrefI32(%buf_hi) : (memref<*xi32>) -> ()
|
||||
return
|
||||
}
|
||||
func.func private @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
|
||||
func.func private @printMemrefI32(%ptr : memref<*xi32>)
|
||||
}
|
||||
@@ -74,6 +74,8 @@ if (MLIR_ENABLE_VULKAN_RUNNER)
|
||||
MLIRTargetLLVMIRExport
|
||||
MLIRTransforms
|
||||
MLIRTranslateLib
|
||||
MLIRVectorDialect
|
||||
MLIRVectorToLLVM
|
||||
${Vulkan_LIBRARY}
|
||||
)
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
|
||||
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
|
||||
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
|
||||
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
@@ -30,18 +30,28 @@
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/ExecutionEngine/JitRunner.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
|
||||
namespace {
|
||||
struct VulkanRunnerOptions {
|
||||
llvm::cl::OptionCategory category{"mlir-vulkan-runner options"};
|
||||
llvm::cl::opt<bool> spirvWebGPUPrepare{
|
||||
"vulkan-runner-spirv-webgpu-prepare",
|
||||
llvm::cl::desc("Run MLIR transforms used when targetting WebGPU"),
|
||||
llvm::cl::cat(category)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult runMLIRPasses(Operation *op,
|
||||
VulkanRunnerOptions &options) {
|
||||
auto module = dyn_cast<ModuleOp>(op);
|
||||
if (!module)
|
||||
return op->emitOpError("expected a 'builtin.module' op");
|
||||
@@ -55,10 +65,13 @@ static LogicalResult runMLIRPasses(Operation *op, JitRunnerOptions &options) {
|
||||
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
|
||||
modulePM.addPass(spirv::createLowerABIAttributesPass());
|
||||
modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
|
||||
if (options.spirvWebGPUPrepare)
|
||||
modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
|
||||
|
||||
passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
|
||||
LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
|
||||
passManager.addPass(createMemRefToLLVMConversionPass());
|
||||
passManager.addPass(createConvertVectorToLLVMPass());
|
||||
passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
|
||||
passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
|
||||
passManager.addPass(createReconcileUnrealizedCastsPass());
|
||||
@@ -75,13 +88,21 @@ int main(int argc, char **argv) {
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
|
||||
// Initialize runner-specific CLI options. These will be parsed and
|
||||
// initialzied in `JitRunnerMain`.
|
||||
VulkanRunnerOptions options;
|
||||
auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) {
|
||||
return runMLIRPasses(op, options);
|
||||
};
|
||||
|
||||
mlir::JitRunnerConfig jitRunnerConfig;
|
||||
jitRunnerConfig.mlirTransformer = runMLIRPasses;
|
||||
jitRunnerConfig.mlirTransformer = runPassesWithOptions;
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
|
||||
mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
|
||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
|
||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
||||
mlir::vector::VectorDialect>();
|
||||
mlir::registerLLVMDialectTranslation(registry);
|
||||
|
||||
return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
|
||||
|
||||
@@ -7303,6 +7303,8 @@ cc_binary(
|
||||
":SPIRVDialect",
|
||||
":SPIRVTransforms",
|
||||
":ToLLVMIRTranslation",
|
||||
":VectorDialect",
|
||||
":VectorToLLVM",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user