mirror of
https://github.com/intel/llvm.git
synced 2026-01-31 16:29:50 +08:00
[mlir][sparse] change variable dimension to fixed attribute pointers/indices
The "sparsification" pass does not need the ability to use runtime values for the dimension, so the only source for variability would have been user code. Restricting the dimension to constants simplifies code generation. Reviewed By: Peiming, wrengr Differential Revision: https://reviews.llvm.org/D133458
This commit is contained in:
@@ -254,13 +254,6 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
fields.push_back(createAllocation(builder, loc, eltType, valuesSz));
|
||||
}
|
||||
|
||||
/// Returns integral constant, if defined.
|
||||
static Optional<int64_t> getConstantInt(Value val) {
|
||||
if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -354,7 +347,7 @@ public:
|
||||
auto enc = getSparseTensorEncoding(op.getSource().getType());
|
||||
if (!enc)
|
||||
return failure();
|
||||
Optional<int64_t> index = getConstantInt(adaptor.getIndex());
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
if (!index)
|
||||
return failure();
|
||||
// Access into static dimension can query original type directly.
|
||||
@@ -473,13 +466,10 @@ public:
|
||||
// conversion.
|
||||
auto tuple = llvm::cast<UnrealizedConversionCastOp>(
|
||||
adaptor.getTensor().getDefiningOp());
|
||||
auto idx = Base::getIndexForOp(tuple, op);
|
||||
if (!idx)
|
||||
// Failed to get the index.
|
||||
return failure();
|
||||
unsigned idx = Base::getIndexForOp(tuple, op);
|
||||
auto fields = tuple.getInputs();
|
||||
assert(*idx < fields.size());
|
||||
rewriter.replaceOp(op, fields[*idx]);
|
||||
assert(idx < fields.size());
|
||||
rewriter.replaceOp(op, fields[idx]);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -490,12 +480,10 @@ class SparseToPointersConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToPointersOp op) {
|
||||
Optional<int64_t> dim = getConstantInt(op.getDim());
|
||||
if (!dim)
|
||||
return llvm::None; // variable dim
|
||||
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*dim, -1);
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToPointersOp op) {
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -505,12 +493,10 @@ class SparseToIndicesConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToIndicesOp op) {
|
||||
Optional<int64_t> dim = getConstantInt(op.getDim());
|
||||
if (!dim)
|
||||
return llvm::None; // variable dim
|
||||
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*dim);
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
|
||||
ToIndicesOp op) {
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -520,8 +506,8 @@ class SparseToValuesConverter
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Optional<unsigned> getIndexForOp(UnrealizedConversionCastOp tuple,
|
||||
ToValuesOp /*op*/) {
|
||||
static unsigned getIndexForOp(UnrealizedConversionCastOp tuple,
|
||||
ToValuesOp /*op*/) {
|
||||
// The last field holds the value buffer.
|
||||
return tuple.getInputs().size() - 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user