[mlir][sparse] change variable dimension to fixed attribute pointers/indices

The "sparsification" pass does not need the ability to use runtime values for
the dimension, so the only source for variability would have been user code.
Restricting the dimension to constants simplifies code generation.

Reviewed By: Peiming, wrengr

Differential Revision: https://reviews.llvm.org/D133458
This commit is contained in:
Aart Bik
2022-09-07 15:08:29 -07:00
parent bb6966aa53
commit 610b09074a
34 changed files with 430 additions and 479 deletions

View File

@@ -254,13 +254,6 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
}
/// Returns integral constant, if defined.
static Optional<int64_t> getConstantInt(Value val) {
if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
return {};
}
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -354,7 +347,7 @@ public:
auto enc = getSparseTensorEncoding(op.getSource().getType());
if (!enc)
return failure();
Optional<int64_t> index = getConstantInt(adaptor.getIndex());
Optional<int64_t> index = op.getConstantIndex();
if (!index)
return failure();
// Access into static dimension can query original type directly.
@@ -473,13 +466,10 @@ public:
// conversion.
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
adaptor.getTensor().getDefiningOp());
auto idx = Base::getIndexForOp(tuple, op);
if (!idx)
// Failed to get the index.
return failure();
unsigned idx = Base::getIndexForOp(tuple, op);
auto fields = tuple.getInputs();
assert(*idx < fields.size());
rewriter.replaceOp(op, fields[*idx]);
assert(idx < fields.size());
rewriter.replaceOp(op, fields[idx]);
return success();
}
};
@@ -490,12 +480,10 @@ class SparseToPointersConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToPointersOp op) {
Optional<int64_t> dim = getConstantInt(op.getDim());
if (!dim)
return llvm::None; // variable dim
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*dim, -1);
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToPointersOp op) {
uint64_t dim = op.getDimension().getZExtValue();
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1);
}
};
@@ -505,12 +493,10 @@ class SparseToIndicesConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToIndicesOp op) {
Optional<int64_t> dim = getConstantInt(op.getDim());
if (!dim)
return llvm::None; // variable dim
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*dim);
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
ToIndicesOp op) {
uint64_t dim = op.getDimension().getZExtValue();
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim);
}
};
@@ -520,8 +506,8 @@ class SparseToValuesConverter
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp tuple,
ToValuesOp /*op*/) {
static unsigned getIndexForOp(UnrealizedConversionCastOp tuple,
ToValuesOp /*op*/) {
// The last field holds the value buffer.
return tuple.getInputs().size() - 1;
}