From abfac563f5b5a123e4bf773c3a09777e6fc4f50c Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 8 Mar 2024 08:14:36 -0800 Subject: [PATCH] [mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415) Make them more general instead of only supporting `func::FuncOp`. --- .../mlir/Dialect/Mesh/Transforms/Passes.td | 4 ++-- .../Mesh/Transforms/ShardingPropagation.cpp | 5 +++-- .../Dialect/Mesh/Transforms/Spmdization.cpp | 18 ++++++++++-------- mlir/test/Dialect/Linalg/mesh-spmdization.mlir | 3 +-- .../Dialect/Mesh/sharding-propagation.mlir | 2 +- mlir/test/Dialect/Mesh/spmdization.mlir | 4 +++- 6 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td index 7fb6631574b4..06ebf151e7d6 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td @@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td" // ShardingPropagation //===----------------------------------------------------------------------===// -def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> { +def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> { let summary = "sharding propagation"; let description = [{ Propagates sharding information throughout the graph. After this pass, each @@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> { ]; } -def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> { +def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> { let summary = "Partition a function into SPMD form."; let description = [{ This pass fits in right after a pass that annotates the function with diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp index 9f2647b21cbf..29320f1e339f 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" #include @@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { struct ShardingPropagation : public mesh::impl::ShardingPropagationBase { void runOnOperation() override { - func::FuncOp funcOp = getOperation(); + FunctionOpInterface funcOp = getOperation(); MLIRContext *ctx = funcOp.getContext(); - Region ®ion = funcOp.getBody(); + Region ®ion = funcOp.getFunctionBody(); OpBuilder builder(ctx); if (!region.hasOneBlock()) { funcOp.emitOpError() << "only one block is supported!"; diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp index c4d8b0b15e46..e4868435135e 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp @@ -24,6 +24,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, } static LogicalResult -spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap, +spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection) { OpBuilder builder(op.getFunctionBody()); @@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap, // Find a return op and change the function results signature to its operands // signature. - func::ReturnOp returnOp; - for (Block &block : op.getBody()) { + Operation *returnOp = nullptr; + for (Block &block : op.getFunctionBody()) { if (block.empty()) { continue; } - returnOp = llvm::cast(block.back()); - if (returnOp) { + if (block.back().hasTrait()) { + returnOp = &block.back(); break; } } assert(returnOp); - op.setFunctionType(FunctionType::get(op->getContext(), - op.getBody().front().getArgumentTypes(), - returnOp->getOperandTypes())); + op.setType(FunctionType::get(op->getContext(), + op.getFunctionBody().front().getArgumentTypes(), + returnOp->getOperandTypes())); return success(); } diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir index 6d21def8de27..bd56c801283b 100644 --- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir +++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt \ -// RUN: --mesh-spmdization \ -// RUN: --test-constant-fold \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \ // RUN: --split-input-file \ // RUN: %s | FileCheck %s diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir index 94f8d94073c5..270787ab5188 100644 --- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir +++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -sharding-propagation %s | FileCheck %s +// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s mesh.mesh @mesh_1d(shape = ?) mesh.mesh @mesh_2d(shape = 2x4) diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir index 572d3eb55eaa..2df247aba351 100644 --- a/mlir/test/Dialect/Mesh/spmdization.mlir +++ b/mlir/test/Dialect/Mesh/spmdization.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \ +// RUN: %s | FileCheck %s mesh.mesh @mesh_1d(shape = 2)