mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][sparse] admit un-sparsifiable operations if all its operands are loaded from dense input
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D153998
This commit is contained in:
@@ -70,6 +70,9 @@ struct TensorExp final {
|
||||
/// and kSelect, this holds the original operation with all regions. For
|
||||
/// kBinaryBranch, this holds the YieldOp for the left or right half
|
||||
/// to be merged into a nested scf loop.
|
||||
///
|
||||
/// Or the actual operation that we can not sparsify but having all dense
|
||||
/// operands for kDenseOp.
|
||||
Operation *op;
|
||||
|
||||
/// An optional attribute that is required to determine the semantics of the
|
||||
@@ -157,8 +160,9 @@ enum class TensorExp::Kind {
|
||||
kShrS, // signed
|
||||
kShrU, // unsigned
|
||||
kShlI,
|
||||
kBinary, // semiring binary op
|
||||
kReduce, // semiring reduction op
|
||||
kBinary, // semiring binary op
|
||||
kReduce, // semiring reduction op
|
||||
kDenseOp, // special category of operations requiring all dense operands
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -645,7 +649,11 @@ private:
|
||||
Type inferType(ExprId e, Value src) const;
|
||||
|
||||
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
|
||||
std::optional<ExprId> buildTensorExp(linalg::GenericOp op, Value v);
|
||||
/// The boolean value returned indicates whether the result of the current
|
||||
/// operation being built depends on any value that is loaded from a sparse
|
||||
/// tensor.
|
||||
std::pair<std::optional<ExprId>, bool> buildTensorExp(linalg::GenericOp op,
|
||||
Value v);
|
||||
|
||||
/// Merger data structures.
|
||||
const TensorId outTensor;
|
||||
|
||||
@@ -92,6 +92,7 @@ static ExpArity getExpArity(TensorExp::Kind k) {
|
||||
case TensorExp::Kind::kSubI:
|
||||
case TensorExp::Kind::kCmpF:
|
||||
case TensorExp::Kind::kCmpI:
|
||||
case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands
|
||||
return ExpArity::kBinary;
|
||||
}
|
||||
llvm_unreachable("unexpected kind");
|
||||
@@ -210,6 +211,11 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
|
||||
children.e0 = x;
|
||||
children.e1 = y;
|
||||
return;
|
||||
case TensorExp::Kind::kDenseOp:
|
||||
assert(x != detail::kInvalidId && !v && o);
|
||||
children.e0 = x;
|
||||
children.e1 = y;
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("unexpected kind");
|
||||
}
|
||||
@@ -393,7 +399,8 @@ LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
|
||||
|
||||
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
|
||||
Operation *op) {
|
||||
assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect);
|
||||
assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) ||
|
||||
TensorExp::Kind::kDenseOp == kind);
|
||||
const LatSetId sNew = addSet();
|
||||
auto &setNew = latSets[sNew];
|
||||
for (const LatPointId p : set(s0)) {
|
||||
@@ -546,6 +553,12 @@ bool Merger::hasNegateOnOut(ExprId e) const {
|
||||
case TensorExp::Kind::kSubI:
|
||||
return expContainsTensor(expr.children.e1, outTensor) ||
|
||||
hasNegateOnOut(expr.children.e0);
|
||||
case TensorExp::Kind::kDenseOp: {
|
||||
bool lhsNeg = hasNegateOnOut(expr.children.e0);
|
||||
if (!lhsNeg && expr.children.e1 != detail::kInvalidId)
|
||||
return hasNegateOnOut(expr.children.e1);
|
||||
return lhsNeg;
|
||||
}
|
||||
default: {
|
||||
switch (getExpArity(expr.kind)) {
|
||||
case ExpArity::kNullary:
|
||||
@@ -646,6 +659,10 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
|
||||
case TensorExp::Kind::kCmpI:
|
||||
case TensorExp::Kind::kBinary:
|
||||
return false;
|
||||
case TensorExp::Kind::kDenseOp:
|
||||
// Since Merger guarantees all the operands of the kDenseOp to be dense, the
|
||||
// operation must be single-condition.
|
||||
return true;
|
||||
}
|
||||
llvm_unreachable("unexpected kind");
|
||||
}
|
||||
@@ -771,6 +788,8 @@ static const char *kindToOpSymbol(TensorExp::Kind kind) {
|
||||
return "binary";
|
||||
case TensorExp::Kind::kReduce:
|
||||
return "reduce";
|
||||
case TensorExp::Kind::kDenseOp:
|
||||
return "dense";
|
||||
}
|
||||
llvm_unreachable("unexpected kind for symbol");
|
||||
}
|
||||
@@ -857,14 +876,19 @@ void Merger::dumpExp(ExprId e) const {
|
||||
case TensorExp::Kind::kCmpI:
|
||||
case TensorExp::Kind::kBinary:
|
||||
case TensorExp::Kind::kReduce:
|
||||
case TensorExp::Kind::kDenseOp:
|
||||
llvm::dbgs() << "(";
|
||||
dumpExp(expr.children.e0);
|
||||
llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
|
||||
if (expr.attr)
|
||||
llvm::dbgs() << "{" << expr.attr << "}";
|
||||
llvm::dbgs() << " ";
|
||||
dumpExp(expr.children.e1);
|
||||
llvm::dbgs() << ")";
|
||||
if (expr.children.e1 != detail::kInvalidId) {
|
||||
llvm::dbgs() << " ";
|
||||
dumpExp(expr.children.e1);
|
||||
llvm::dbgs() << ")";
|
||||
} else {
|
||||
assert(expr.kind == TensorExp::Kind::kDenseOp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1142,6 +1166,21 @@ LatSetId Merger::buildLattices(ExprId e, LoopId i) {
|
||||
Operation *const op = expr.op;
|
||||
return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
|
||||
}
|
||||
case TensorExp::Kind::kDenseOp: {
|
||||
// It does not really matter whether we use conjunctive/disjunctive set
|
||||
// here, as all the operands of kDenseOp must be dense, the disjunctive set
|
||||
// will be optimized into conjunctive set eventually.
|
||||
if (expr.children.e1 == detail::kInvalidId) {
|
||||
const ExprId e0 = expr.children.e0;
|
||||
Operation *const op = expr.op;
|
||||
return mapSet(kind, buildLattices(e0, i), Value(), op);
|
||||
}
|
||||
|
||||
const ExprId e0 = expr.children.e0;
|
||||
const ExprId e1 = expr.children.e1;
|
||||
Operation *const op = expr.op;
|
||||
return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind");
|
||||
}
|
||||
@@ -1150,7 +1189,7 @@ std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
||||
// Build the linalg semantics backward from yield.
|
||||
Operation *yield = op.getRegion().front().getTerminator();
|
||||
assert(isa<linalg::YieldOp>(yield));
|
||||
return buildTensorExp(op, yield->getOperand(0));
|
||||
return buildTensorExp(op, yield->getOperand(0)).first;
|
||||
}
|
||||
|
||||
/// Only returns false if we are certain this is a nonzero.
|
||||
@@ -1210,7 +1249,9 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
|
||||
return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0));
|
||||
}
|
||||
|
||||
std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
std::pair<std::optional<ExprId>, bool>
|
||||
Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
// Recursion leaves.
|
||||
if (auto arg = dyn_cast<BlockArgument>(v)) {
|
||||
const TensorId tid = makeTensorId(arg.getArgNumber());
|
||||
// Any argument of the generic op that is not marked as a scalar
|
||||
@@ -1218,96 +1259,98 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
// bounds. This includes rank-0 tensor arguments.
|
||||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand &t = op->getOpOperand(tid);
|
||||
bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr;
|
||||
if (!op.isScalar(&t))
|
||||
return addTensorExp(tid);
|
||||
return {addTensorExp(tid), hasSpDep};
|
||||
v = t.get(); // get scalar value
|
||||
}
|
||||
// Any other argument (marked as scalar argument for the generic op
|
||||
// or belonging to an enveloping op) is considered invariant.
|
||||
return addInvariantExp(v);
|
||||
return {addInvariantExp(v), /*hasSpDep=*/false};
|
||||
}
|
||||
// Something defined outside is invariant.
|
||||
Operation *def = v.getDefiningOp();
|
||||
if (def->getBlock() != &op.getRegion().front())
|
||||
return addInvariantExp(v);
|
||||
return {addInvariantExp(v), /*hasSpDep=*/false};
|
||||
// Construct index operations.
|
||||
if (def->getNumOperands() == 0) {
|
||||
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
|
||||
return addLoopVarExp(makeLoopId(indexOp.getDim()));
|
||||
return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false};
|
||||
}
|
||||
|
||||
// Construct unary operations if subexpression can be built.
|
||||
if (def->getNumOperands() == 1) {
|
||||
const auto x = buildTensorExp(op, def->getOperand(0));
|
||||
const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0));
|
||||
if (x.has_value()) {
|
||||
const ExprId e = *x;
|
||||
if (isa<math::AbsFOp>(def))
|
||||
return addExp(TensorExp::Kind::kAbsF, e);
|
||||
return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep};
|
||||
if (isa<complex::AbsOp>(def))
|
||||
return addExp(TensorExp::Kind::kAbsC, e);
|
||||
return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep};
|
||||
if (isa<math::AbsIOp>(def))
|
||||
return addExp(TensorExp::Kind::kAbsI, e);
|
||||
return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep};
|
||||
if (isa<math::CeilOp>(def))
|
||||
return addExp(TensorExp::Kind::kCeilF, e);
|
||||
return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep};
|
||||
if (isa<math::FloorOp>(def))
|
||||
return addExp(TensorExp::Kind::kFloorF, e);
|
||||
return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep};
|
||||
if (isa<math::SqrtOp>(def))
|
||||
return addExp(TensorExp::Kind::kSqrtF, e);
|
||||
return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep};
|
||||
if (isa<complex::SqrtOp>(def))
|
||||
return addExp(TensorExp::Kind::kSqrtC, e);
|
||||
return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep};
|
||||
if (isa<math::ExpM1Op>(def))
|
||||
return addExp(TensorExp::Kind::kExpm1F, e);
|
||||
return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep};
|
||||
if (isa<complex::Expm1Op>(def))
|
||||
return addExp(TensorExp::Kind::kExpm1C, e);
|
||||
return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep};
|
||||
if (isa<math::Log1pOp>(def))
|
||||
return addExp(TensorExp::Kind::kLog1pF, e);
|
||||
return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep};
|
||||
if (isa<complex::Log1pOp>(def))
|
||||
return addExp(TensorExp::Kind::kLog1pC, e);
|
||||
return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep};
|
||||
if (isa<math::SinOp>(def))
|
||||
return addExp(TensorExp::Kind::kSinF, e);
|
||||
return {addExp(TensorExp::Kind::kSinF, e), hasSpDep};
|
||||
if (isa<complex::SinOp>(def))
|
||||
return addExp(TensorExp::Kind::kSinC, e);
|
||||
return {addExp(TensorExp::Kind::kSinC, e), hasSpDep};
|
||||
if (isa<math::TanhOp>(def))
|
||||
return addExp(TensorExp::Kind::kTanhF, e);
|
||||
return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep};
|
||||
if (isa<complex::TanhOp>(def))
|
||||
return addExp(TensorExp::Kind::kTanhC, e);
|
||||
return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep};
|
||||
if (isa<arith::NegFOp>(def))
|
||||
return addExp(TensorExp::Kind::kNegF, e); // no negi in std
|
||||
return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std
|
||||
if (isa<complex::NegOp>(def))
|
||||
return addExp(TensorExp::Kind::kNegC, e);
|
||||
return {addExp(TensorExp::Kind::kNegC, e), hasSpDep};
|
||||
if (isa<arith::TruncFOp>(def))
|
||||
return addExp(TensorExp::Kind::kTruncF, e, v);
|
||||
return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep};
|
||||
if (isa<arith::ExtFOp>(def))
|
||||
return addExp(TensorExp::Kind::kExtF, e, v);
|
||||
return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep};
|
||||
if (isa<arith::FPToSIOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastFS, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep};
|
||||
if (isa<arith::FPToUIOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastFU, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep};
|
||||
if (isa<arith::SIToFPOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastSF, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep};
|
||||
if (isa<arith::UIToFPOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastUF, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep};
|
||||
if (isa<arith::ExtSIOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastS, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep};
|
||||
if (isa<arith::ExtUIOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastU, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep};
|
||||
if (isa<arith::IndexCastOp>(def))
|
||||
return addExp(TensorExp::Kind::kCastIdx, e, v);
|
||||
return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep};
|
||||
if (isa<arith::TruncIOp>(def))
|
||||
return addExp(TensorExp::Kind::kTruncI, e, v);
|
||||
return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep};
|
||||
if (isa<complex::ImOp>(def))
|
||||
return addExp(TensorExp::Kind::kCIm, e);
|
||||
return {addExp(TensorExp::Kind::kCIm, e), hasSpDep};
|
||||
if (isa<complex::ReOp>(def))
|
||||
return addExp(TensorExp::Kind::kCRe, e);
|
||||
return {addExp(TensorExp::Kind::kCRe, e), hasSpDep};
|
||||
if (isa<arith::BitcastOp>(def))
|
||||
return addExp(TensorExp::Kind::kBitCast, e, v);
|
||||
return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep};
|
||||
if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
|
||||
if (isAdmissibleBranch(unop, unop.getPresentRegion()) &&
|
||||
isAdmissibleBranch(unop, unop.getAbsentRegion()))
|
||||
return addExp(TensorExp::Kind::kUnary, e, Value(), def);
|
||||
return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep};
|
||||
}
|
||||
if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
|
||||
if (isAdmissibleBranch(selop, selop.getRegion()))
|
||||
return addExp(TensorExp::Kind::kSelect, e, Value(), def);
|
||||
return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1315,49 +1358,50 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
// See buildLattices() for an explanation of rejecting certain
|
||||
// division and shift operations.
|
||||
if (def->getNumOperands() == 2) {
|
||||
const auto x = buildTensorExp(op, def->getOperand(0));
|
||||
const auto y = buildTensorExp(op, def->getOperand(1));
|
||||
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
|
||||
const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
|
||||
bool hasSpDep = xDepSp || yDepSp;
|
||||
if (x.has_value() && y.has_value()) {
|
||||
const ExprId e0 = *x;
|
||||
const ExprId e1 = *y;
|
||||
if (isa<arith::MulFOp>(def))
|
||||
return addExp(TensorExp::Kind::kMulF, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep};
|
||||
if (isa<complex::MulOp>(def))
|
||||
return addExp(TensorExp::Kind::kMulC, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep};
|
||||
if (isa<arith::MulIOp>(def))
|
||||
return addExp(TensorExp::Kind::kMulI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep};
|
||||
if (isa<arith::DivFOp>(def) && !maybeZero(e1))
|
||||
return addExp(TensorExp::Kind::kDivF, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep};
|
||||
if (isa<complex::DivOp>(def) && !maybeZero(e1))
|
||||
return addExp(TensorExp::Kind::kDivC, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep};
|
||||
if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
|
||||
return addExp(TensorExp::Kind::kDivS, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep};
|
||||
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
|
||||
return addExp(TensorExp::Kind::kDivU, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep};
|
||||
if (isa<arith::AddFOp>(def))
|
||||
return addExp(TensorExp::Kind::kAddF, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep};
|
||||
if (isa<complex::AddOp>(def))
|
||||
return addExp(TensorExp::Kind::kAddC, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep};
|
||||
if (isa<arith::AddIOp>(def))
|
||||
return addExp(TensorExp::Kind::kAddI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep};
|
||||
if (isa<arith::SubFOp>(def))
|
||||
return addExp(TensorExp::Kind::kSubF, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep};
|
||||
if (isa<complex::SubOp>(def))
|
||||
return addExp(TensorExp::Kind::kSubC, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep};
|
||||
if (isa<arith::SubIOp>(def))
|
||||
return addExp(TensorExp::Kind::kSubI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep};
|
||||
if (isa<arith::AndIOp>(def))
|
||||
return addExp(TensorExp::Kind::kAndI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep};
|
||||
if (isa<arith::OrIOp>(def))
|
||||
return addExp(TensorExp::Kind::kOrI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep};
|
||||
if (isa<arith::XOrIOp>(def))
|
||||
return addExp(TensorExp::Kind::kXorI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep};
|
||||
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
|
||||
return addExp(TensorExp::Kind::kShrS, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep};
|
||||
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
|
||||
return addExp(TensorExp::Kind::kShrU, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep};
|
||||
if (isa<arith::ShLIOp>(def) && isInvariant(e1))
|
||||
return addExp(TensorExp::Kind::kShlI, e0, e1);
|
||||
return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep};
|
||||
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
|
||||
if (ci.getPredicate() == arith::CmpIPredicate::eq &&
|
||||
ci.getPredicate() == arith::CmpIPredicate::sle &&
|
||||
@@ -1366,11 +1410,12 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
ci.getPredicate() == arith::CmpIPredicate::uge) {
|
||||
// We can not sparsify comparison with equal, this is because 0 <= 0
|
||||
// yields true, and thus densifies the result.
|
||||
return std::nullopt;
|
||||
return {std::nullopt, false};
|
||||
}
|
||||
|
||||
return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
|
||||
ci.getPredicateAttr());
|
||||
auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr,
|
||||
ci.getPredicateAttr());
|
||||
return {e, hasSpDep};
|
||||
}
|
||||
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
|
||||
if (cf.getPredicate() == arith::CmpFPredicate::OEQ &&
|
||||
@@ -1384,10 +1429,11 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
cf.getPredicate() == arith::CmpFPredicate::UNO) {
|
||||
// We can not sparsify comparison with equal, this is because 0 <= 0
|
||||
// yields true, and thus densifies the result.
|
||||
return std::nullopt;
|
||||
return {std::nullopt, false};
|
||||
}
|
||||
return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
|
||||
cf.getPredicateAttr());
|
||||
auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr,
|
||||
cf.getPredicateAttr());
|
||||
return {e, hasSpDep};
|
||||
}
|
||||
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
|
||||
if (isAdmissibleBranch(binop, binop.getOverlapRegion()) &&
|
||||
@@ -1395,26 +1441,54 @@ std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
isAdmissibleBranch(binop, binop.getLeftRegion())) &&
|
||||
(binop.getRightIdentity() ||
|
||||
isAdmissibleBranch(binop, binop.getRightRegion())))
|
||||
return addExp(TensorExp::Kind::kBinary, e0, e1, def);
|
||||
return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep};
|
||||
}
|
||||
}
|
||||
}
|
||||
// Construct ternary operations if subexpressions can be built.
|
||||
if (def->getNumOperands() == 3) {
|
||||
const auto x = buildTensorExp(op, def->getOperand(0));
|
||||
const auto y = buildTensorExp(op, def->getOperand(1));
|
||||
const auto z = buildTensorExp(op, def->getOperand(2));
|
||||
const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0));
|
||||
const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1));
|
||||
const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2));
|
||||
bool hasSpDep = xDepSp || yDepSp || zDepSp;
|
||||
if (x.has_value() && y.has_value() && z.has_value()) {
|
||||
const ExprId e0 = *x;
|
||||
const ExprId e1 = *y;
|
||||
if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
|
||||
if (isAdmissibleBranch(redop, redop.getRegion()))
|
||||
return addExp(TensorExp::Kind::kReduce, e0, e1, def);
|
||||
return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, we are dealing with an operation that is not currently
|
||||
// sparsifiable. We can still generate code for it if all its operands only
|
||||
// have dense dependencies (i.e., all the values are loaded from dense
|
||||
// tensors).
|
||||
if (def->getNumResults() != 1) // only handle single result operation.
|
||||
return {std::nullopt, false};
|
||||
|
||||
SmallVector<std::pair<std::optional<ExprId>, bool>, 2> subExp;
|
||||
// Builds all the sub-expressions
|
||||
for (Value operand : def->getOperands())
|
||||
subExp.push_back(buildTensorExp(op, operand));
|
||||
|
||||
if (llvm::all_of(subExp,
|
||||
[](auto e) { return e.first.has_value() && !e.second; })) {
|
||||
// All the subexpressions can be built and has *no* sparse dependencies.
|
||||
if (subExp.size() == 2) {
|
||||
auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
|
||||
*subExp[1].first, def);
|
||||
return {e, false};
|
||||
}
|
||||
if (subExp.size() == 1) {
|
||||
auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first,
|
||||
detail::kInvalidId, def);
|
||||
return {e, false};
|
||||
}
|
||||
}
|
||||
// Cannot build.
|
||||
return std::nullopt;
|
||||
return {std::nullopt, false};
|
||||
}
|
||||
|
||||
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
|
||||
@@ -1609,6 +1683,14 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
|
||||
ReduceOp redOp = cast<ReduceOp>(expr.op);
|
||||
return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1});
|
||||
}
|
||||
case TensorExp::Kind::kDenseOp: {
|
||||
Operation *actualOp = expr.op;
|
||||
IRMapping mapping;
|
||||
mapping.map(actualOp->getOperand(0), v0);
|
||||
if (actualOp->getNumOperands() == 2)
|
||||
mapping.map(actualOp->getOperand(1), v1);
|
||||
return rewriter.clone(*actualOp, mapping)->getResult(0);
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind in build");
|
||||
}
|
||||
|
||||
95
mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
Normal file
95
mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
Normal file
@@ -0,0 +1,95 @@
|
||||
// RUN: mlir-opt %s -sparsification | FileCheck %s
|
||||
|
||||
#trait = {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
|
||||
}
|
||||
|
||||
#VEC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 32, crdWidth = 32 }>
|
||||
#COO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], posWidth = 32, crdWidth = 32 }>
|
||||
#CCC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ], posWidth = 32, crdWidth = 32 }>
|
||||
|
||||
//
|
||||
// This kernel can be sparsified as all unsparsifiable operations'
|
||||
// operands are loaded from dense tensors.
|
||||
//
|
||||
// CHECK-LABEL: func @dense_op_without_sp_dep
|
||||
// CHECK-NOT: linalg.generic {{.*}}
|
||||
func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>,
|
||||
%expanded_54: tensor<2x10x1xf32>,
|
||||
%expanded_56: tensor<2x10x1xf32>,
|
||||
%expanded_57: tensor<2x10x1xf32>,
|
||||
%176: tensor<8xf32, #VEC>,
|
||||
%177: tensor<8xf32, #VEC>,
|
||||
%9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> {
|
||||
%cst_13 = arith.constant -3.40282347E+38 : f32
|
||||
%178 = tensor.empty() : tensor<2x10x100xf32>
|
||||
%179 = linalg.generic #trait
|
||||
ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
|
||||
tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>,
|
||||
tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
|
||||
outs(%178 : tensor<2x10x100xf32>) {
|
||||
^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
|
||||
%180 = arith.mulf %in_60, %in_60 : f32
|
||||
%181 = arith.mulf %in_59, %cst_13 : f32
|
||||
%182 = arith.subf %181, %180 : f32
|
||||
%183 = arith.maxf %182, %cst_13 : f32
|
||||
%184 = arith.addf %183, %cst_13 : f32
|
||||
%185 = math.rsqrt %184 : f32 // data dependent on sparse value.
|
||||
%186 = arith.mulf %185, %in_61 : f32
|
||||
%187 = arith.subf %in, %in_58 : f32
|
||||
%188 = arith.mulf %187, %186 : f32
|
||||
%189 = arith.addf %188, %in_62 : f32
|
||||
%190 = arith.mulf %189, %in_63 : f32
|
||||
%191 = arith.addf %out, %190 : f32
|
||||
linalg.yield %191 : f32
|
||||
} -> tensor<2x10x100xf32>
|
||||
return %179 : tensor<2x10x100xf32>
|
||||
}
|
||||
|
||||
//
|
||||
// This kernel cannot be sparsified as some unsparsifiable operations'
|
||||
// operands are loaded from sparse tensors.
|
||||
//
|
||||
// CHECK-LABEL: func @dense_op_with_sp_dep
|
||||
// CHECK: linalg.generic {{.*}}
|
||||
func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>,
|
||||
%expanded_54: tensor<2x10x1xf32, #CCC>,
|
||||
%expanded_56: tensor<2x10x1xf32, #CCC>,
|
||||
%expanded_57: tensor<2x10x1xf32, #CCC>,
|
||||
%176: tensor<8xf32, #VEC>,
|
||||
%177: tensor<8xf32, #VEC>,
|
||||
%9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> {
|
||||
%cst_13 = arith.constant -3.40282347E+38 : f32
|
||||
%178 = tensor.empty() : tensor<2x10x100xf32>
|
||||
%179 = linalg.generic #trait
|
||||
ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
|
||||
tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>,
|
||||
tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
|
||||
outs(%178 : tensor<2x10x100xf32>) {
|
||||
^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
|
||||
%180 = arith.mulf %in_60, %in_60 : f32
|
||||
%181 = arith.mulf %in_59, %cst_13 : f32
|
||||
%182 = arith.subf %181, %180 : f32
|
||||
%183 = arith.maxf %182, %cst_13 : f32
|
||||
%184 = arith.addf %183, %cst_13 : f32
|
||||
%185 = math.rsqrt %184 : f32
|
||||
%186 = arith.mulf %185, %in_61 : f32
|
||||
%187 = arith.subf %in, %in_58 : f32
|
||||
%188 = arith.mulf %187, %186 : f32
|
||||
%189 = arith.addf %188, %in_62 : f32
|
||||
%190 = arith.mulf %189, %in_63 : f32
|
||||
%191 = arith.addf %out, %190 : f32
|
||||
linalg.yield %191 : f32
|
||||
} -> tensor<2x10x100xf32>
|
||||
return %179 : tensor<2x10x100xf32>
|
||||
}
|
||||
@@ -305,6 +305,12 @@ protected:
|
||||
case TensorExp::Kind::kReduce:
|
||||
return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
|
||||
compareExpression(tensorExp.children.e1, pattern.children.e1);
|
||||
case TensorExp::Kind::kDenseOp: {
|
||||
bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
|
||||
if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
|
||||
return compareExpression(tensorExp.children.e1, pattern.children.e1);
|
||||
return eq;
|
||||
}
|
||||
}
|
||||
llvm_unreachable("unexpected kind");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user