mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][Linalg] Add an interface to decompose complex ops
This patch adds an interface, named AggregatedOpInterface, that decomposes complex operations into simpler ones. For now, make the interface specific to Linalg because although the concept is general, the way to materialize it needs some maturing. Use that interface with the softmax operator. Differential Revision: https://reviews.llvm.org/D154363
This commit is contained in:
@@ -897,4 +897,34 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
let verifyWithRegions = 1;
|
||||
}
|
||||
|
||||
def AggregatedOpInterface : OpInterface<"AggregatedOpInterface"> {
|
||||
let description = [{
|
||||
Interface for decomposing aggregated operations into a sequence of simpler
|
||||
ops.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Method to decompose the operation into simpler operations.
|
||||
|
||||
On success, this method returns one `Value` per result in the
|
||||
original operation.
|
||||
The order of the returned values must match the order of the
|
||||
original values.
|
||||
In other words, the returned vector can be used directly with
|
||||
`RewriterBase::replaceOp(this, returnedValues)`.
|
||||
}],
|
||||
/*retType=*/"FailureOr<SmallVector<Value>>",
|
||||
/*methodName=*/"decomposeOperation",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return {};
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // LINALG_IR_LINALGINTERFACES
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#define LINALG_OPS
|
||||
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
@@ -93,6 +94,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
|
||||
[DestinationStyleOpInterface,
|
||||
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
|
||||
DeclareOpInterfaceMethods<TilingInterface,
|
||||
["getIterationDomain",
|
||||
"getLoopIteratorTypes",
|
||||
|
||||
@@ -1199,6 +1199,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DecomposeInterfaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
|
||||
[FunctionalStyleTransformOpTrait,
|
||||
MemoryEffectsOpInterface,
|
||||
TransformOpInterface,
|
||||
TransformEachOpTrait,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let description = [{
|
||||
TODO
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::Operation *target,
|
||||
::mlir::transform::ApplyToEachResultList &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteInDestinationPassingStyleOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -2323,6 +2323,176 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
|
||||
.reifyResultShapes(b, reifiedReturnShapes);
|
||||
}
|
||||
|
||||
// Helper functions for softmax decomposition.
|
||||
// @{
|
||||
|
||||
// Helper function to produce the iterator types (reduction or parallel) and
|
||||
// affine maps for the iterators used in the decomposition of softmax.
|
||||
// This method creates:
|
||||
// If allParallel == true:
|
||||
// - iterator type: {parallel, ..., parallel}
|
||||
// - affine maps:
|
||||
// -- identity with inputRank dimensions.
|
||||
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
||||
// where N == inputRank.
|
||||
//
|
||||
// If allParallel == false:
|
||||
// - iterator type at dim(i) == parallel for i != \p dim and
|
||||
// dim(dim) == reduction.
|
||||
// - affine map:
|
||||
// -- identity with inputRank dimensions.
|
||||
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
||||
// where N == inputRank.
|
||||
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
|
||||
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
|
||||
int64_t dim, bool allParallel = false) {
|
||||
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
||||
utils::IteratorType::parallel);
|
||||
if (!allParallel)
|
||||
iteratorTypes[dim] = utils::IteratorType::reduction;
|
||||
MLIRContext *ctxt = builder.getContext();
|
||||
auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
|
||||
SmallVector<AffineExpr, 2> affineExprs;
|
||||
for (int i = 0; i < inputRank; i++) {
|
||||
if (i != dim)
|
||||
affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
|
||||
}
|
||||
auto reductionMap =
|
||||
AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
|
||||
SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
|
||||
return std::make_tuple(iteratorTypes, indexingMaps);
|
||||
}
|
||||
|
||||
// Helper function to produce a linalg.generic that computes a reduction on
|
||||
// dimension \p dim with the operation type \p T.
|
||||
template <typename T>
|
||||
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
|
||||
int64_t dim) {
|
||||
auto inputType = cast<ShapedType>(input.getType());
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputShape.size();
|
||||
auto [iteratorTypes, indexingMaps] =
|
||||
computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
|
||||
assert(indexingMaps.size() == 2 &&
|
||||
"We should have two maps: 1 for the input, 1 for the output");
|
||||
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
||||
|
||||
auto genericOp = builder.create<linalg::GenericOp>(
|
||||
loc, output.getType(), input, output, indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = b.create<T>(loc, args[0], args[1]);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
return genericOp.getResult(0);
|
||||
}
|
||||
|
||||
/// Produce a linalg generic that computes the second step of the softmax
|
||||
/// decomposition: res = exp(input - max), where \p max is the max of \p input
|
||||
/// on dimension \p dim.
|
||||
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
|
||||
Value max, Value output, int64_t dim) {
|
||||
auto inputType = cast<ShapedType>(input.getType());
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputShape.size();
|
||||
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
||||
builder, inputRank, dim, /*allParallel=*/true);
|
||||
assert(indexingMaps.size() == 2 && "We should have one map for each input");
|
||||
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
||||
// Add the affine map for the output argument.
|
||||
indexingMaps.push_back(indexingMaps[0]);
|
||||
auto genericOp = builder.create<linalg::GenericOp>(
|
||||
loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
|
||||
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
|
||||
Value result = b.create<math::ExpOp>(loc, diff);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
return genericOp.getResult(0);
|
||||
}
|
||||
|
||||
/// Produce a linalg generic that computes the final step of the softmax
|
||||
/// decomposition.
|
||||
/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
|
||||
/// yield n / d
|
||||
/// }
|
||||
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
|
||||
Value denominator, Value output, int64_t dim) {
|
||||
auto inputType = cast<ShapedType>(numerator.getType());
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputShape.size();
|
||||
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
||||
builder, inputRank, dim, /*allParallel=*/true);
|
||||
assert(indexingMaps.size() == 2 &&
|
||||
"We should have one map for each input (2)");
|
||||
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
|
||||
// Add the affine map for the output tensor.
|
||||
indexingMaps.push_back(indexingMaps[0]);
|
||||
auto genericOp = builder.create<linalg::GenericOp>(
|
||||
loc, numerator.getType(), ValueRange{numerator, denominator}, output,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
});
|
||||
return genericOp.getResult(0);
|
||||
}
|
||||
// @} End helper functions for softmax decomposition.
|
||||
|
||||
/// Given an N-dimensional tensor x, this method converts
|
||||
/// softmax(x) to the following sequence of operations:
|
||||
///
|
||||
/// 1. Compute the max of x along dimension d. This results
|
||||
/// in a N-1 dimensional tensor m.
|
||||
/// m = max(x, dim = d)
|
||||
///
|
||||
/// 2. Subtract a broadcasted m from x and exponentiate. This results in
|
||||
/// a N dimensional tensor z.
|
||||
/// z = exp(x - m)
|
||||
///
|
||||
/// 3. Compute the sum of z along dimension d. This results in
|
||||
/// a N-1 dimensional tensor l.
|
||||
/// l = sum(z, dim = d)
|
||||
///
|
||||
/// 4. Divide z and l. This gives the N-dimensional softmax.
|
||||
/// softmax = z / l
|
||||
///
|
||||
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPoint(*this);
|
||||
Location loc = getLoc();
|
||||
Value input = getInput();
|
||||
ShapedType inputType = getInputOperandType();
|
||||
Type elementType = inputType.getElementType();
|
||||
int64_t reductionDim = getDimension();
|
||||
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
|
||||
Value outputNd = b.create<tensor::EmptyOp>(loc, dims, elementType);
|
||||
dims.erase(dims.begin() + reductionDim);
|
||||
// Step 1: Compute max along dim.
|
||||
Value output = b.create<tensor::EmptyOp>(loc, dims, elementType);
|
||||
Value neutralForMaxF =
|
||||
arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
|
||||
Value neutralForMaxFInit =
|
||||
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, output).result();
|
||||
Value max =
|
||||
reduce<arith::MaxFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
|
||||
|
||||
// Step 2: Subtract max from input and exponentiate.
|
||||
Value numerator =
|
||||
buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim);
|
||||
|
||||
// Step 3: Compute sum along dim.
|
||||
Value zero =
|
||||
arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
|
||||
Value zeroInit = b.create<linalg::FillOp>(loc, Value{zero}, output).result();
|
||||
Value denominator =
|
||||
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
|
||||
|
||||
// Step 4: Compute softmax.
|
||||
Value result =
|
||||
buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim);
|
||||
return SmallVector<Value>{result};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LinalgDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -335,6 +335,38 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DecomposeInterfaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Decompose the target operation if it implements the AggregatedOpInterface.
|
||||
// Push the decomposed operations (the ones that replaces the values produced by
|
||||
// \p target) in the `results`.
|
||||
DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
|
||||
transform::TransformRewriter &rewriter, Operation *target,
|
||||
transform::ApplyToEachResultList &results,
|
||||
transform::TransformState &state) {
|
||||
auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
|
||||
if (!decomposableOp) {
|
||||
failed(rewriter.notifyMatchFailure(target,
|
||||
"payload is not a decomposable op"));
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value>> maybeNewResults =
|
||||
decomposableOp.decomposeOperation(rewriter);
|
||||
if (failed(maybeNewResults))
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
|
||||
rewriter.replaceOp(decomposableOp, *maybeNewResults);
|
||||
for (Value val : *maybeNewResults) {
|
||||
Operation *definition = val.getDefiningOp();
|
||||
if (definition)
|
||||
results.push_back(definition);
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EliminateLinalgOpAnchoredEmptyTensorsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: @conv_2d_nhwc_hwcf
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
|
||||
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
|
||||
@@ -199,8 +202,54 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
|
||||
return %0 : tensor<?x?x1x?xf32>
|
||||
}
|
||||
|
||||
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
|
||||
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
|
||||
return %1 : tensor<2x16x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @softmax(
|
||||
//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
|
||||
// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
|
||||
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32
|
||||
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
|
||||
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
|
||||
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
|
||||
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
|
||||
// CHECK: %[[D8:.+]] = arith.maxf %[[IN]], %[[OUT]] : f32
|
||||
// CHECK: linalg.yield %[[D8]] : f32
|
||||
// CHECK: } -> tensor<2x16xf32>
|
||||
// CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
|
||||
// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
|
||||
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
|
||||
// CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32
|
||||
// CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32
|
||||
// CHECK: linalg.yield %[[D9]] : f32
|
||||
// CHECK: } -> tensor<2x16x32xf32>
|
||||
// CHECK: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[D5:.+]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
|
||||
// CHECK: %[[D6:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
|
||||
// CHECK-SAME: "parallel", "reduction"]} ins(%[[D4]] : tensor<2x16x32xf32>) outs(%[[D5]] : tensor<2x16xf32>) {
|
||||
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
|
||||
// CHECK: %[[D8]] = arith.addf %[[IN]], %[[OUT]] : f32
|
||||
// CHECK: linalg.yield %[[D8]] : f32
|
||||
// CHECK: } -> tensor<2x16xf32>
|
||||
// CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types =
|
||||
// CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>)
|
||||
// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) {
|
||||
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
|
||||
// CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32
|
||||
// CHECK: linalg.yield %[[D8]] : f32
|
||||
// CHECK: } -> tensor<2x16x32xf32>
|
||||
// CHECK: return %[[D7]] : tensor<2x16x32xf32>
|
||||
// CHECK: }
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.structured.decompose %0 : (!transform.any_op) -> !transform.any_op
|
||||
|
||||
%2 = transform.structured.match ops{["linalg.softmax"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%3 = transform.structured.decompose_interface %2 : (!transform.any_op) -> !transform.any_op
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user