[MLIR] Add initial convert-memref-to-emitc pass (#85389)

This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.
This commit is contained in:
Matthias Gehre
2024-03-21 14:27:37 +01:00
committed by GitHub
parent 538257bf00
commit 0aa6d57e57
11 changed files with 334 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
namespace mlir {
class RewritePatternSet;
class TypeConverter;
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

View File

@@ -0,0 +1,20 @@
//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
#include <memory>
namespace mlir {
class Pass;
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H

View File

@@ -45,6 +45,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

View File

@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
];
}
//===----------------------------------------------------------------------===//
// MemRefToEmitC
//===----------------------------------------------------------------------===//
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
let summary = "Convert MemRef dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
}
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//

View File

@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(NVGPUToNVVM)

View File

@@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRMemRefToEmitC
MemRefToEmitC.cpp
MemRefToEmitCPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRMemRefDialect
MLIRTransforms
)

View File

@@ -0,0 +1,114 @@
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert memref ops into emitc ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
}
if (op.getAlignment().value_or(1) > 1) {
// TODO: Allow alignment if it is not more than the natural alignment
// of the C array.
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with alignment requirement");
}
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
return success();
}
};
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
rewriter.replaceOp(op, var);
return success();
}
};
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
}
};
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
if (!memRefType.hasStaticShape() ||
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
return {};
}
Type convertedElementType =
typeConverter.convertType(memRefType.getElementType());
if (!convertedElementType)
return {};
return emitc::ArrayType::get(memRefType.getShape(),
convertedElementType);
});
}
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
}

View File

@@ -0,0 +1,55 @@
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert memref ops into emitc ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
TypeConverter converter;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (isa<MemRefType>(type))
return {};
return type;
});
populateMemRefToEmitCTypeConversion(converter);
RewritePatternSet patterns(&getContext());
populateMemRefToEmitCConversionPatterns(patterns, converter);
ConversionTarget target(getContext());
target.addIllegalDialect<memref::MemRefDialect>();
target.addLegalDialect<emitc::EmitCDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

View File

@@ -0,0 +1,40 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
func.func @memref_op(%arg0 : memref<2x4xf32>) {
// expected-error@+1 {{failed to legalize operation 'memref.copy'}}
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}
// -----
func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
%1 = memref.alloca(%0) : memref<4x?xf32>
return
}
// -----
func.func @alloca_with_alignment() {
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
%0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
return
}
// -----
func.func @non_identity_layout() {
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
%0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
return
}
// -----
func.func @zero_rank() {
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
%0 = memref.alloca() : memref<f32>
return
}

View File

@@ -0,0 +1,28 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
// CHECK-LABEL: memref_store
// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
}
// -----
// CHECK-LABEL: memref_load
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
// CHECK: return %[[VAR]] : f32
return %1 : f32
}

View File

@@ -4186,6 +4186,7 @@ cc_library(
":MathToLLVM",
":MathToLibm",
":MathToSPIRV",
":MemRefToEmitC",
":MemRefToLLVM",
":MemRefToSPIRV",
":NVGPUToNVVM",
@@ -8256,6 +8257,32 @@ cc_library(
],
)
cc_library(
name = "MemRefToEmitC",
srcs = glob([
"lib/Conversion/MemRefToEmitC/*.cpp",
"lib/Conversion/MemRefToEmitC/*.h",
]),
hdrs = glob([
"include/mlir/Conversion/MemRefToEmitC/*.h",
]),
includes = [
"include",
"lib/Conversion/MemRefToEmitC",
],
deps = [
":ConversionPassIncGen",
":EmitCDialect",
":MemRefDialect",
":IR",
":Pass",
":Support",
":TransformUtils",
":Transforms",
"//llvm:Support",
],
)
cc_library(
name = "MemRefToLLVM",
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),