mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 09:31:59 +08:00
[MLIR] Add initial convert-memref-to-emitc pass (#85389)
This converts `memref.alloca`, `memref.load` & `memref.store` to `emitc.variable`, `emitc.subscript` and `emitc.assign`.
This commit is contained in:
21
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
Normal file
21
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
Normal file
@@ -0,0 +1,21 @@
|
||||
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
|
||||
namespace mlir {
|
||||
class RewritePatternSet;
|
||||
class TypeConverter;
|
||||
|
||||
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
|
||||
|
||||
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &converter);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
@@ -0,0 +1,20 @@
|
||||
//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
|
||||
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
|
||||
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
|
||||
@@ -45,6 +45,7 @@
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
|
||||
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
|
||||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
|
||||
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
|
||||
|
||||
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefToEmitC
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
|
||||
let summary = "Convert MemRef dialect to EmitC dialect";
|
||||
let dependentDialects = ["emitc::EmitCDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemRefToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
|
||||
add_subdirectory(MathToLibm)
|
||||
add_subdirectory(MathToLLVM)
|
||||
add_subdirectory(MathToSPIRV)
|
||||
add_subdirectory(MemRefToEmitC)
|
||||
add_subdirectory(MemRefToLLVM)
|
||||
add_subdirectory(MemRefToSPIRV)
|
||||
add_subdirectory(NVGPUToNVVM)
|
||||
|
||||
18
mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
Normal file
18
mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
add_mlir_conversion_library(MLIRMemRefToEmitC
|
||||
MemRefToEmitC.cpp
|
||||
MemRefToEmitCPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIREmitCDialect
|
||||
MLIRMemRefDialect
|
||||
MLIRTransforms
|
||||
)
|
||||
114
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Normal file
114
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements patterns to convert memref ops into emitc ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
|
||||
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (!op.getType().hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "cannot transform alloca with dynamic shape");
|
||||
}
|
||||
|
||||
if (op.getAlignment().value_or(1) > 1) {
|
||||
// TODO: Allow alignment if it is not more than the natural alignment
|
||||
// of the C array.
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "cannot transform alloca with alignment requirement");
|
||||
}
|
||||
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
if (!resultTy) {
|
||||
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
|
||||
}
|
||||
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
|
||||
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
if (!resultTy) {
|
||||
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
|
||||
}
|
||||
|
||||
auto subscript = rewriter.create<emitc::SubscriptOp>(
|
||||
op.getLoc(), operands.getMemref(), operands.getIndices());
|
||||
|
||||
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
|
||||
auto var =
|
||||
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
|
||||
|
||||
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
|
||||
rewriter.replaceOp(op, var);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto subscript = rewriter.create<emitc::SubscriptOp>(
|
||||
op.getLoc(), operands.getMemref(), operands.getIndices());
|
||||
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
|
||||
operands.getValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
|
||||
typeConverter.addConversion(
|
||||
[&](MemRefType memRefType) -> std::optional<Type> {
|
||||
if (!memRefType.hasStaticShape() ||
|
||||
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
|
||||
return {};
|
||||
}
|
||||
Type convertedElementType =
|
||||
typeConverter.convertType(memRefType.getElementType());
|
||||
if (!convertedElementType)
|
||||
return {};
|
||||
return emitc::ArrayType::get(memRefType.getShape(),
|
||||
convertedElementType);
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
|
||||
TypeConverter &converter) {
|
||||
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
|
||||
patterns.getContext());
|
||||
}
|
||||
55
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
Normal file
55
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass to convert memref ops into emitc ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
|
||||
|
||||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct ConvertMemRefToEmitCPass
|
||||
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
|
||||
void runOnOperation() override {
|
||||
TypeConverter converter;
|
||||
|
||||
// Fallback for other types.
|
||||
converter.addConversion([](Type type) -> std::optional<Type> {
|
||||
if (isa<MemRefType>(type))
|
||||
return {};
|
||||
return type;
|
||||
});
|
||||
|
||||
populateMemRefToEmitCTypeConversion(converter);
|
||||
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateMemRefToEmitCConversionPatterns(patterns, converter);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<emitc::EmitCDialect>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
@@ -0,0 +1,40 @@
|
||||
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
|
||||
|
||||
func.func @memref_op(%arg0 : memref<2x4xf32>) {
|
||||
// expected-error@+1 {{failed to legalize operation 'memref.copy'}}
|
||||
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @alloca_with_dynamic_shape() {
|
||||
%0 = index.constant 1
|
||||
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
|
||||
%1 = memref.alloca(%0) : memref<4x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @alloca_with_alignment() {
|
||||
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
|
||||
%0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @non_identity_layout() {
|
||||
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
|
||||
%0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @zero_rank() {
|
||||
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
|
||||
%0 = memref.alloca() : memref<f32>
|
||||
return
|
||||
}
|
||||
28
mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Normal file
28
mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Normal file
@@ -0,0 +1,28 @@
|
||||
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: memref_store
|
||||
// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
|
||||
func.func @memref_store(%v : f32, %i: index, %j: index) {
|
||||
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
|
||||
%0 = memref.alloca() : memref<4x8xf32>
|
||||
|
||||
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
|
||||
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
|
||||
memref.store %v, %0[%i, %j] : memref<4x8xf32>
|
||||
return
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: memref_load
|
||||
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
|
||||
func.func @memref_load(%i: index, %j: index) -> f32 {
|
||||
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
|
||||
%0 = memref.alloca() : memref<4x8xf32>
|
||||
|
||||
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
|
||||
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
|
||||
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
|
||||
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
|
||||
// CHECK: return %[[VAR]] : f32
|
||||
return %1 : f32
|
||||
}
|
||||
@@ -4186,6 +4186,7 @@ cc_library(
|
||||
":MathToLLVM",
|
||||
":MathToLibm",
|
||||
":MathToSPIRV",
|
||||
":MemRefToEmitC",
|
||||
":MemRefToLLVM",
|
||||
":MemRefToSPIRV",
|
||||
":NVGPUToNVVM",
|
||||
@@ -8256,6 +8257,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "MemRefToEmitC",
|
||||
srcs = glob([
|
||||
"lib/Conversion/MemRefToEmitC/*.cpp",
|
||||
"lib/Conversion/MemRefToEmitC/*.h",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"include/mlir/Conversion/MemRefToEmitC/*.h",
|
||||
]),
|
||||
includes = [
|
||||
"include",
|
||||
"lib/Conversion/MemRefToEmitC",
|
||||
],
|
||||
deps = [
|
||||
":ConversionPassIncGen",
|
||||
":EmitCDialect",
|
||||
":MemRefDialect",
|
||||
":IR",
|
||||
":Pass",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "MemRefToLLVM",
|
||||
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),
|
||||
|
||||
Reference in New Issue
Block a user