[mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140122
This commit is contained in:
Peiming Liu
2022-12-15 18:28:30 +00:00
parent b49ee01fe2
commit 083ddffe47
11 changed files with 890 additions and 5 deletions

View File

@@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
let parameters = (ins SparseTensorEncodingAttr : $encoding);
let builders = [
TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
assert(encoding && "sparse tensor encoding should not be null");
return $_get(encoding.getContext(), encoding);
return get(encoding.getContext(), encoding);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$type), [{
return get(getSparseTensorEncoding(type));
@@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
}];
// We skipped the default builder that simply takes the input sparse tensor encoding
// attribute since we need to normalize the dimension level type and remove unrelated
// fields that are irrelavant to sparse tensor storage scheme.
let skipDefaultBuilders = 1;
let assemblyFormat="`<` qualified($encoding) `>`";
}

View File

@@ -158,6 +158,19 @@ std::unique_ptr<Pass>
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
bool enableConvert = true);
//===----------------------------------------------------------------------===//
// The SparseStorageSpecifierToLLVM pass.
//===----------------------------------------------------------------------===//
class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
public:
StorageSpecifierToLLVMTypeConverter();
};
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//

View File

@@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
];
}
def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
let summary = "Lower sparse storage specifer to llvm structure";
let description = [{
This pass rewrites sparse tensor storage specifier-related operations into
LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.
Example of the conversion:
```mlir
Before:
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
: !sparse_tensor.storage_specifier<#CSR> to i64
After:
%0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
```
}];
let constructor = "mlir::createStorageSpecifierToLLVMPass()";
let dependentDialects = [
"arith::ArithDialect",
"LLVM::LLVMDialect",
"sparse_tensor::SparseTensorDialect",
];
}
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

View File

@@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
// SparseTensorDialect Types.
//===----------------------------------------------------------------------===//
/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
/// irrelevant fields that does not alter the sparse tensor memory layout.
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
for (auto dlt : enc.getDimLevelType())
dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
return SparseTensorEncodingAttr::get(
enc.getContext(), dlts,
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
enc.getPointerBitWidth(), enc.getIndexBitWidth());
}
StorageSpecifierType
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
}
IntegerType StorageSpecifierType::getSizesType() const {
unsigned idxBitWidth =
getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;

View File

@@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenEnv.cpp
CodegenUtils.cpp
SparseBufferRewriting.cpp
SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
SparseTensorStorageLayout.cpp
SparseVectorization.cpp
Sparsification.cpp
SparsificationAndBufferizationPass.cpp

View File

@@ -0,0 +1,184 @@
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
using namespace mlir;
using namespace sparse_tensor;
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
unsigned rank = enc.getDimLevelType().size();
SmallVector<Type, 2> result;
auto indexType = tp.getSizesType();
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
getNumDataFieldsFromEncoding(enc));
result.push_back(dimSizes);
result.push_back(memSizes);
return result;
}
static Type convertSpecifier(StorageSpecifierType tp) {
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
getSpecifierFields(tp));
}
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
}
constexpr uint64_t kDimSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
class SpecifierStructBuilder : public StructBuilder {
public:
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
assert(value);
}
// Undef value for dimension sizes, all zero value for memory sizes.
static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
Value memSize(OpBuilder &builder, Location loc, unsigned pos);
void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
};
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
Type structType) {
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
.getBody()[kMemSizePosInSpecifier]
.cast<LLVM::LLVMArrayType>();
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
// Fill memSizes array with zero.
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
md.setMemSize(builder, loc, i, zero);
return md;
}
/// Builds IR inserting the pos-th size into the descriptor.
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
unsigned dim) {
return builder.create<LLVM::ExtractValueOp>(
loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
}
/// Builds IR inserting the pos-th size into the descriptor.
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
unsigned dim, Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
}
/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
}
/// Builds IR inserting the pos-th memory size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
unsigned pos, Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
}
template <typename Base, typename SourceOp>
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SpecifierStructBuilder spec(adaptor.getSpecifier());
Value v;
if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
v = Base::onDimSize(rewriter, op, spec,
op.getDim().value().getZExtValue());
} else {
auto enc = op.getSpecifier().getType().getEncoding();
builder::StorageLayout layout(enc);
Optional<unsigned> dim = std::nullopt;
if (op.getDim())
dim = op.getDim().value().getZExtValue();
unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
v = Base::onMemSize(rewriter, op, spec, idx);
}
rewriter.replaceOp(op, v);
return success();
}
};
struct StorageSpecifierSetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
SetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned d) {
spec.setDimSize(builder, op.getLoc(), d, op.getValue());
return spec;
}
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned i) {
spec.setMemSize(builder, op.getLoc(), i, op.getValue());
return spec;
}
};
struct StorageSpecifierGetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
GetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned d) {
return spec.dimSize(builder, op.getLoc(), d);
}
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned i) {
return spec.memSize(builder, op.getLoc(), i);
}
};
struct StorageSpecifierInitOpConverter
: public OpConversionPattern<StorageSpecifierInitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
rewriter, op.getLoc(), llvmType));
return success();
}
};
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
StorageSpecifierInitOpConverter>(converter,
patterns.getContext());
}

