mirror of
https://github.com/intel/llvm.git
synced 2026-01-28 01:04:49 +08:00
[mlir][mesh] fixes for 0d tensors (#132948)
In some cases 0d tensors have no sharding. This PR provides a few minor fixes to account for such cases.
This commit is contained in:
@@ -119,6 +119,8 @@ inline bool isFullReplication(MeshSharding sharding) {
|
||||
inline mesh::MeshOp
|
||||
getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
|
||||
SymbolTableCollection &symbolTableCollection) {
|
||||
if (!meshSymbol)
|
||||
return nullptr;
|
||||
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
|
||||
op, meshSymbol);
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
|
||||
|
||||
Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
|
||||
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
|
||||
if (rankedTensorType) {
|
||||
if (rankedTensorType && !rankedTensorType.getShape().empty()) {
|
||||
return shardShapedType(rankedTensorType, mesh, sharding);
|
||||
}
|
||||
return type;
|
||||
|
||||
@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
|
||||
// Set the result types to the sharded counterparts.
|
||||
for (auto [oldResult, newResult, sharding] :
|
||||
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
|
||||
newResult.setType(
|
||||
shardType(newResult.getType(),
|
||||
getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
|
||||
newResult.setType(shardType(
|
||||
newResult.getType(),
|
||||
getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -622,7 +622,7 @@ shardedBlockArgumentTypes(Block &block,
|
||||
block.getArguments(), std::back_inserter(res),
|
||||
[&symbolTableCollection](BlockArgument arg) {
|
||||
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
|
||||
if (!rankedTensorArg) {
|
||||
if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
|
||||
return arg.getType();
|
||||
}
|
||||
|
||||
@@ -672,7 +672,7 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
|
||||
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
|
||||
TypedValue<RankedTensorType> rankedTensor =
|
||||
dyn_cast<TypedValue<RankedTensorType>>(operand);
|
||||
if (!rankedTensor) {
|
||||
if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
|
||||
return MeshSharding();
|
||||
}
|
||||
|
||||
@@ -689,20 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
|
||||
static std::vector<MeshSharding> getResultShardings(Operation &op) {
|
||||
std::vector<MeshSharding> res;
|
||||
res.reserve(op.getNumResults());
|
||||
llvm::transform(op.getResults(), std::back_inserter(res),
|
||||
[](OpResult result) {
|
||||
TypedValue<RankedTensorType> rankedTensor =
|
||||
dyn_cast<TypedValue<RankedTensorType>>(result);
|
||||
if (!rankedTensor) {
|
||||
return MeshSharding();
|
||||
}
|
||||
if (!result.hasOneUse()) {
|
||||
return MeshSharding();
|
||||
}
|
||||
Operation *userOp = *result.getUsers().begin();
|
||||
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
|
||||
return MeshSharding(shardOp.getSharding());
|
||||
});
|
||||
llvm::transform(
|
||||
op.getResults(), std::back_inserter(res), [&op](OpResult result) {
|
||||
if (!result.hasOneUse() || result.use_empty()) {
|
||||
return MeshSharding();
|
||||
}
|
||||
TypedValue<RankedTensorType> rankedTensor =
|
||||
dyn_cast<TypedValue<RankedTensorType>>(result);
|
||||
if (!rankedTensor) {
|
||||
return MeshSharding();
|
||||
}
|
||||
Operation *userOp = *result.getUsers().begin();
|
||||
ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
|
||||
if (shardOp) {
|
||||
return MeshSharding(shardOp.getSharding());
|
||||
}
|
||||
if (rankedTensor.getType().getRank() == 0) {
|
||||
// This is a 0d tensor result without explicit sharding.
|
||||
// Find mesh symbol from operands, if any.
|
||||
// Shardings without mesh are not always fully supported yet.
|
||||
for (auto operand : op.getOperands()) {
|
||||
if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
|
||||
return MeshSharding(sharding.getMeshAttr());
|
||||
}
|
||||
}
|
||||
}
|
||||
return MeshSharding();
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -50,19 +50,25 @@ struct CreatorOpShardingInterface
|
||||
IRMapping &spmdizationMap,
|
||||
SymbolTableCollection &symbolTable,
|
||||
OpBuilder &builder) const {
|
||||
auto mesh =
|
||||
mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
|
||||
auto shardType = cast<ShapedType>(
|
||||
mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
|
||||
assert(resultShardings.size() == 1);
|
||||
auto resType = cast<RankedTensorType>(op->getResult(0).getType());
|
||||
mlir::mesh::MeshOp mesh;
|
||||
ShapedType shardType;
|
||||
if (resType.getRank() > 0) {
|
||||
mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
|
||||
shardType =
|
||||
cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
|
||||
} else {
|
||||
shardType = resType;
|
||||
}
|
||||
Operation *newOp = nullptr;
|
||||
// if the sharding introduces a new dynamic dimension, we take it from
|
||||
// the dynamic sharding info. For now bail out if it's not
|
||||
// provided.
|
||||
assert(resultShardings.size() == 1);
|
||||
if (!shardType.hasStaticShape()) {
|
||||
assert(op->getResult(0).hasOneUse());
|
||||
SmallVector<Value> newOperands;
|
||||
auto oldType = cast<ShapedType>(op->getResult(0).getType());
|
||||
auto oldType = cast<ShapedType>(resType);
|
||||
assert(oldType.getRank() == shardType.getRank());
|
||||
int currOldOprndNum = -1;
|
||||
mesh::ShardShapeOp shapeForDevice;
|
||||
|
||||
@@ -43,3 +43,10 @@ func.func @tensor_empty_same_static_dims_sizes() -> () {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_empty_0d
|
||||
func.func @tensor_empty_0d() -> () {
|
||||
tensor.empty() : tensor<f32>
|
||||
// CHECK-NEXT: tensor.empty() : tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user