mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)
Make them more general instead of only supporting `func::FuncOp`.
This commit is contained in:
committed by
GitHub
parent
8160139136
commit
abfac563f5
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
|
||||
// ShardingPropagation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
|
||||
def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
|
||||
let summary = "sharding propagation";
|
||||
let description = [{
|
||||
Propagates sharding information throughout the graph. After this pass, each
|
||||
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
|
||||
];
|
||||
}
|
||||
|
||||
def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
|
||||
def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
|
||||
let summary = "Partition a function into SPMD form.";
|
||||
let description = [{
|
||||
This pass fits in right after a pass that annotates the function with
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
||||
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <vector>
|
||||
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
|
||||
struct ShardingPropagation
|
||||
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
|
||||
void runOnOperation() override {
|
||||
func::FuncOp funcOp = getOperation();
|
||||
FunctionOpInterface funcOp = getOperation();
|
||||
MLIRContext *ctx = funcOp.getContext();
|
||||
Region ®ion = funcOp.getBody();
|
||||
Region ®ion = funcOp.getFunctionBody();
|
||||
OpBuilder builder(ctx);
|
||||
if (!region.hasOneBlock()) {
|
||||
funcOp.emitOpError() << "only one block is supported!";
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
|
||||
spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
|
||||
SymbolTableCollection &symbolTableCollection) {
|
||||
OpBuilder builder(op.getFunctionBody());
|
||||
|
||||
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
|
||||
|
||||
// Find a return op and change the function results signature to its operands
|
||||
// signature.
|
||||
func::ReturnOp returnOp;
|
||||
for (Block &block : op.getBody()) {
|
||||
Operation *returnOp = nullptr;
|
||||
for (Block &block : op.getFunctionBody()) {
|
||||
if (block.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
returnOp = llvm::cast<func::ReturnOp>(block.back());
|
||||
if (returnOp) {
|
||||
if (block.back().hasTrait<OpTrait::ReturnLike>()) {
|
||||
returnOp = &block.back();
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(returnOp);
|
||||
op.setFunctionType(FunctionType::get(op->getContext(),
|
||||
op.getBody().front().getArgumentTypes(),
|
||||
returnOp->getOperandTypes()));
|
||||
op.setType(FunctionType::get(op->getContext(),
|
||||
op.getFunctionBody().front().getArgumentTypes(),
|
||||
returnOp->getOperandTypes()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// RUN: mlir-opt \
|
||||
// RUN: --mesh-spmdization \
|
||||
// RUN: --test-constant-fold \
|
||||
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
|
||||
// RUN: --split-input-file \
|
||||
// RUN: %s | FileCheck %s
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
|
||||
|
||||
mesh.mesh @mesh_1d(shape = ?)
|
||||
mesh.mesh @mesh_2d(shape = 2x4)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
|
||||
// RUN: mlir-opt \
|
||||
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
|
||||
// RUN: %s | FileCheck %s
|
||||
|
||||
mesh.mesh @mesh_1d(shape = 2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user