View File

@@ -28,6 +28,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
// Storage specifier outlives sparse tensor pipeline.
target.addLegalOp<GetStorageSpecifierOp>();
target.addLegalOp<SetStorageSpecifierOp>();
target.addLegalOp<StorageSpecifierInitOp>();
// All dynamic rules below accept new function, call, return, and
// various tensor and bufferization operations as legal output of the
// rewriting provided that all sparse tensor types have been fully
// rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
@@ -271,6 +277,44 @@ struct SparseVectorizationPass
}
};
struct StorageSpecifierToLLVMPass
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
StorageSpecifierToLLVMPass() = default;
void runOnOperation() override {
auto *ctx = &getContext();
ConversionTarget target(*ctx);
RewritePatternSet patterns(ctx);
StorageSpecifierToLLVMTypeConverter converter;
// All ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateStorageSpecifierToLLVMPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
return std::make_unique<StorageSpecifierToLLVMPass>();
}

View File

@@ -0,0 +1,188 @@
//===- SparseTensorStorageLayout.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "SparseTensorStorageLayout.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace sparse_tensor;
static Value createIndexCast(OpBuilder &builder, Location loc, Value value,
Type to) {
if (value.getType() != to)
return builder.create<arith::IndexCastOp>(loc, to, value);
return value;
}
static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional<unsigned> dim) {
if (!dim)
return nullptr;
return IntegerAttr::get(IndexType::get(ctx), dim.value());
}
unsigned
builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind,
Optional<unsigned> dim) const {
unsigned fieldIdx = -1u;
foreachFieldInSparseTensor(
enc,
[dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
unsigned fDim, DimLevelType dlt) -> bool {
if ((dim && fDim == dim.value() && kind == fKind) ||
(kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
fieldIdx = fIdx;
// Returns false to break the iteration.
return false;
}
return true;
});
assert(fieldIdx != -1u);
return fieldIdx;
}
unsigned
builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind,
Optional<unsigned> dim) const {
return getMemRefFieldIndex(toFieldKind(kind), dim);
}
Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder,
Location loc,
RankedTensorType rtp) {
return builder.create<StorageSpecifierInitOp>(
loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp)));
}
Value builder::SparseTensorSpecifier::getSpecifierField(
OpBuilder &builder, Location loc, StorageSpecifierKind kind,
Optional<unsigned> dim) {
return createIndexCast(builder, loc,
builder.create<GetStorageSpecifierOp>(
loc, getFieldType(kind, dim), specifier, kind,
fromOptionalInt(specifier.getContext(), dim)),
builder.getIndexType());
}
void builder::SparseTensorSpecifier::setSpecifierField(
OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind,
Optional<unsigned> dim) {
specifier = builder.create<SetStorageSpecifierOp>(
loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim),
createIndexCast(builder, loc, v, getFieldType(kind, dim)));
}
constexpr uint64_t kDataFieldStartingIdx = 0;
void sparse_tensor::builder::foreachFieldInSparseTensor(
const SparseTensorEncodingAttr enc,
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
callback) {
assert(enc);
#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
if (!(callback(idx, kind, dim, dlt))) \
return;
static_assert(kDataFieldStartingIdx == 0);
unsigned fieldIdx = kDataFieldStartingIdx;
// Per-dimension storage.
for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
// Dimension level types apply in order to the reordered dimension.
// As a result, the compound type can be constructed directly in the given
// order.
auto dlt = getDimLevelType(enc, r);
if (isCompressedDLT(dlt)) {
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
} else if (isSingletonDLT(dlt)) {
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
} else {
assert(isDenseDLT(dlt)); // no fields
}
}
// The values array.
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
DimLevelType::Undef);
// Put metadata at the end.
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, -1u,
DimLevelType::Undef);
#undef RETURN_ON_FALSE
}
void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor(
RankedTensorType rType,
llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
DimLevelType)>
callback) {
auto enc = getSparseTensorEncoding(rType);
assert(enc);
// Construct the basic types.
Type idxType = enc.getIndexType();
Type ptrType = enc.getPointerType();
Type eltType = rType.getElementType();
Type metaDataType = StorageSpecifierType::get(enc);
// memref<? x ptr> pointers
Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
// memref<? x idx> indices
Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
// memref<? x eltType> values
Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
foreachFieldInSparseTensor(
enc,
[metaDataType, ptrMemType, idxMemType, valMemType,
callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
unsigned dim, DimLevelType dlt) -> bool {
switch (fieldKind) {
case SparseTensorFieldKind::StorageSpec:
return callback(metaDataType, fieldIdx, fieldKind, dim, dlt);
case SparseTensorFieldKind::PtrMemRef:
return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
case SparseTensorFieldKind::IdxMemRef:
return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
case SparseTensorFieldKind::ValMemRef:
return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
};
llvm_unreachable("unrecognized field kind");
});
}
unsigned
sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
unsigned numFields = 0;
foreachFieldInSparseTensor(enc,
[&numFields](unsigned, SparseTensorFieldKind,
unsigned, DimLevelType) -> bool {
numFields++;
return true;
});
return numFields;
}
unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding(
SparseTensorEncodingAttr enc) {
unsigned numFields = 0; // one value memref
foreachFieldInSparseTensor(enc,
[&numFields](unsigned fidx, SparseTensorFieldKind,
unsigned, DimLevelType) -> bool {
if (fidx >= kDataFieldStartingIdx)
numFields++;
return true;
});
numFields -= 1; // the last field is MetaData field
assert(numFields ==
builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1);
return numFields;
}

