2022-12-15 18:28:30 +00:00
|
|
|
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "CodegenUtils.h"
|
|
|
|
|
#include "SparseTensorStorageLayout.h"
|
|
|
|
|
|
|
|
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
2023-01-13 21:05:06 -08:00
|
|
|
#include <optional>
|
2022-12-15 18:28:30 +00:00
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace sparse_tensor;
|
|
|
|
|
|
2023-01-03 18:06:54 -08:00
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Helper methods.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-12-15 18:28:30 +00:00
|
|
|
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
|
|
|
|
|
MLIRContext *ctx = tp.getContext();
|
|
|
|
|
auto enc = tp.getEncoding();
|
[mlir][sparse] Factoring out SparseTensorType class
This change adds a new `SparseTensorType` class for making the "dim" vs "lvl" distinction more overt, and for abstracting over the differences between sparse-tensors and dense-tensors. In addition, this change also adds new type aliases `Dimension`, `Level`, and `FieldIndex` to make code more self-documenting.
Although the diff is very large, the majority of the changes are mechanical in nature (e.g., changing types to use the new aliases, updating variable names to match, etc). Along the way I also made many variables `const` when they could be; the majority of which required only adding the keyword. A few places had conditional definitions of these variables, requiring actual code changes; however, that was only done when the overall change was extremely local and easy to extract. All these changes are included in the current patch only because it would be too onerous to split them off into a separate patch.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143800
2023-02-14 18:20:45 -08:00
|
|
|
const Level lvlRank = enc.getLvlRank();
|
2022-12-15 18:28:30 +00:00
|
|
|
|
|
|
|
|
SmallVector<Type, 2> result;
|
2023-02-22 18:44:00 +00:00
|
|
|
// TODO: how can we get the lowering type for index type in the later pipeline
|
|
|
|
|
// to be consistent? LLVM::StructureType does not allow index fields.
|
|
|
|
|
auto indexType = IntegerType::get(tp.getContext(), 64);
|
[mlir][sparse] Factoring out SparseTensorType class
This change adds a new `SparseTensorType` class for making the "dim" vs "lvl" distinction more overt, and for abstracting over the differences between sparse-tensors and dense-tensors. In addition, this change also adds new type aliases `Dimension`, `Level`, and `FieldIndex` to make code more self-documenting.
Although the diff is very large, the majority of the changes are mechanical in nature (e.g., changing types to use the new aliases, updating variable names to match, etc). Along the way I also made many variables `const` when they could be; the majority of which required only adding the keyword. A few places had conditional definitions of these variables, requiring actual code changes; however, that was only done when the overall change was extremely local and easy to extract. All these changes are included in the current patch only because it would be too onerous to split them off into a separate patch.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143800
2023-02-14 18:20:45 -08:00
|
|
|
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank);
|
2022-12-15 18:28:30 +00:00
|
|
|
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
|
|
|
|
|
getNumDataFieldsFromEncoding(enc));
|
|
|
|
|
result.push_back(dimSizes);
|
|
|
|
|
result.push_back(memSizes);
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Type convertSpecifier(StorageSpecifierType tp) {
|
|
|
|
|
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
|
|
|
|
|
getSpecifierFields(tp));
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-03 18:06:54 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Specifier struct builder.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2022-12-15 18:28:30 +00:00
|
|
|
|
|
|
|
|
constexpr uint64_t kDimSizePosInSpecifier = 0;
|
|
|
|
|
constexpr uint64_t kMemSizePosInSpecifier = 1;
|
|
|
|
|
|
|
|
|
|
class SpecifierStructBuilder : public StructBuilder {
|
2023-02-22 18:44:00 +00:00
|
|
|
private:
|
|
|
|
|
Value extractField(OpBuilder &builder, Location loc,
|
|
|
|
|
ArrayRef<int64_t> indices) {
|
|
|
|
|
return genCast(builder, loc,
|
|
|
|
|
builder.create<LLVM::ExtractValueOp>(loc, value, indices),
|
|
|
|
|
builder.getIndexType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
|
|
|
|
|
Value v) {
|
|
|
|
|
value = builder.create<LLVM::InsertValueOp>(
|
|
|
|
|
loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
|
|
|
|
|
indices);
|
|
|
|
|
}
|
|
|
|
|
|
2022-12-15 18:28:30 +00:00
|
|
|
public:
|
|
|
|
|
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
|
|
|
|
|
assert(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Undef value for dimension sizes, all zero value for memory sizes.
|
|
|
|
|
static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
|
|
|
|
|
|
|
|
|
|
Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
|
|
|
|
|
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
|
|
|
|
|
|
|
|
|
|
Value memSize(OpBuilder &builder, Location loc, unsigned pos);
|
|
|
|
|
void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
|
|
|
|
|
Type structType) {
|
|
|
|
|
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
|
|
|
|
|
SpecifierStructBuilder md(metaData);
|
|
|
|
|
auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
|
|
|
|
|
.getBody()[kMemSizePosInSpecifier]
|
|
|
|
|
.cast<LLVM::LLVMArrayType>();
|
|
|
|
|
|
|
|
|
|
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
|
|
|
|
|
// Fill memSizes array with zero.
|
|
|
|
|
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
|
|
|
|
|
md.setMemSize(builder, loc, i, zero);
|
|
|
|
|
|
|
|
|
|
return md;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Builds IR inserting the pos-th size into the descriptor.
|
|
|
|
|
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
|
|
|
|
|
unsigned dim) {
|
2023-02-22 18:44:00 +00:00
|
|
|
return extractField(builder, loc,
|
|
|
|
|
ArrayRef<int64_t>{kDimSizePosInSpecifier, dim});
|
2022-12-15 18:28:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Builds IR inserting the pos-th size into the descriptor.
|
|
|
|
|
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
|
|
|
|
|
unsigned dim, Value size) {
|
2023-02-22 18:44:00 +00:00
|
|
|
|
|
|
|
|
insertField(builder, loc, ArrayRef<int64_t>{kDimSizePosInSpecifier, dim},
|
|
|
|
|
size);
|
2022-12-15 18:28:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Builds IR extracting the pos-th memory size into the descriptor.
|
|
|
|
|
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
|
|
|
|
|
unsigned pos) {
|
2023-02-22 18:44:00 +00:00
|
|
|
return extractField(builder, loc,
|
|
|
|
|
ArrayRef<int64_t>{kMemSizePosInSpecifier, pos});
|
2022-12-15 18:28:30 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Builds IR inserting the pos-th memory size into the descriptor.
|
|
|
|
|
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
|
|
|
|
|
unsigned pos, Value size) {
|
2023-02-22 18:44:00 +00:00
|
|
|
insertField(builder, loc, ArrayRef<int64_t>{kMemSizePosInSpecifier, pos},
|
|
|
|
|
size);
|
2022-12-15 18:28:30 +00:00
|
|
|
}
|
|
|
|
|
|
2023-01-03 18:06:54 -08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// The sparse storage specifier type converter (defined in Passes.h).
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
|
|
|
|
|
addConversion([](Type type) { return type; });
|
|
|
|
|
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Storage specifier conversion rules.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-12-15 18:28:30 +00:00
|
|
|
template <typename Base, typename SourceOp>
|
|
|
|
|
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
|
|
|
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
SpecifierStructBuilder spec(adaptor.getSpecifier());
|
|
|
|
|
Value v;
|
|
|
|
|
if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
|
|
|
|
|
v = Base::onDimSize(rewriter, op, spec,
|
|
|
|
|
op.getDim().value().getZExtValue());
|
|
|
|
|
} else {
|
|
|
|
|
auto enc = op.getSpecifier().getType().getEncoding();
|
2023-01-10 12:33:10 -08:00
|
|
|
StorageLayout layout(enc);
|
2023-01-14 14:06:18 -08:00
|
|
|
std::optional<unsigned> dim;
|
2022-12-15 18:28:30 +00:00
|
|
|
if (op.getDim())
|
|
|
|
|
dim = op.getDim().value().getZExtValue();
|
|
|
|
|
unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
|
|
|
|
|
v = Base::onMemSize(rewriter, op, spec, idx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, v);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct StorageSpecifierSetOpConverter
|
|
|
|
|
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
|
|
|
|
|
SetStorageSpecifierOp> {
|
|
|
|
|
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
|
|
|
|
static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
|
|
|
SpecifierStructBuilder &spec, unsigned d) {
|
|
|
|
|
spec.setDimSize(builder, op.getLoc(), d, op.getValue());
|
|
|
|
|
return spec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
|
|
|
SpecifierStructBuilder &spec, unsigned i) {
|
|
|
|
|
spec.setMemSize(builder, op.getLoc(), i, op.getValue());
|
|
|
|
|
return spec;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct StorageSpecifierGetOpConverter
|
|
|
|
|
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
|
|
|
|
|
GetStorageSpecifierOp> {
|
|
|
|
|
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
|
|
|
|
static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
|
|
|
SpecifierStructBuilder &spec, unsigned d) {
|
|
|
|
|
return spec.dimSize(builder, op.getLoc(), d);
|
|
|
|
|
}
|
|
|
|
|
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
|
|
|
SpecifierStructBuilder &spec, unsigned i) {
|
|
|
|
|
return spec.memSize(builder, op.getLoc(), i);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct StorageSpecifierInitOpConverter
|
|
|
|
|
: public OpConversionPattern<StorageSpecifierInitOp> {
|
|
|
|
|
public:
|
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
|
|
|
|
|
rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
|
|
|
|
|
rewriter, op.getLoc(), llvmType));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2023-01-03 18:06:54 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Public method for populating conversion rules.
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
2022-12-15 18:28:30 +00:00
|
|
|
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
|
|
|
|
|
RewritePatternSet &patterns) {
|
|
|
|
|
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
|
|
|
|
|
StorageSpecifierInitOpConverter>(converter,
|
|
|
|
|
patterns.getContext());
|
|
|
|
|
}
|