[mlir][mesh] fixes for 0d tensors (#132948)

In some cases 0d tensors have no sharding. This PR provides a few minor
fixes to account for such cases.
This commit is contained in:
Frank Schlimbach
2025-03-26 18:13:41 +01:00
committed by GitHub
parent e8dfd70fe2
commit 9269aaecff
6 changed files with 54 additions and 26 deletions

View File

@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
inline mesh::MeshOp
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTableCollection) {
if (!meshSymbol)
return nullptr;
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}

View File

@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType) {
if (rankedTensorType && !rankedTensorType.getShape().empty()) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
return type;

View File

@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(
shardType(newResult.getType(),
getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
newResult.setType(shardType(
newResult.getType(),
getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
}
}

View File

@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
if (!rankedTensorArg) {
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
return arg.getType();
}
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor) {
if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
return MeshSharding();
}
@@ -689,20 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
llvm::transform(op.getResults(), std::back_inserter(res),
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshSharding();
}
if (!result.hasOneUse()) {
return MeshSharding();
}
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
return MeshSharding(shardOp.getSharding());
});
llvm::transform(
op.getResults(), std::back_inserter(res), [&op](OpResult result) {
if (!result.hasOneUse() || result.use_empty()) {
return MeshSharding();
}
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshSharding();
}
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
if (shardOp) {
return MeshSharding(shardOp.getSharding());
}
if (rankedTensor.getType().getRank() == 0) {
// This is a 0d tensor result without explicit sharding.
// Find mesh symbol from operands, if any.
// Shardings without mesh are not always fully supported yet.
for (auto operand : op.getOperands()) {
if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
return MeshSharding(sharding.getMeshAttr());
}
}
}
return MeshSharding();
});
return res;
}

View File

@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
auto mesh =
mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
auto shardType = cast<ShapedType>(
mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
assert(resultShardings.size() == 1);
auto resType = cast<RankedTensorType>(op->getResult(0).getType());
mlir::mesh::MeshOp mesh;
ShapedType shardType;
if (resType.getRank() > 0) {
mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
shardType =
cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
} else {
shardType = resType;
}
Operation *newOp = nullptr;
// if the sharding introduces a new dynamic dimension, we take it from
// the dynamic sharding info. For now bail out if it's not
// provided.
assert(resultShardings.size() == 1);
if (!shardType.hasStaticShape()) {
assert(op->getResult(0).hasOneUse());
SmallVector<Value> newOperands;
auto oldType = cast<ShapedType>(op->getResult(0).getType());
auto oldType = cast<ShapedType>(resType);
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
mesh::ShardShapeOp shapeForDevice;

View File

@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
return
}
// CHECK-LABEL: func @tensor_empty_0d
func.func @tensor_empty_0d() -> () {
tensor.empty() : tensor<f32>
// CHECK-NEXT: tensor.empty() : tensor<f32>
return
}