mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 14:48:27 +08:00
[mlir][mesh] Add TableGen deffinitions of more collective ops (#73842)
Add definitions for broadcast, gather, receive, reduce, scatter, send and shift.
This commit is contained in:
committed by
GitHub
parent
01e40a8a3d
commit
dff2f59be3
@@ -15,14 +15,17 @@ explanation.
|
||||
The main addition is that the collectives in this dialect have mesh
|
||||
semantics.
|
||||
|
||||
### Device groups
|
||||
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
|
||||
axes that partition the devices into disjoint groups.
|
||||
The collective operation is performed between devices in the same group.
|
||||
Devices that have the same coordinates outside of axes `mesh_axes` are in the
|
||||
same group.
|
||||
A group is described by its multi-index along the axes outside of `mesh_axes`.
|
||||
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
|
||||
axes list is `[0, 1]` then devices are partitioned into the groups
|
||||
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
|
||||
The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
|
||||
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
|
||||
Device (1, 0, 2, 4) will be in another group.
|
||||
Some collective operations like all-to-all and all-gather care about the
|
||||
@@ -33,6 +36,17 @@ The axes are ordered from outer to inner.
|
||||
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
|
||||
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
|
||||
|
||||
### In-group Device
|
||||
Some operations like `broadcast`, `scatter` and `send` specify devices in each
|
||||
device-group.
|
||||
These devices are represented with their multi-index over the mesh axes that
|
||||
are not constant within a device group.
|
||||
These are the axes specified by `mesh_axes` attribute.
|
||||
|
||||
For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
|
||||
an in-group device with `(i, j)`. Then for each group with index `g` on the
|
||||
second axis, the in-group device would be `(i, g, j)`.
|
||||
|
||||
|
||||
## Operations
|
||||
|
||||
|
||||
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
|
||||
AllShapesMatch<["input", "result"]>,
|
||||
AllElementTypesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Broadcast over a device mesh.";
|
||||
let description = [{
|
||||
Broadcast the tensor on `root` to all devices in each respective group.
|
||||
The operation broadcasts along mesh axes `mesh_axes`.
|
||||
The `root` device specifies the in-group multi-index that is broadcast to
|
||||
all other devices in the group.
|
||||
|
||||
Example:
|
||||
```
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
|
||||
|
||||
%1 = mesh.broadcast %0 on @mesh0
|
||||
mesh_axes = [0]
|
||||
root = [0]
|
||||
: (tensor<2xi8>) -> tensor<2xi8>
|
||||
```
|
||||
|
||||
Input:
|
||||
```
|
||||
+-------+-------+ | broadcast
|
||||
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
|
||||
+-------+-------+ ↓
|
||||
device (1, 0) -> | | | <- device (1, 1)
|
||||
+-------+-------+
|
||||
```
|
||||
|
||||
Output:
|
||||
```
|
||||
+-------+-------+
|
||||
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
|
||||
+-------+-------+
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyRankedTensor:$input,
|
||||
DenseI64ArrayAttr:$root,
|
||||
Variadic<Index>:$root_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
|
||||
AllRanksMatch<["input", "result"]>,
|
||||
AllElementTypesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Gather over a device mesh.";
|
||||
let description = [{
|
||||
Gathers on device `root` along the `gather_axis` tensor axis.
|
||||
`root` specifies the coordinates of a device along `mesh_axes`.
|
||||
It uniquely identifies the root device for each device group.
|
||||
The result tensor on non-root devices is undefined.
|
||||
Using it will result in undefined behavior.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
|
||||
...
|
||||
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
|
||||
gather_axis = 1 root = [1]
|
||||
: (tensor<2x2xi8>) -> tensor<2x4xi8>
|
||||
```
|
||||
Input:
|
||||
```
|
||||
gather tensor
|
||||
axis 1
|
||||
------------>
|
||||
+-------+-------+
|
||||
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
|
||||
| 3 4 | 7 8 |
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
|
||||
| 11 12 | 15 16 |
|
||||
+-------+-------+
|
||||
```
|
||||
Result:
|
||||
```
|
||||
+-------------+
|
||||
| 1 2 5 6 | <- devices (0, 1)
|
||||
| 3 4 7 8 |
|
||||
+-------------+
|
||||
| 9 10 13 14 | <- devices (1, 1)
|
||||
| 11 12 15 16 |
|
||||
+-------------+
|
||||
```
|
||||
Devices `(0, 0)` and `(1, 0)` have undefined result.
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$gather_axis,
|
||||
DenseI64ArrayAttr:$root,
|
||||
Variadic<Index>:$root_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyNon0RankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`gather_axis` `=` $gather_axis
|
||||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
|
||||
AllShapesMatch<["input", "result"]>,
|
||||
AllElementTypesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Send over a device mesh.";
|
||||
let description = [{
|
||||
Receive from a device within a device group.
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
OptionalAttr<DenseI64ArrayAttr>:$source,
|
||||
Variadic<Index>:$source_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
|
||||
AllShapesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Reduce over a device mesh.";
|
||||
let description = [{
|
||||
Reduces on device `root` within each device group.
|
||||
`root` specifies the coordinates of a device along `mesh_axes`.
|
||||
It uniquely identifies the root device within its device group.
|
||||
The accumulation element type is specified by the result type and
|
||||
it does not need to match the input element type.
|
||||
The input element is converted to the result element type before
|
||||
performing the reduction.
|
||||
|
||||
Attributes:
|
||||
`reduction`: Indicates the reduction method.
|
||||
|
||||
Example:
|
||||
```
|
||||
%1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
|
||||
reduction = <max> root = [2, 3]
|
||||
: (tensor<3x4xf32>) -> tensor<3x4xf64>
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyRankedTensor:$input,
|
||||
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
|
||||
DenseI64ArrayAttr:$root,
|
||||
Variadic<Index>:$root_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
(`reduction` `=` $reduction^)?
|
||||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
|
||||
SameOperandsAndResultRank]> {
|
||||
let summary = "Reduce-scatter over a device mesh.";
|
||||
@@ -400,4 +579,154 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
|
||||
AllRanksMatch<["input", "result"]>,
|
||||
AllElementTypesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Scatter over a device mesh.";
|
||||
let description = [{
|
||||
For each device group split the input tensor on the `root` device along
|
||||
axis `scatter_axis` and scatter the parts across the group devices.
|
||||
|
||||
Example:
|
||||
```
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
|
||||
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
|
||||
scatter_axis = 0
|
||||
root = [1]
|
||||
: (tensor<2x2xi8>) -> tensor<1x2xi8>
|
||||
```
|
||||
|
||||
Input:
|
||||
```
|
||||
device
|
||||
(0, 1)
|
||||
↓
|
||||
+-------+-------+ | scatter tensor
|
||||
device (0, 0) -> | | | | axis 0
|
||||
| | | ↓
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 1 2 | 5 6 |
|
||||
| 3 4 | 7 8 |
|
||||
+-------+-------+
|
||||
↑
|
||||
device
|
||||
(1, 1)
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
device
|
||||
(0, 1)
|
||||
↓
|
||||
+-------+-------+
|
||||
device (0, 0) -> | 1 2 | 5 6 |
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 3 4 | 7 8 |
|
||||
+-------+-------+
|
||||
↑
|
||||
device
|
||||
(1, 1)
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$scatter_axis,
|
||||
DenseI64ArrayAttr:$root,
|
||||
Variadic<Index>:$root_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`scatter_axis` `=` $scatter_axis
|
||||
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
|
||||
AllShapesMatch<["input", "result"]>,
|
||||
AllElementTypesMatch<["input", "result"]>
|
||||
]> {
|
||||
let summary = "Send over a device mesh.";
|
||||
let description = [{
|
||||
Send from one device to another within a device group.
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
DenseI64ArrayAttr:$destination,
|
||||
Variadic<Index>:$destination_dynamic
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultShape
|
||||
]> {
|
||||
let summary = "Sift over a device mesh.";
|
||||
let description = [{
|
||||
Within each device group shift along mesh axis `shift_axis` by an offset
|
||||
`offset`.
|
||||
The result on devices that do not have a corresponding source is undefined.
|
||||
`shift_axis` must be one of `mesh_axes`.
|
||||
If the `rotate` attribute is present,
|
||||
instead of a shift a rotation is done.
|
||||
|
||||
Example:
|
||||
```
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
%1 = mesh.shift on @mesh0 mesh_axes = [1]
|
||||
shift_axis = 1 offset = 2 rotate
|
||||
: tensor<2xi8> -> tensor<2xi8>
|
||||
```
|
||||
|
||||
Input:
|
||||
```
|
||||
mesh axis 1
|
||||
----------->
|
||||
|
||||
+----+----+----+----+
|
||||
| 1 | 2 | 3 | 4 |
|
||||
+----+----+----+----+
|
||||
| 5 | 6 | 7 | 8 |
|
||||
+----+----+----+----+
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
+----+----+----+----+
|
||||
| 3 | 4 | 1 | 2 |
|
||||
+----+----+----+----+
|
||||
| 7 | 8 | 5 | 6 |
|
||||
+----+----+----+----+
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$shift_axis,
|
||||
I64Attr:$offset,
|
||||
UnitAttr:$rotate
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`shift_axis` `=` $shift_axis
|
||||
`offset` `=` $offset
|
||||
(`rotate` $rotate^)?
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
@@ -507,6 +509,43 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.broadcast op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.gather op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.receive op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.reduce op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.reduce_scatter op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -528,6 +567,33 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.scatter op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.send op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.shift op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// TODO
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user