mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 04:14:03 +08:00
[MLIR] Apply clang-tidy fixes for readability-identifier-naming in ShardOps.cpp (NFC)
This commit is contained in:
@@ -476,38 +476,37 @@ void GridShapeOp::getAsmResultNames(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
FlatSymbolRefAttr grid,
|
||||
ArrayRef<GridAxesAttr> split_axes,
|
||||
ArrayRef<int64_t> static_halos,
|
||||
ArrayRef<int64_t> static_offsets) {
|
||||
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
|
||||
ArrayRef<int64_t> staticHalos,
|
||||
ArrayRef<int64_t> staticOffsets) {
|
||||
return build(
|
||||
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
|
||||
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets), {});
|
||||
}
|
||||
|
||||
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
|
||||
ArrayRef<int64_t> static_halos,
|
||||
ArrayRef<int64_t> static_offsets) {
|
||||
llvm::StringRef grid, ArrayRef<GridAxesAttr> splitAxes,
|
||||
ArrayRef<int64_t> staticHalos,
|
||||
ArrayRef<int64_t> staticOffsets) {
|
||||
return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
|
||||
GridAxesArrayAttr::get(b.getContext(), split_axes),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
|
||||
GridAxesArrayAttr::get(b.getContext(), splitAxes),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), {},
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticOffsets),
|
||||
{});
|
||||
}
|
||||
|
||||
void ShardingOp::build(
|
||||
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
|
||||
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
|
||||
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
|
||||
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
|
||||
FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> splitAxes,
|
||||
::mlir::ArrayRef<::mlir::OpFoldResult> haloSizes,
|
||||
::mlir::ArrayRef<::mlir::OpFoldResult> shardedDimsOffsets) {
|
||||
mlir::SmallVector<int64_t> staticHalos, staticDims;
|
||||
mlir::SmallVector<mlir::Value> dynamicHalos, dynamicDims;
|
||||
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
|
||||
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
|
||||
dispatchIndexOpFoldResults(haloSizes, dynamicHalos, staticHalos);
|
||||
dispatchIndexOpFoldResults(shardedDimsOffsets, dynamicDims, staticDims);
|
||||
return build(
|
||||
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
|
||||
b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), splitAxes),
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
|
||||
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
|
||||
}
|
||||
@@ -650,14 +649,14 @@ public:
|
||||
if (dynamicOffs.empty() && !staticOffs.empty()) {
|
||||
assert(staticOffs.size() >= 2);
|
||||
auto diff = staticOffs[1] - staticOffs[0];
|
||||
bool all_same = staticOffs.size() > 2;
|
||||
bool allSame = staticOffs.size() > 2;
|
||||
for (auto i = 2u; i < staticOffs.size(); ++i) {
|
||||
if (staticOffs[i] - staticOffs[i - 1] != diff) {
|
||||
all_same = false;
|
||||
allSame = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_same) {
|
||||
if (allSame) {
|
||||
staticOffs.clear();
|
||||
modified = true;
|
||||
}
|
||||
@@ -749,7 +748,7 @@ bool Sharding::operator==(const Sharding &rhs) const {
|
||||
|
||||
bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
|
||||
|
||||
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
|
||||
Sharding::Sharding(::mlir::FlatSymbolRefAttr grid) : grid(grid) {}
|
||||
|
||||
Sharding::Sharding(Value rhs) {
|
||||
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
|
||||
@@ -767,21 +766,20 @@ Sharding::Sharding(Value rhs) {
|
||||
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
|
||||
}
|
||||
|
||||
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
|
||||
ArrayRef<GridAxesAttr> split_axes_,
|
||||
ArrayRef<int64_t> static_halo_sizes_,
|
||||
ArrayRef<int64_t> static_sharded_dims_offsets_,
|
||||
ArrayRef<Value> dynamic_halo_sizes_,
|
||||
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
|
||||
Sharding res(grid_);
|
||||
if (split_axes_.empty()) {
|
||||
Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid,
|
||||
ArrayRef<GridAxesAttr> splitAxes,
|
||||
ArrayRef<int64_t> staticHaloSizes,
|
||||
ArrayRef<int64_t> staticShardedDimsOffsets,
|
||||
ArrayRef<Value> dynamicHaloSizes,
|
||||
ArrayRef<Value> dynamicShardedDimsOffsets) {
|
||||
Sharding res(grid);
|
||||
if (splitAxes.empty()) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res.split_axes.resize(split_axes_.size());
|
||||
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
|
||||
res.split_axes[i] =
|
||||
GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
|
||||
res.split_axes.resize(splitAxes.size());
|
||||
for (auto [i, axis] : llvm::enumerate(splitAxes)) {
|
||||
res.split_axes[i] = GridAxesAttr::get(grid.getContext(), axis.asArrayRef());
|
||||
}
|
||||
|
||||
auto clone = [](const auto src, auto &dst) {
|
||||
@@ -789,10 +787,10 @@ Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
|
||||
llvm::copy(src, dst.begin());
|
||||
};
|
||||
|
||||
clone(static_halo_sizes_, res.static_halo_sizes);
|
||||
clone(static_sharded_dims_offsets_, res.static_sharded_dims_offsets);
|
||||
clone(dynamic_halo_sizes_, res.dynamic_halo_sizes);
|
||||
clone(dynamic_sharded_dims_offsets_, res.dynamic_sharded_dims_offsets);
|
||||
clone(staticHaloSizes, res.static_halo_sizes);
|
||||
clone(staticShardedDimsOffsets, res.static_sharded_dims_offsets);
|
||||
clone(dynamicHaloSizes, res.dynamic_halo_sizes);
|
||||
clone(dynamicShardedDimsOffsets, res.dynamic_sharded_dims_offsets);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -809,10 +807,10 @@ void ShardShapeOp::getAsmResultNames(
|
||||
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
|
||||
::mlir::OperationState &odsState,
|
||||
::llvm::ArrayRef<int64_t> dims,
|
||||
ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
|
||||
ArrayRef<Value> dimsDyn, ::mlir::Value sharding,
|
||||
::mlir::ValueRange device) {
|
||||
SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
|
||||
build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
|
||||
build(odsBuilder, odsState, resType, dims, dimsDyn, sharding,
|
||||
SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user