mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 00:20:25 +08:00
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars (#148055)
This aims to expand the the MemRefToEmitC pass so that it can accept
global scalars.
From:
```
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
memref.get_global @__constant_xi32 : memref<i32>
}
```
To:
```
emitc.global static const @__constant_xi32 : i32 = -1
emitc.func @globals() {
%0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
return
}
```
This commit is contained in:
@@ -16,7 +16,9 @@
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
|
||||
}
|
||||
};
|
||||
|
||||
Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
|
||||
Type resultTy;
|
||||
if (opTy.getRank() == 0) {
|
||||
resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
|
||||
} else {
|
||||
resultTy = typeConverter->convertType(opTy);
|
||||
}
|
||||
return resultTy;
|
||||
}
|
||||
|
||||
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
MemRefType opTy = op.getType();
|
||||
if (!op.getType().hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op.getLoc(), "cannot transform global with dynamic shape");
|
||||
@@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
|
||||
op.getLoc(), "global variable with alignment requirement is "
|
||||
"currently not supported");
|
||||
}
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Type resultTy = convertMemRefType(opTy, getTypeConverter());
|
||||
|
||||
if (!resultTy) {
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"cannot convert result type");
|
||||
@@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
|
||||
bool externSpecifier = !staticSpecifier;
|
||||
|
||||
Attribute initialValue = operands.getInitialValueAttr();
|
||||
if (opTy.getRank() == 0) {
|
||||
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
|
||||
initialValue = elementsAttr.getSplatValue<Attribute>();
|
||||
}
|
||||
if (isa_and_present<UnitAttr>(initialValue))
|
||||
initialValue = {};
|
||||
|
||||
@@ -132,11 +150,23 @@ struct ConvertGetGlobal final
|
||||
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
MemRefType opTy = op.getType();
|
||||
Type resultTy = convertMemRefType(opTy, getTypeConverter());
|
||||
|
||||
if (!resultTy) {
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"cannot convert result type");
|
||||
}
|
||||
|
||||
if (opTy.getRank() == 0) {
|
||||
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
|
||||
emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>(
|
||||
op.getLoc(), lvalueType, operands.getNameAttr());
|
||||
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
|
||||
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
|
||||
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
|
||||
return success();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
|
||||
operands.getNameAttr());
|
||||
return success();
|
||||
|
||||
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
|
||||
module @globals {
|
||||
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
|
||||
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
|
||||
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
|
||||
// CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
|
||||
memref.global @public_global : memref<3x7xf32>
|
||||
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
|
||||
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
|
||||
@@ -50,6 +52,9 @@ module @globals {
|
||||
func.func @use_global() {
|
||||
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
|
||||
%0 = memref.get_global @public_global : memref<3x7xf32>
|
||||
// CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
|
||||
// CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
|
||||
%1 = memref.get_global @__constant_xi32 : memref<i32>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user