[mlir][mesh] Add TableGen deffinitions of more collective ops (#73842)

Add definitions for
broadcast, gather, receive, reduce, scatter, send and shift.
This commit is contained in:
Boian Petkantchin
2023-12-04 09:11:47 -08:00
committed by GitHub
parent 01e40a8a3d
commit dff2f59be3
3 changed files with 409 additions and 0 deletions

View File

@@ -15,14 +15,17 @@ explanation.
The main addition is that the collectives in this dialect have mesh
semantics.
### Device groups
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes `mesh_axes` are in the
same group.
A group is described by its multi-index along the axes outside of `mesh_axes`.
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
axes list is `[0, 1]` then devices are partitioned into the groups
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
@@ -33,6 +36,17 @@ The axes are ordered from outer to inner.
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
### In-group Device
Some operations like `broadcast`, `scatter` and `send` specify devices in each
device-group.
These devices are represented with their multi-index over the mesh axes that
are not constant within a device group.
These are the axes specified by `mesh_axes` attribute.
For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
an in-group device with `(i, j)`. Then for each group with index `g` on the
second axis, the in-group device would be `(i, g, j)`.
## Operations

View File

@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
let hasCanonicalizer = 1;
}
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
let summary = "Broadcast over a device mesh.";
let description = [{
Broadcast the tensor on `root` to all devices in each respective group.
The operation broadcasts along mesh axes `mesh_axes`.
The `root` device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
```
Input:
```
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
```
Output:
```
+-------+-------+
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+-------+-------+
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+-------+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
let summary = "Gather over a device mesh.";
let description = [{
Gathers on device `root` along the `gather_axis` tensor axis.
`root` specifies the coordinates of a device along `mesh_axes`.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
```mlir
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
```
Input:
```
gather tensor
axis 1
------------>
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
```
Result:
```
+-------------+
| 1 2 5 6 | <- devices (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 1)
| 11 12 15 16 |
+-------------+
```
Devices `(0, 0)` and `(1, 0)` have undefined result.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$gather_axis,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
let summary = "Send over a device mesh.";
let description = [{
Receive from a device within a device group.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
OptionalAttr<DenseI64ArrayAttr>:$source,
Variadic<Index>:$source_dynamic
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
AllShapesMatch<["input", "result"]>
]> {
let summary = "Reduce over a device mesh.";
let description = [{
Reduces on device `root` within each device group.
`root` specifies the coordinates of a device along `mesh_axes`.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
`reduction`: Indicates the reduction method.
Example:
```
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device mesh.";
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
let hasCanonicalizer = 1;
}
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
let summary = "Scatter over a device mesh.";
let description = [{
For each device group split the input tensor on the `root` device along
axis `scatter_axis` and scatter the parts across the group devices.
Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
```
Input:
```
device
(0, 1)
+-------+-------+ | scatter tensor
device (0, 0) -> | | | | axis 0
| | | ↓
+-------+-------+
device (1, 0) -> | 1 2 | 5 6 |
| 3 4 | 7 8 |
+-------+-------+
device
(1, 1)
```
Result:
```
device
(0, 1)
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
+-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
device
(1, 1)
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$scatter_axis,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
let summary = "Send over a device mesh.";
let description = [{
Send from one device to another within a device group.
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
DenseI64ArrayAttr:$destination,
Variadic<Index>:$destination_dynamic
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
let summary = "Sift over a device mesh.";
let description = [{
Within each device group shift along mesh axis `shift_axis` by an offset
`offset`.
The result on devices that do not have a corresponding source is undefined.
`shift_axis` must be one of `mesh_axes`.
If the `rotate` attribute is present,
instead of a shift a rotation is done.
Example:
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
```
Input:
```
mesh axis 1
----------->
+----+----+----+----+
| 1 | 2 | 3 | 4 |
+----+----+----+----+
| 5 | 6 | 7 | 8 |
+----+----+----+----+
```
Result:
```
+----+----+----+----+
| 3 | 4 | 1 | 2 |
+----+----+----+----+
| 7 | 8 | 5 | 6 |
+----+----+----+----+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$shift_axis,
I64Attr:$offset,
UnitAttr:$rotate
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
}];
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

View File

@@ -14,6 +14,8 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
@@ -507,6 +509,43 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.broadcast op
//===----------------------------------------------------------------------===//
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.gather op
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.receive op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.reduce op
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
@@ -528,6 +567,33 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.scatter op
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.send op
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// mesh.shift op
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// TODO
return failure();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//