mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 01:15:50 +08:00
[mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup
Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D142072
This commit is contained in:
@@ -101,16 +101,18 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
<< "expect positive value or ? for slice offset/size/stride";
|
||||
}
|
||||
|
||||
static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
|
||||
if (bitwidth)
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
return IndexType::get(ctx);
|
||||
}
|
||||
|
||||
Type SparseTensorEncodingAttr::getPointerType() const {
|
||||
unsigned ptrWidth = getPointerBitWidth();
|
||||
Type indexType = IndexType::get(getContext());
|
||||
return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
|
||||
return getIntegerOrIndexType(getContext(), getPointerBitWidth());
|
||||
}
|
||||
|
||||
Type SparseTensorEncodingAttr::getIndexType() const {
|
||||
unsigned idxWidth = getIndexBitWidth();
|
||||
Type indexType = IndexType::get(getContext());
|
||||
return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
|
||||
return getIntegerOrIndexType(getContext(), getIndexBitWidth());
|
||||
}
|
||||
|
||||
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
|
||||
@@ -157,11 +159,30 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
|
||||
return getStaticDimSliceStride(toOrigDim(*this, lvl));
|
||||
}
|
||||
|
||||
const static DimLevelType validDLTs[] = {
|
||||
DimLevelType::Dense, DimLevelType::Compressed,
|
||||
DimLevelType::CompressedNu, DimLevelType::CompressedNo,
|
||||
DimLevelType::CompressedNuNo, DimLevelType::Singleton,
|
||||
DimLevelType::SingletonNu, DimLevelType::SingletonNo,
|
||||
DimLevelType::SingletonNuNo};
|
||||
|
||||
static std::optional<DimLevelType> parseDLT(StringRef str) {
|
||||
for (DimLevelType dlt : validDLTs)
|
||||
if (str == toMLIRString(dlt))
|
||||
return dlt;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
#define RETURN_ON_FAIL(stmt) \
|
||||
if (failed(stmt)) { \
|
||||
return {}; \
|
||||
}
|
||||
#define ERROR_IF(COND, MSG) \
|
||||
if (COND) { \
|
||||
parser.emitError(parser.getNameLoc(), MSG); \
|
||||
return {}; \
|
||||
}
|
||||
|
||||
RETURN_ON_FAIL(parser.parseLess())
|
||||
RETURN_ON_FAIL(parser.parseLBrace())
|
||||
@@ -191,37 +212,13 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute attr;
|
||||
RETURN_ON_FAIL(parser.parseAttribute(attr));
|
||||
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an array for dimension level types");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!arrayAttr, "expected an array for dimension level types")
|
||||
for (auto i : arrayAttr) {
|
||||
auto strAttr = i.dyn_cast<StringAttr>();
|
||||
if (!strAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected a string value in dimension level types");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!strAttr, "expected a string value in dimension level types")
|
||||
auto strVal = strAttr.getValue();
|
||||
if (strVal == "dense") {
|
||||
dlt.push_back(DimLevelType::Dense);
|
||||
} else if (strVal == "compressed") {
|
||||
dlt.push_back(DimLevelType::Compressed);
|
||||
} else if (strVal == "compressed-nu") {
|
||||
dlt.push_back(DimLevelType::CompressedNu);
|
||||
} else if (strVal == "compressed-no") {
|
||||
dlt.push_back(DimLevelType::CompressedNo);
|
||||
} else if (strVal == "compressed-nu-no") {
|
||||
dlt.push_back(DimLevelType::CompressedNuNo);
|
||||
} else if (strVal == "singleton") {
|
||||
dlt.push_back(DimLevelType::Singleton);
|
||||
} else if (strVal == "singleton-nu") {
|
||||
dlt.push_back(DimLevelType::SingletonNu);
|
||||
} else if (strVal == "singleton-no") {
|
||||
dlt.push_back(DimLevelType::SingletonNo);
|
||||
} else if (strVal == "singleton-nu-no") {
|
||||
dlt.push_back(DimLevelType::SingletonNuNo);
|
||||
if (auto optDLT = parseDLT(strVal)) {
|
||||
dlt.push_back(optDLT.value());
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"unexpected dimension level type: ")
|
||||
@@ -232,46 +229,26 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
} else if (attrName == "dimOrdering") {
|
||||
Attribute attr;
|
||||
RETURN_ON_FAIL(parser.parseAttribute(attr))
|
||||
|
||||
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
|
||||
if (!affineAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an affine map for dimension ordering");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!affineAttr, "expected an affine map for dimension ordering")
|
||||
dimOrd = affineAttr.getValue();
|
||||
} else if (attrName == "higherOrdering") {
|
||||
Attribute attr;
|
||||
RETURN_ON_FAIL(parser.parseAttribute(attr))
|
||||
|
||||
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
|
||||
if (!affineAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an affine map for higher ordering");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!affineAttr, "expected an affine map for higher ordering")
|
||||
higherOrd = affineAttr.getValue();
|
||||
} else if (attrName == "pointerBitWidth") {
|
||||
Attribute attr;
|
||||
RETURN_ON_FAIL(parser.parseAttribute(attr))
|
||||
|
||||
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an integral pointer bitwidth");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!intAttr, "expected an integral pointer bitwidth")
|
||||
ptr = intAttr.getInt();
|
||||
} else if (attrName == "indexBitWidth") {
|
||||
Attribute attr;
|
||||
RETURN_ON_FAIL(parser.parseAttribute(attr))
|
||||
|
||||
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an integral index bitwidth");
|
||||
return {};
|
||||
}
|
||||
ERROR_IF(!intAttr, "expected an integral index bitwidth")
|
||||
ind = intAttr.getInt();
|
||||
} else if (attrName == "slice") {
|
||||
RETURN_ON_FAIL(parser.parseLSquare())
|
||||
@@ -298,6 +275,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
|
||||
RETURN_ON_FAIL(parser.parseRBrace())
|
||||
RETURN_ON_FAIL(parser.parseGreater())
|
||||
#undef ERROR_IF
|
||||
#undef RETURN_ON_FAIL
|
||||
|
||||
// Construct struct-like storage for attribute.
|
||||
@@ -367,18 +345,21 @@ LogicalResult SparseTensorEncodingAttr::verify(
|
||||
return emitError() << "unexpected mismatch in dimension slices and "
|
||||
"dimension level type size";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define RETURN_FAILURE_IF_FAILED(X) \
|
||||
if (failed(X)) { \
|
||||
return failure(); \
|
||||
}
|
||||
|
||||
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
function_ref<InFlightDiagnostic()> emitError) const {
|
||||
// Check structural integrity.
|
||||
if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
|
||||
getHigherOrdering(), getPointerBitWidth(),
|
||||
getIndexBitWidth(), getDimSlices())))
|
||||
return failure();
|
||||
RETURN_FAILURE_IF_FAILED(verify(
|
||||
emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(),
|
||||
getPointerBitWidth(), getIndexBitWidth(), getDimSlices()))
|
||||
// Check integrity with tensor type specifics. Dimension ordering is optional,
|
||||
// but we always should have dimension level types for the full rank.
|
||||
unsigned size = shape.size();
|
||||
@@ -435,23 +416,17 @@ static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) {
|
||||
|
||||
bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
|
||||
SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
|
||||
if (!enc)
|
||||
return false;
|
||||
|
||||
return isCOOType(enc, 0, /*isUnique=*/true);
|
||||
return enc && isCOOType(enc, 0, /*isUnique=*/true);
|
||||
}
|
||||
|
||||
unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
|
||||
unsigned rank = enc.getDimLevelType().size();
|
||||
if (rank <= 1)
|
||||
return rank;
|
||||
|
||||
const unsigned rank = enc.getDimLevelType().size();
|
||||
// We only consider COO region with at least two dimensions for the purpose
|
||||
// of AOS storage optimization.
|
||||
for (unsigned r = 0; r < rank - 1; r++) {
|
||||
if (isCOOType(enc, r, /*isUnique=*/false))
|
||||
return r;
|
||||
}
|
||||
if (rank > 1)
|
||||
for (unsigned r = 0; r < rank - 1; r++)
|
||||
if (isCOOType(enc, r, /*isUnique=*/false))
|
||||
return r;
|
||||
|
||||
return rank;
|
||||
}
|
||||
@@ -541,10 +516,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
|
||||
|
||||
Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
|
||||
std::optional<APInt> dim) const {
|
||||
std::optional<unsigned> intDim;
|
||||
if (dim)
|
||||
intDim = dim.value().getZExtValue();
|
||||
return getFieldType(kind, intDim);
|
||||
return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
|
||||
: std::nullopt);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -552,17 +525,12 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult isInBounds(uint64_t dim, Value tensor) {
|
||||
uint64_t rank = tensor.getType().cast<RankedTensorType>().getRank();
|
||||
if (dim >= rank)
|
||||
return failure();
|
||||
return success(); // in bounds
|
||||
return success(dim < tensor.getType().cast<RankedTensorType>().getRank());
|
||||
}
|
||||
|
||||
static LogicalResult isMatchingWidth(Value result, unsigned width) {
|
||||
Type etp = result.getType().cast<MemRefType>().getElementType();
|
||||
if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
|
||||
return success();
|
||||
return failure();
|
||||
const Type etp = result.getType().cast<MemRefType>().getElementType();
|
||||
return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
|
||||
}
|
||||
|
||||
static LogicalResult verifySparsifierGetterSetter(
|
||||
@@ -663,11 +631,8 @@ LogicalResult ToValuesOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult GetStorageSpecifierOp::verify() {
|
||||
if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
|
||||
getSpecifier(), getOperation()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
|
||||
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
|
||||
// Checks the result type
|
||||
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
|
||||
getResult().getType()) {
|
||||
@@ -692,11 +657,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult SetStorageSpecifierOp::verify() {
|
||||
if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
|
||||
getSpecifier(), getOperation()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
|
||||
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
|
||||
// Checks the input type
|
||||
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
|
||||
getValue().getType()) {
|
||||
@@ -748,59 +710,45 @@ LogicalResult BinaryOp::verify() {
|
||||
|
||||
// Check correct number of block arguments and return type for each
|
||||
// non-empty region.
|
||||
LogicalResult regionResult = success();
|
||||
if (!overlap.empty()) {
|
||||
regionResult = verifyNumBlockArgs(
|
||||
this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
|
||||
this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
|
||||
}
|
||||
if (!left.empty()) {
|
||||
regionResult =
|
||||
verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
RETURN_FAILURE_IF_FAILED(
|
||||
verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
|
||||
} else if (getLeftIdentity()) {
|
||||
if (leftType != outputType)
|
||||
return emitError("left=identity requires first argument to have the same "
|
||||
"type as the output");
|
||||
}
|
||||
if (!right.empty()) {
|
||||
regionResult = verifyNumBlockArgs(this, right, "right",
|
||||
TypeRange{rightType}, outputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
|
||||
this, right, "right", TypeRange{rightType}, outputType))
|
||||
} else if (getRightIdentity()) {
|
||||
if (rightType != outputType)
|
||||
return emitError("right=identity requires second argument to have the "
|
||||
"same type as the output");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult UnaryOp::verify() {
|
||||
Type inputType = getX().getType();
|
||||
Type outputType = getOutput().getType();
|
||||
LogicalResult regionResult = success();
|
||||
|
||||
// Check correct number of block arguments and return type for each
|
||||
// non-empty region.
|
||||
Region &present = getPresentRegion();
|
||||
if (!present.empty()) {
|
||||
regionResult = verifyNumBlockArgs(this, present, "present",
|
||||
TypeRange{inputType}, outputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
|
||||
this, present, "present", TypeRange{inputType}, outputType))
|
||||
}
|
||||
Region &absent = getAbsentRegion();
|
||||
if (!absent.empty()) {
|
||||
regionResult =
|
||||
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
RETURN_FAILURE_IF_FAILED(
|
||||
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -880,8 +828,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result,
|
||||
}
|
||||
|
||||
LogicalResult PushBackOp::verify() {
|
||||
Value n = getN();
|
||||
if (n) {
|
||||
if (Value n = getN()) {
|
||||
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
|
||||
if (nValue && nValue.value() < 1)
|
||||
return emitOpError("n must be not less than 1");
|
||||
@@ -972,32 +919,21 @@ LogicalResult ForeachOp::verify() {
|
||||
|
||||
LogicalResult ReduceOp::verify() {
|
||||
Type inputType = getX().getType();
|
||||
LogicalResult regionResult = success();
|
||||
|
||||
// Check correct number of block arguments and return type.
|
||||
Region &formula = getRegion();
|
||||
regionResult = verifyNumBlockArgs(this, formula, "reduce",
|
||||
TypeRange{inputType, inputType}, inputType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
|
||||
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
|
||||
this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SelectOp::verify() {
|
||||
Builder b(getContext());
|
||||
|
||||
Type inputType = getX().getType();
|
||||
Type boolType = b.getI1Type();
|
||||
LogicalResult regionResult = success();
|
||||
|
||||
// Check correct number of block arguments and return type.
|
||||
Region &formula = getRegion();
|
||||
regionResult = verifyNumBlockArgs(this, formula, "select",
|
||||
TypeRange{inputType}, boolType);
|
||||
if (failed(regionResult))
|
||||
return regionResult;
|
||||
|
||||
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
|
||||
TypeRange{inputType}, boolType))
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1025,15 +961,8 @@ LogicalResult SortOp::verify() {
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
LogicalResult result = checkTypes(getXs());
|
||||
if (failed(result))
|
||||
return result;
|
||||
|
||||
if (n)
|
||||
return checkTypes(getYs(), false);
|
||||
|
||||
return success();
|
||||
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
|
||||
return n ? checkTypes(getYs(), false) : success();
|
||||
}
|
||||
|
||||
LogicalResult SortCooOp::verify() {
|
||||
@@ -1084,6 +1013,8 @@ LogicalResult YieldOp::verify() {
|
||||
"reduce, select or foreach");
|
||||
}
|
||||
|
||||
#undef RETURN_FAILURE_IF_FAILED
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user