[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)

Make them more general instead of only supporting `func::FuncOp`.
This commit is contained in:
Boian Petkantchin
2024-03-08 08:14:36 -08:00
committed by GitHub
parent 8160139136
commit abfac563f5
6 changed files with 20 additions and 16 deletions

View File

@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
// ShardingPropagation
//===----------------------------------------------------------------------===//
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
];
}
def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
#include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
Region &region = funcOp.getBody();
Region &region = funcOp.getFunctionBody();
OpBuilder builder(ctx);
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";

View File

@@ -24,6 +24,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
// Find a return op and change the function results signature to its operands
// signature.
func::ReturnOp returnOp;
for (Block &block : op.getBody()) {
Operation *returnOp = nullptr;
for (Block &block : op.getFunctionBody()) {
if (block.empty()) {
continue;
}
returnOp = llvm::cast<func::ReturnOp>(block.back());
if (returnOp) {
if (block.back().hasTrait<OpTrait::ReturnLike>()) {
returnOp = &block.back();
break;
}
}
assert(returnOp);
op.setFunctionType(FunctionType::get(op->getContext(),
op.getBody().front().getArgumentTypes(),
returnOp->getOperandTypes()));
op.setType(FunctionType::get(op->getContext(),
op.getFunctionBody().front().getArgumentTypes(),
returnOp->getOperandTypes()));
return success();
}

View File

@@ -1,6 +1,5 @@
// RUN: mlir-opt \
// RUN: --mesh-spmdization \
// RUN: --test-constant-fold \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)

View File

@@ -1,4 +1,6 @@
// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
// RUN: mlir-opt \
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
// RUN: %s | FileCheck %s
mesh.mesh @mesh_1d(shape = 2)