[MLIR] Apply clang-tidy fixes for readability-identifier-naming in ShardOps.cpp (NFC)

This commit is contained in:
Mehdi Amini
2025-08-21 08:41:59 -07:00
parent 1bbff7290f
commit 60492898f8

View File

@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
//===----------------------------------------------------------------------===//
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr grid,
ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
ArrayRef<int64_t> staticHalos,
ArrayRef<int64_t> staticOffsets) {
return build(
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
ArrayRef<int64_t> staticHalos,
ArrayRef<int64_t> staticOffsets) {
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
GridAxesArrayAttr::get(b.getContext(), splitAxes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
{});
}
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes,
::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
return build(
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
@@ -650,14 +649,14 @@ public:
if (dynamicOffs.empty() && !staticOffs.empty()) {
assert(staticOffs.size() >= 2);
auto diff = staticOffs[1] - staticOffs[0];
bool all_same = staticOffs.size() > 2;
bool allSame = staticOffs.size() > 2;
for (auto i = 2u; i < staticOffs.size(); ++i) {
if (staticOffs[i] - staticOffs[i - 1] != diff) {
all_same = false;
allSame = false;
break;
}
}
if (all_same) {
if (allSame) {
staticOffs.clear();
modified = true;
}
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {}
Sharding::Sharding(Value rhs) {
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
@@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) {
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
ArrayRef<GridAxesAttr> split_axes_,
ArrayRef<int64_t> static_halo_sizes_,
ArrayRef<int64_t> static_sharded_dims_offsets_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
Sharding res(grid_);
if (split_axes_.empty()) {
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid,
ArrayRef<GridAxesAttr> splitAxes,
ArrayRef<int64_t> staticHaloSizes,
ArrayRef<int64_t> staticShardedDimsOffsets,
ArrayRef<Value> dynamicHaloSizes,
ArrayRef<Value> dynamicShardedDimsOffsets) {
Sharding res(grid);
if (splitAxes.empty()) {
return res;
}
res.split_axes.resize(split_axes_.size());
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
res.split_axes[i] =
GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
res.split_axes.resize(splitAxes.size());
for (auto [i, axis] : llvm::enumerate(splitAxes)) {
res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
}
auto clone = [](const auto src, auto &dst) {
@@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
llvm::copy(src, dst.begin());
};
clone(static_halo_sizes_, res.static_halo_sizes);
clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
clone(staticHaloSizes, res.static_halo_sizes);
clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
clone(dynamicHaloSizes, res.dynamic_halo_sizes);
clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
return res;
}
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
::llvm::ArrayRef<int64_t> dims,
ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
::mlir::ValueRange device) {
SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
}