mirror of
https://github.com/intel/llvm.git
synced 2026-01-30 14:07:28 +08:00
[mlir][sparse] Put the implementation for the insertion operation to subroutines.
Previously, we generated inlined implementation for insert operation and observed MLIR compile time increase due to the size of the main routine. We now put the insert operation implementation in subroutines and leave the inlining decision to the MLIR compiler. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138957
This commit is contained in:
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
using FuncGeneratorType =
|
||||
function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
|
||||
|
||||
static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
|
||||
|
||||
static constexpr uint64_t dimSizesIdx = 0;
|
||||
static constexpr uint64_t memSizesIdx = 1;
|
||||
static constexpr uint64_t fieldsIdx = 2;
|
||||
@@ -476,12 +481,24 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
///
|
||||
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
|
||||
///
|
||||
static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields,
|
||||
SmallVectorImpl<Value> &indices, Value value) {
|
||||
static void genInsertBody(OpBuilder &builder, ModuleOp module,
|
||||
func::FuncOp func, RankedTensorType rtp) {
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
Block *entryBlock = func.addEntryBlock();
|
||||
builder.setInsertionPointToStart(entryBlock);
|
||||
|
||||
Location loc = func.getLoc();
|
||||
ValueRange args = entryBlock->getArguments();
|
||||
unsigned rank = rtp.getShape().size();
|
||||
assert(rank == indices.size());
|
||||
unsigned field = fieldsIdx; // start past header
|
||||
|
||||
// Construct fields and indices arrays from parameters.
|
||||
ValueRange tmp = args.drop_back(rank + 1);
|
||||
SmallVector<Value> fields(tmp.begin(), tmp.end());
|
||||
tmp = args.take_back(rank + 1).drop_back();
|
||||
SmallVector<Value> indices(tmp.begin(), tmp.end());
|
||||
Value value = args.back();
|
||||
|
||||
unsigned field = fieldsIdx; // Start past header.
|
||||
Value pos = constantZero(builder, loc, builder.getIndexType());
|
||||
// Generate code for every dimension.
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
@@ -519,6 +536,77 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
else
|
||||
genStore(builder, loc, value, fields[field++], pos);
|
||||
assert(fields.size() == field);
|
||||
builder.create<func::ReturnOp>(loc, fields);
|
||||
}
|
||||
|
||||
/// Generates a call to a function to perform an insertion operation. If the
|
||||
/// function doesn't exist yet, call `createFunc` to generate the function.
|
||||
static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
|
||||
SmallVectorImpl<Value> &fields,
|
||||
SmallVectorImpl<Value> &indices, Value value,
|
||||
func::FuncOp insertPoint,
|
||||
StringRef namePrefix,
|
||||
FuncGeneratorType createFunc) {
|
||||
// The mangled name of the function has this format:
|
||||
// <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
|
||||
// _<indexBitWidth>_<pointerBitWidth>
|
||||
SmallString<32> nameBuffer;
|
||||
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
||||
nameOstream << namePrefix;
|
||||
unsigned rank = rtp.getShape().size();
|
||||
assert(rank == indices.size());
|
||||
for (unsigned d = 0; d < rank; d++) {
|
||||
if (isCompressedDim(rtp, d)) {
|
||||
nameOstream << "C_";
|
||||
} else if (isSingletonDim(rtp, d)) {
|
||||
nameOstream << "S_";
|
||||
} else {
|
||||
nameOstream << "D_";
|
||||
}
|
||||
}
|
||||
// Static dim sizes are used in the generated code while dynamic sizes are
|
||||
// loaded from the dimSizes buffer. This is the reason for adding the shape
|
||||
// to the function name.
|
||||
for (auto d : rtp.getShape())
|
||||
nameOstream << d << "_";
|
||||
SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp);
|
||||
// Permutation information is also used in generating insertion.
|
||||
if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
|
||||
nameOstream << enc.getDimOrdering() << "_";
|
||||
nameOstream << rtp.getElementType() << "_";
|
||||
nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth();
|
||||
|
||||
// Look up the function.
|
||||
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
|
||||
MLIRContext *context = module.getContext();
|
||||
auto result = SymbolRefAttr::get(context, nameOstream.str());
|
||||
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
|
||||
|
||||
// Construct parameters for fields and indices.
|
||||
SmallVector<Value> operands(fields.begin(), fields.end());
|
||||
operands.append(indices.begin(), indices.end());
|
||||
operands.push_back(value);
|
||||
Location loc = insertPoint.getLoc();
|
||||
|
||||
if (!func) {
|
||||
// Create the function.
|
||||
OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
builder.setInsertionPoint(insertPoint);
|
||||
|
||||
func = builder.create<func::FuncOp>(
|
||||
loc, nameOstream.str(),
|
||||
FunctionType::get(context, ValueRange(operands).getTypes(),
|
||||
ValueRange(fields).getTypes()));
|
||||
func.setPrivate();
|
||||
createFunc(builder, module, func, rtp);
|
||||
}
|
||||
|
||||
// Generate a call to perform the insertion and update `fields` with values
|
||||
// returned from the call.
|
||||
func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
|
||||
for (size_t i = 0; i < fields.size(); i++) {
|
||||
fields[i] = call.getResult(i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generations insertion finalization code.
|
||||
@@ -865,7 +953,9 @@ public:
|
||||
Value value = genLoad(rewriter, loc, values, index);
|
||||
indices.push_back(index);
|
||||
// TODO: faster for subsequent insertions?
|
||||
genInsert(rewriter, loc, dstType, fields, indices, value);
|
||||
auto insertPoint = op->template getParentOfType<func::FuncOp>();
|
||||
genInsertionCallHelper(rewriter, dstType, fields, indices, value,
|
||||
insertPoint, kInsertFuncNamePrefix, genInsertBody);
|
||||
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
|
||||
index);
|
||||
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
|
||||
@@ -899,7 +989,10 @@ public:
|
||||
SmallVector<Value> indices(adaptor.getIndices());
|
||||
// Generate insertion.
|
||||
Value value = adaptor.getValue();
|
||||
genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
|
||||
auto insertPoint = op->template getParentOfType<func::FuncOp>();
|
||||
genInsertionCallHelper(rewriter, dstType, fields, indices, value,
|
||||
insertPoint, kInsertFuncNamePrefix, genInsertBody);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user