diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md index 03877f1a6544..77da2f10d890 100644 --- a/mlir/docs/Dialects/Mesh.md +++ b/mlir/docs/Dialects/Mesh.md @@ -15,14 +15,17 @@ explanation. The main addition is that the collectives in this dialect have mesh semantics. +### Device groups The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh axes that partition the devices into disjoint groups. The collective operation is performed between devices in the same group. Devices that have the same coordinates outside of axes `mesh_axes` are in the same group. +A group is described by its multi-index along the axes outside of `mesh_axes`. For example if we have a device mesh of size `2x3x4x5` and the partition mesh axes list is `[0, 1]` then devices are partitioned into the groups `{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`. +The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`. Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group. Device (1, 0, 2, 4) will be in another group. Some collective operations like all-to-all and all-gather care about the @@ -33,6 +36,17 @@ The axes are ordered from outer to inner. If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`. +### In-group Device +Some operations like `broadcast`, `scatter` and `send` specify devices in each +device-group. +These devices are represented with their multi-index over the mesh axes that +are not constant within a device group. +These are the axes specified by `mesh_axes` attribute. + +For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify +an in-group device with `(i, j)`. Then for each group with index `g` on the +second axis, the in-group device would be `(i, g, j)`. + ## Operations diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 5cce15dd1015..361e67fd1e19 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [ let hasCanonicalizer = 1; } +def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [ + AllShapesMatch<["input", "result"]>, + AllElementTypesMatch<["input", "result"]> + ]> { + let summary = "Broadcast over a device mesh."; + let description = [{ + Broadcast the tensor on `root` to all devices in each respective group. + The operation broadcasts along mesh axes `mesh_axes`. + The `root` device specifies the in-group multi-index that is broadcast to + all other devices in the group. + + Example: + ``` + mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) + + %1 = mesh.broadcast %0 on @mesh0 + mesh_axes = [0] + root = [0] + : (tensor<2xi8>) -> tensor<2xi8> + ``` + + Input: + ``` + +-------+-------+ | broadcast + device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 + +-------+-------+ ↓ + device (1, 0) -> | | | <- device (1, 1) + +-------+-------+ + ``` + + Output: + ``` + +-------+-------+ + device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) + +-------+-------+ + device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) + +-------+-------+ + ``` + }]; + let arguments = !con(commonArgs, (ins + AnyRankedTensor:$input, + DenseI64ArrayAttr:$root, + Variadic:$root_dynamic + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + `root` `=` custom($root_dynamic, $root) + attr-dict `:` functional-type(operands, results) + }]; +} + +def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [ + AllRanksMatch<["input", "result"]>, + AllElementTypesMatch<["input", "result"]> + ]> { + let summary = "Gather over a device mesh."; + let description = [{ + Gathers on device `root` along the `gather_axis` tensor axis. + `root` specifies the coordinates of a device along `mesh_axes`. + It uniquely identifies the root device for each device group. + The result tensor on non-root devices is undefined. + Using it will result in undefined behavior. + + Example: + ```mlir + mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) + ... + %1 = mesh.gather %0 on @mesh0 mesh_axes = [1] + gather_axis = 1 root = [1] + : (tensor<2x2xi8>) -> tensor<2x4xi8> + ``` + Input: + ``` + gather tensor + axis 1 + ------------> + +-------+-------+ + device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) + | 3 4 | 7 8 | + +-------+-------+ + device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) + | 11 12 | 15 16 | + +-------+-------+ + ``` + Result: + ``` + +-------------+ + | 1 2 5 6 | <- devices (0, 1) + | 3 4 7 8 | + +-------------+ + | 9 10 13 14 | <- devices (1, 1) + | 11 12 15 16 | + +-------------+ + ``` + Devices `(0, 0)` and `(1, 0)` have undefined result. + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + IndexAttr:$gather_axis, + DenseI64ArrayAttr:$root, + Variadic:$root_dynamic + )); + let results = (outs + AnyNon0RankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + `gather_axis` `=` $gather_axis + `root` `=` custom($root_dynamic, $root) + attr-dict `:` functional-type(operands, results) + }]; +} + +def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [ + AllShapesMatch<["input", "result"]>, + AllElementTypesMatch<["input", "result"]> + ]> { + let summary = "Send over a device mesh."; + let description = [{ + Receive from a device within a device group. + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + OptionalAttr:$source, + Variadic:$source_dynamic + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + (`source` `=` custom($source_dynamic, $source)^)? + attr-dict `:` functional-type(operands, results) + }]; +} + +def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [ + AllShapesMatch<["input", "result"]> + ]> { + let summary = "Reduce over a device mesh."; + let description = [{ + Reduces on device `root` within each device group. + `root` specifies the coordinates of a device along `mesh_axes`. + It uniquely identifies the root device within its device group. + The accumulation element type is specified by the result type and + it does not need to match the input element type. + The input element is converted to the result element type before + performing the reduction. + + Attributes: + `reduction`: Indicates the reduction method. + + Example: + ``` + %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0] + reduction = root = [2, 3] + : (tensor<3x4xf32>) -> tensor<3x4xf64> + ``` + }]; + let arguments = !con(commonArgs, (ins + AnyRankedTensor:$input, + DefaultValuedAttr:$reduction, + DenseI64ArrayAttr:$root, + Variadic:$root_dynamic + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + (`reduction` `=` $reduction^)? + `root` `=` custom($root_dynamic, $root) + attr-dict `:` functional-type(operands, results) + }]; +} + def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [ SameOperandsAndResultRank]> { let summary = "Reduce-scatter over a device mesh."; @@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", let hasCanonicalizer = 1; } +def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [ + AllRanksMatch<["input", "result"]>, + AllElementTypesMatch<["input", "result"]> + ]> { + let summary = "Scatter over a device mesh."; + let description = [{ + For each device group split the input tensor on the `root` device along + axis `scatter_axis` and scatter the parts across the group devices. + + Example: + ``` + mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2]) + %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0] + scatter_axis = 0 + root = [1] + : (tensor<2x2xi8>) -> tensor<1x2xi8> + ``` + + Input: + ``` + device + (0, 1) + ↓ + +-------+-------+ | scatter tensor + device (0, 0) -> | | | | axis 0 + | | | ↓ + +-------+-------+ + device (1, 0) -> | 1 2 | 5 6 | + | 3 4 | 7 8 | + +-------+-------+ + ↑ + device + (1, 1) + ``` + + Result: + ``` + device + (0, 1) + ↓ + +-------+-------+ + device (0, 0) -> | 1 2 | 5 6 | + +-------+-------+ + device (1, 0) -> | 3 4 | 7 8 | + +-------+-------+ + ↑ + device + (1, 1) + ``` + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + IndexAttr:$scatter_axis, + DenseI64ArrayAttr:$root, + Variadic:$root_dynamic + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + `scatter_axis` `=` $scatter_axis + `root` `=` custom($root_dynamic, $root) + attr-dict `:` functional-type(operands, results) + }]; +} + +def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [ + AllShapesMatch<["input", "result"]>, + AllElementTypesMatch<["input", "result"]> + ]> { + let summary = "Send over a device mesh."; + let description = [{ + Send from one device to another within a device group. + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + DenseI64ArrayAttr:$destination, + Variadic:$destination_dynamic + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + `destination` `=` custom($destination_dynamic, $destination) + attr-dict `:` functional-type(operands, results) + }]; +} + +def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [ + SameOperandsAndResultElementType, + SameOperandsAndResultShape + ]> { + let summary = "Sift over a device mesh."; + let description = [{ + Within each device group shift along mesh axis `shift_axis` by an offset + `offset`. + The result on devices that do not have a corresponding source is undefined. + `shift_axis` must be one of `mesh_axes`. + If the `rotate` attribute is present, + instead of a shift a rotation is done. + + Example: + ``` + mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4]) + %1 = mesh.shift on @mesh0 mesh_axes = [1] + shift_axis = 1 offset = 2 rotate + : tensor<2xi8> -> tensor<2xi8> + ``` + + Input: + ``` + mesh axis 1 + -----------> + + +----+----+----+----+ + | 1 | 2 | 3 | 4 | + +----+----+----+----+ + | 5 | 6 | 7 | 8 | + +----+----+----+----+ + ``` + + Result: + ``` + +----+----+----+----+ + | 3 | 4 | 1 | 2 | + +----+----+----+----+ + | 7 | 8 | 5 | 6 | + +----+----+----+----+ + ``` + }]; + let arguments = !con(commonArgs, (ins + AnyNon0RankedTensor:$input, + IndexAttr:$shift_axis, + I64Attr:$offset, + UnitAttr:$rotate + )); + let results = (outs + AnyRankedTensor:$result + ); + let assemblyFormat = [{ + $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? + `shift_axis` `=` $shift_axis + `offset` `=` $offset + (`rotate` $rotate^)? + attr-dict `:` type($input) `->` type($result) + }]; +} + #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index b45f7cd21ce9..3b89860c14d9 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -14,6 +14,8 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" @@ -507,6 +509,43 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add>(context); } +//===----------------------------------------------------------------------===// +// mesh.broadcast op +//===----------------------------------------------------------------------===// + +LogicalResult +BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + +//===----------------------------------------------------------------------===// +// mesh.gather op +//===----------------------------------------------------------------------===// + +LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + +//===----------------------------------------------------------------------===// +// mesh.receive op +//===----------------------------------------------------------------------===// + +LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + +//===----------------------------------------------------------------------===// +// mesh.reduce op +//===----------------------------------------------------------------------===// + +LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + //===----------------------------------------------------------------------===// // mesh.reduce_scatter op //===----------------------------------------------------------------------===// @@ -528,6 +567,33 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add>(context); } +//===----------------------------------------------------------------------===// +// mesh.scatter op +//===----------------------------------------------------------------------===// + +LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + +//===----------------------------------------------------------------------===// +// mesh.send op +//===----------------------------------------------------------------------===// + +LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + +//===----------------------------------------------------------------------===// +// mesh.shift op +//===----------------------------------------------------------------------===// + +LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO + return failure(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===//