mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 13:35:38 +08:00
[mlir][linalg] move isElementwise() to Linalg/Utils (NFC)
Differential Revision: https://reviews.llvm.org/D128398
This commit is contained in:
@@ -32,6 +32,15 @@ class LinalgDependenceGraph;
|
||||
// General utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Check if all indexing maps are projected permutations.
|
||||
bool allIndexingsAreProjectedPermutation(LinalgOp op);
|
||||
|
||||
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
||||
bool hasOnlyScalarElementwiseOp(Region &r);
|
||||
|
||||
/// Check if a LinalgOp is an element-wise operation.
|
||||
bool isElementwise(LinalgOp op);
|
||||
|
||||
/// Check if `permutation` is a permutation of the range
|
||||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
@@ -417,48 +417,6 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
||||
llvm::to_vector<4>(returnTypes), op->getAttrs())};
|
||||
}
|
||||
|
||||
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
||||
static bool hasOnlyScalarElementwiseOp(Region &r) {
|
||||
if (!llvm::hasSingleElement(r))
|
||||
return false;
|
||||
for (Operation &op : r.front()) {
|
||||
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
|
||||
linalg::IndexOp>(op) ||
|
||||
OpTrait::hasElementwiseMappableTraits(&op)) ||
|
||||
llvm::any_of(op.getResultTypes(),
|
||||
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns `true` if all indexing maps of the linalg op are projected
|
||||
/// permutations.
|
||||
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
||||
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
|
||||
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
|
||||
});
|
||||
}
|
||||
|
||||
// Return true if the op is an element-wise linalg op.
|
||||
static bool isElementwise(Operation *op) {
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
||||
if (!linalgOp)
|
||||
return false;
|
||||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
|
||||
return false;
|
||||
|
||||
if (!allIndexingsAreProjectedPermutation(linalgOp))
|
||||
return false;
|
||||
|
||||
// TODO: relax the restrictions on indexing map.
|
||||
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
|
||||
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
|
||||
return false;
|
||||
}
|
||||
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
|
||||
}
|
||||
|
||||
/// Generic vectorization function that rewrites the body of a `linalgOp` into
|
||||
/// vector form. Generic vectorization proceeds as follows:
|
||||
/// 1. Verify the `linalgOp` has one non-empty region.
|
||||
|
||||
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgUtils
|
||||
MLIRAffineAnalysis
|
||||
MLIRAffineUtils
|
||||
MLIRArithmeticDialect
|
||||
MLIRFuncDialect
|
||||
MLIRIR
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Affine/LoopUtils.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
@@ -141,6 +142,41 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
||||
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
|
||||
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
|
||||
});
|
||||
}
|
||||
|
||||
bool hasOnlyScalarElementwiseOp(Region &r) {
|
||||
if (!llvm::hasSingleElement(r))
|
||||
return false;
|
||||
for (Operation &op : r.front()) {
|
||||
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
|
||||
linalg::IndexOp>(op) ||
|
||||
OpTrait::hasElementwiseMappableTraits(&op)) ||
|
||||
llvm::any_of(op.getResultTypes(),
|
||||
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isElementwise(LinalgOp op) {
|
||||
if (op.getNumLoops() != op.getNumParallelLoops())
|
||||
return false;
|
||||
|
||||
if (!allIndexingsAreProjectedPermutation(op))
|
||||
return false;
|
||||
|
||||
// TODO: relax the restrictions on indexing map.
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
if (!op.getTiedIndexingMap(opOperand).isPermutation())
|
||||
return false;
|
||||
}
|
||||
return hasOnlyScalarElementwiseOp(op->getRegion(0));
|
||||
}
|
||||
|
||||
bool isPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Count the number of appearances for all indices.
|
||||
SmallVector<int64_t> indexCounts(permutation.size(), 0);
|
||||
|
||||
@@ -7472,6 +7472,7 @@ cc_library(
|
||||
":ArithmeticDialect",
|
||||
":ArithmeticUtils",
|
||||
":DialectUtils",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
":LinalgAnalysis",
|
||||
":LinalgDialect",
|
||||
|
||||
Reference in New Issue
Block a user