View File

@@ -0,0 +1,361 @@
//===- SparseTensorStorageLayout.h ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities for lowering and accessing sparse tensor
// types.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace sparse_tensor {
// FIXME: this is a tmp namespace
namespace builder {
//===----------------------------------------------------------------------===//
// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
// scheme.
//
// Sparse tensor storage scheme for rank-dimensional tensor is organized
// as a single compound type with the following fields. Note that every
// memref with ? size actually behaves as a "vector", i.e. the stored
// size is the capacity and the used size resides in the memSizes array.
//
// struct {
// ; per-dimension d:
// ; if dense:
// <nothing>
// ; if compresed:
// memref<? x ptr> pointers-d ; pointers for sparse dim d
// memref<? x idx> indices-d ; indices for sparse dim d
// ; if singleton:
// memref<? x idx> indices-d ; indices for singleton dim d
// memref<? x eltType> values ; values
//
// ; sparse tensor metadata
// struct {
// array<rank x int> dimSizes ; sizes for each dimension
// array<n x int> memSizes; ; sizes for each data memref
// }
// };
//
//===----------------------------------------------------------------------===//
enum class SparseTensorFieldKind : uint32_t {
StorageSpec = 0,
PtrMemRef = 1,
IdxMemRef = 2,
ValMemRef = 3
};
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::PtrMemRef) ==
static_cast<uint32_t>(StorageSpecifierKind::PtrMemSize));
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::IdxMemRef) ==
static_cast<uint32_t>(StorageSpecifierKind::IdxMemSize));
static_assert(static_cast<uint32_t>(SparseTensorFieldKind::ValMemRef) ==
static_cast<uint32_t>(StorageSpecifierKind::ValMemSize));
/// For each field that will be allocated for the given sparse tensor encoding,
/// calls the callback with the corresponding field index, field kind, dimension
/// (for sparse tensor level memrefs) and dimlevelType.
/// The field index always starts with zero and increments by one between two
/// callback invocations.
/// Ideally, all other methods should rely on this function to query a sparse
/// tensor fields instead of relying on ad-hoc index computation.
void foreachFieldInSparseTensor(
SparseTensorEncodingAttr,
llvm::function_ref<bool(unsigned /*fieldIdx*/,
SparseTensorFieldKind /*fieldKind*/,
unsigned /*dim (if applicable)*/,
DimLevelType /*DLT (if applicable)*/)>);
/// Same as above, except that it also builds the Type for the corresponding
/// field.
void foreachFieldAndTypeInSparseTensor(
RankedTensorType,
llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
SparseTensorFieldKind /*fieldKind*/,
unsigned /*dim (if applicable)*/,
DimLevelType /*DLT (if applicable)*/)>);
/// Gets the total number of fields for the given sparse tensor encoding.
unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
/// Gets the total number of data fields (index arrays, pointer arrays, and a
/// value array) for the given sparse tensor encoding.
unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) {
assert(kind != SparseTensorFieldKind::StorageSpec);
return static_cast<StorageSpecifierKind>(kind);
}
inline SparseTensorFieldKind toFieldKind(StorageSpecifierKind kind) {
assert(kind != StorageSpecifierKind::DimSize);
return static_cast<SparseTensorFieldKind>(kind);
}
class StorageLayout {
public:
explicit StorageLayout(SparseTensorEncodingAttr enc) : enc(enc) {}
///
/// Getters: get the field index for required field.
///
unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
Optional<unsigned> dim) const;
unsigned getMemRefFieldIndex(StorageSpecifierKind kind,
Optional<unsigned> dim) const;
private:
unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const;
SparseTensorEncodingAttr enc;
};
class SparseTensorSpecifier {
public:
explicit SparseTensorSpecifier(Value specifier) : specifier(specifier) {}
// Undef value for dimension sizes, all zero value for memory sizes.
static Value getInitValue(OpBuilder &builder, Location loc,
RankedTensorType rtp);
/*implicit*/ operator Value() { return specifier; }
Value getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, Optional<unsigned> dim);
void setSpecifierField(OpBuilder &builder, Location loc, Value v,
StorageSpecifierKind kind, Optional<unsigned> dim);
Type getFieldType(StorageSpecifierKind kind, Optional<unsigned> dim) {
return specifier.getType().getFieldType(kind, dim);
}
private:
TypedValue<StorageSpecifierType> specifier;
};
/// A helper class around an array of values that corresponding to a sparse
/// tensor, provides a set of meaningful APIs to query and update a particular
/// field in a consistent way.
/// Users should not make assumption on how a sparse tensor is laid out but
/// instead relies on this class to access the right value for the right field.
template <bool mut>
class SparseTensorDescriptorImpl {
private:
// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
// for mutable descriptors.
// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
// buffers to append value for some special cases, though users should be
// responsible to restore the buffer to legal states after their use. It is
// probably not a clean way, but it is the most efficient way to avoid copying
// the fields into another SmallVector. If a more clear way is wanted, we
// should change it to MutableArrayRef instead.
using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
ValueRange>::type;
public:
SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
: rType(tp.cast<RankedTensorType>()), fields(fields) {
assert(getSparseTensorEncoding(tp) &&
builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
fields.size());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
static_assert(
std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
}
// Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
// SparseTensorDescriptor.
template <typename T = SparseTensorDescriptorImpl<true>>
/*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
: rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
unsigned getMemRefFieldIndex(SparseTensorFieldKind kind,
Optional<unsigned> dim) const {
// Delegates to storage layout.
StorageLayout layout(getSparseTensorEncoding(rType));
return layout.getMemRefFieldIndex(kind, dim);
}
unsigned getPtrMemRefIndex(unsigned ptrDim) const {
return getMemRefFieldIndex(SparseTensorFieldKind::PtrMemRef, ptrDim);
}
unsigned getIdxMemRefIndex(unsigned idxDim) const {
return getMemRefFieldIndex(SparseTensorFieldKind::IdxMemRef, idxDim);
}
unsigned getValMemRefIndex() const {
return getMemRefFieldIndex(SparseTensorFieldKind::ValMemRef, std::nullopt);
}
unsigned getNumFields() const { return fields.size(); }
///
/// Getters: get the value for required field.
///
Value getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind,
Optional<unsigned> dim) const {
SparseTensorSpecifier md(fields.back());
return md.getSpecifierField(builder, loc, kind, dim);
}
Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
}
Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
dim);
}
Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
dim);
}
Value getValMemSize(OpBuilder &builder, Location loc) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
std::nullopt);
}
Value getPtrMemRef(unsigned ptrDim) const {
return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
}
Value getIdxMemRef(unsigned idxDim) const {
return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxDim);
}
Value getValMemRef() const {
return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
}
Value getMemRefField(SparseTensorFieldKind kind,
Optional<unsigned> dim) const {
return fields[getMemRefFieldIndex(kind, dim)];
}
Value getMemRefField(unsigned fidx) const {
assert(fidx < fields.size() - 1);
return fields[fidx];
}
Value getField(unsigned fidx) const {
assert(fidx < fields.size());
return fields[fidx];
}
///
/// Setters: update the value for required field (only enabled for
/// MutSparseTensorDescriptor).
///
template <typename T = Value>
void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
std::enable_if_t<mut, T> v) {
fields[getMemRefFieldIndex(kind, dim)] = v;
}
template <typename T = Value>
void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
assert(fidx < fields.size() - 1);
fields[fidx] = v;
}
template <typename T = Value>
void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
assert(fidx < fields.size());
fields[fidx] = v;
}
template <typename T = Value>
void setSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, Optional<unsigned> dim,
std::enable_if_t<mut, T> v) {
SparseTensorSpecifier md(fields.back());
md.setSpecifierField(builder, loc, v, kind, dim);
fields.back() = md;
}
template <typename T = Value>
void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
std::enable_if_t<mut, T> v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}
ValueRange getMemRefFields() const {
ValueRange ret = fields;
// drop the last metadata fields
return ret.slice(0, fields.size() - 1);
}
Type getMemRefElementType(SparseTensorFieldKind kind,
Optional<unsigned> dim) const {
return getMemRefField(kind, dim)
.getType()
.template cast<MemRefType>()
.getElementType();
}
RankedTensorType getTensorType() const { return rType; }
ValueArrayRef getFields() const { return fields; }
private:
RankedTensorType rType;
ValueArrayRef fields;
};
using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
}
/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
.getResult(0);
}
inline Value genTuple(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc) {
return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
}
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
}
inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
}
} // namespace builder
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_

View File

@@ -147,6 +147,7 @@ public:
} else {
pm.addPass(createSparseTensorCodegenPass(enableBufferInitialization));
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
pm.addPass(createStorageSpecifierToLLVMPass());
}
if (failed(runPipeline(pm, getOperation())))
return signalPassFailure();

View File

@@ -0,0 +1,38 @@
// RUN: mlir-opt %s -sparse-storage-specifier-to-llvm --cse --canonicalize | FileCheck %s
#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
// CHECK-LABEL: func.func @sparse_metadata_init() -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][1, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1, 1] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1, 2] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: return %[[VAL_4]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: }
func.func @sparse_metadata_init() -> !sparse_tensor.storage_specifier<#CSR> {
%0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR>
return %0 : !sparse_tensor.storage_specifier<#CSR>
}
// CHECK-LABEL: func.func @sparse_get_md(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>) -> i64 {
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: return %[[VAL_1]] : i64
func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#CSR>) -> i64 {
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
: !sparse_tensor.storage_specifier<#CSR> to i64
return %0 : i64
}
// CHECK-LABEL: func.func @sparse_set_md(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)>,
// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.struct<(array<2 x i64>, array<3 x i64>)> {
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
// CHECK: return %[[VAL_2]] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
func.func @sparse_set_md(%arg0: !sparse_tensor.storage_specifier<#CSR>, %arg1: i64)
-> !sparse_tensor.storage_specifier<#CSR> {
%0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1
: i64, !sparse_tensor.storage_specifier<#CSR>
return %0 : !sparse_tensor.storage_specifier<#CSR>
}