[mlir][sparse] IR/SparseTensorDialect.cpp: misc code cleanup

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D142072
This commit is contained in:
wren romano
2023-01-18 18:22:48 -08:00
parent 5e5d901feb
commit 743fbcb79d

View File

@@ -101,16 +101,18 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
<< "expect positive value or ? for slice offset/size/stride";
}
static Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth) {
if (bitwidth)
return IntegerType::get(ctx, bitwidth);
return IndexType::get(ctx);
}
Type SparseTensorEncodingAttr::getPointerType() const {
unsigned ptrWidth = getPointerBitWidth();
Type indexType = IndexType::get(getContext());
return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType;
return getIntegerOrIndexType(getContext(), getPointerBitWidth());
}
Type SparseTensorEncodingAttr::getIndexType() const {
unsigned idxWidth = getIndexBitWidth();
Type indexType = IndexType::get(getContext());
return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
return getIntegerOrIndexType(getContext(), getIndexBitWidth());
}
SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
@@ -157,11 +159,30 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const {
return getStaticDimSliceStride(toOrigDim(*this, lvl));
}
const static DimLevelType validDLTs[] = {
DimLevelType::Dense, DimLevelType::Compressed,
DimLevelType::CompressedNu, DimLevelType::CompressedNo,
DimLevelType::CompressedNuNo, DimLevelType::Singleton,
DimLevelType::SingletonNu, DimLevelType::SingletonNo,
DimLevelType::SingletonNuNo};
static std::optional<DimLevelType> parseDLT(StringRef str) {
for (DimLevelType dlt : validDLTs)
if (str == toMLIRString(dlt))
return dlt;
return std::nullopt;
}
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
#define RETURN_ON_FAIL(stmt) \
if (failed(stmt)) { \
return {}; \
}
#define ERROR_IF(COND, MSG) \
if (COND) { \
parser.emitError(parser.getNameLoc(), MSG); \
return {}; \
}
RETURN_ON_FAIL(parser.parseLess())
RETURN_ON_FAIL(parser.parseLBrace())
@@ -191,37 +212,13 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr));
auto arrayAttr = attr.dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(),
"expected an array for dimension level types");
return {};
}
ERROR_IF(!arrayAttr, "expected an array for dimension level types")
for (auto i : arrayAttr) {
auto strAttr = i.dyn_cast<StringAttr>();
if (!strAttr) {
parser.emitError(parser.getNameLoc(),
"expected a string value in dimension level types");
return {};
}
ERROR_IF(!strAttr, "expected a string value in dimension level types")
auto strVal = strAttr.getValue();
if (strVal == "dense") {
dlt.push_back(DimLevelType::Dense);
} else if (strVal == "compressed") {
dlt.push_back(DimLevelType::Compressed);
} else if (strVal == "compressed-nu") {
dlt.push_back(DimLevelType::CompressedNu);
} else if (strVal == "compressed-no") {
dlt.push_back(DimLevelType::CompressedNo);
} else if (strVal == "compressed-nu-no") {
dlt.push_back(DimLevelType::CompressedNuNo);
} else if (strVal == "singleton") {
dlt.push_back(DimLevelType::Singleton);
} else if (strVal == "singleton-nu") {
dlt.push_back(DimLevelType::SingletonNu);
} else if (strVal == "singleton-no") {
dlt.push_back(DimLevelType::SingletonNo);
} else if (strVal == "singleton-nu-no") {
dlt.push_back(DimLevelType::SingletonNuNo);
if (auto optDLT = parseDLT(strVal)) {
dlt.push_back(optDLT.value());
} else {
parser.emitError(parser.getNameLoc(),
"unexpected dimension level type: ")
@@ -232,46 +229,26 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
} else if (attrName == "dimOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
if (!affineAttr) {
parser.emitError(parser.getNameLoc(),
"expected an affine map for dimension ordering");
return {};
}
ERROR_IF(!affineAttr, "expected an affine map for dimension ordering")
dimOrd = affineAttr.getValue();
} else if (attrName == "higherOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
auto affineAttr = attr.dyn_cast<AffineMapAttr>();
if (!affineAttr) {
parser.emitError(parser.getNameLoc(),
"expected an affine map for higher ordering");
return {};
}
ERROR_IF(!affineAttr, "expected an affine map for higher ordering")
higherOrd = affineAttr.getValue();
} else if (attrName == "pointerBitWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(),
"expected an integral pointer bitwidth");
return {};
}
ERROR_IF(!intAttr, "expected an integral pointer bitwidth")
ptr = intAttr.getInt();
} else if (attrName == "indexBitWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(),
"expected an integral index bitwidth");
return {};
}
ERROR_IF(!intAttr, "expected an integral index bitwidth")
ind = intAttr.getInt();
} else if (attrName == "slice") {
RETURN_ON_FAIL(parser.parseLSquare())
@@ -298,6 +275,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
RETURN_ON_FAIL(parser.parseRBrace())
RETURN_ON_FAIL(parser.parseGreater())
#undef ERROR_IF
#undef RETURN_ON_FAIL
// Construct struct-like storage for attribute.
@@ -367,18 +345,21 @@ LogicalResult SparseTensorEncodingAttr::verify(
return emitError() << "unexpected mismatch in dimension slices and "
"dimension level type size";
}
return success();
}
#define RETURN_FAILURE_IF_FAILED(X) \
if (failed(X)) { \
return failure(); \
}
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<int64_t> shape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
// Check structural integrity.
if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
getHigherOrdering(), getPointerBitWidth(),
getIndexBitWidth(), getDimSlices())))
return failure();
RETURN_FAILURE_IF_FAILED(verify(
emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(),
getPointerBitWidth(), getIndexBitWidth(), getDimSlices()))
// Check integrity with tensor type specifics. Dimension ordering is optional,
// but we always should have dimension level types for the full rank.
unsigned size = shape.size();
@@ -435,23 +416,17 @@ static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) {
bool mlir::sparse_tensor::isUniqueCOOType(RankedTensorType tp) {
SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp);
if (!enc)
return false;
return isCOOType(enc, 0, /*isUnique=*/true);
return enc && isCOOType(enc, 0, /*isUnique=*/true);
}
unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
unsigned rank = enc.getDimLevelType().size();
if (rank <= 1)
return rank;
const unsigned rank = enc.getDimLevelType().size();
// We only consider COO region with at least two dimensions for the purpose
// of AOS storage optimization.
for (unsigned r = 0; r < rank - 1; r++) {
if (isCOOType(enc, r, /*isUnique=*/false))
return r;
}
if (rank > 1)
for (unsigned r = 0; r < rank - 1; r++)
if (isCOOType(enc, r, /*isUnique=*/false))
return r;
return rank;
}
@@ -541,10 +516,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
std::optional<APInt> dim) const {
std::optional<unsigned> intDim;
if (dim)
intDim = dim.value().getZExtValue();
return getFieldType(kind, intDim);
return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue())
: std::nullopt);
}
//===----------------------------------------------------------------------===//
@@ -552,17 +525,12 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind,
//===----------------------------------------------------------------------===//
static LogicalResult isInBounds(uint64_t dim, Value tensor) {
uint64_t rank = tensor.getType().cast<RankedTensorType>().getRank();
if (dim >= rank)
return failure();
return success(); // in bounds
return success(dim < tensor.getType().cast<RankedTensorType>().getRank());
}
static LogicalResult isMatchingWidth(Value result, unsigned width) {
Type etp = result.getType().cast<MemRefType>().getElementType();
if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
return success();
return failure();
const Type etp = result.getType().cast<MemRefType>().getElementType();
return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
}
static LogicalResult verifySparsifierGetterSetter(
@@ -663,11 +631,8 @@ LogicalResult ToValuesOp::verify() {
}
LogicalResult GetStorageSpecifierOp::verify() {
if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
getSpecifier(), getOperation()))) {
return failure();
}
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
// Checks the result type
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
getResult().getType()) {
@@ -692,11 +657,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
}
LogicalResult SetStorageSpecifierOp::verify() {
if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(),
getSpecifier(), getOperation()))) {
return failure();
}
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
// Checks the input type
if (getSpecifier().getType().getFieldType(getSpecifierKind(), getDim()) !=
getValue().getType()) {
@@ -748,59 +710,45 @@ LogicalResult BinaryOp::verify() {
// Check correct number of block arguments and return type for each
// non-empty region.
LogicalResult regionResult = success();
if (!overlap.empty()) {
regionResult = verifyNumBlockArgs(
this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
}
if (!left.empty()) {
regionResult =
verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(
verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
} else if (getLeftIdentity()) {
if (leftType != outputType)
return emitError("left=identity requires first argument to have the same "
"type as the output");
}
if (!right.empty()) {
regionResult = verifyNumBlockArgs(this, right, "right",
TypeRange{rightType}, outputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
this, right, "right", TypeRange{rightType}, outputType))
} else if (getRightIdentity()) {
if (rightType != outputType)
return emitError("right=identity requires second argument to have the "
"same type as the output");
}
return success();
}
LogicalResult UnaryOp::verify() {
Type inputType = getX().getType();
Type outputType = getOutput().getType();
LogicalResult regionResult = success();
// Check correct number of block arguments and return type for each
// non-empty region.
Region &present = getPresentRegion();
if (!present.empty()) {
regionResult = verifyNumBlockArgs(this, present, "present",
TypeRange{inputType}, outputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
this, present, "present", TypeRange{inputType}, outputType))
}
Region &absent = getAbsentRegion();
if (!absent.empty()) {
regionResult =
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
}
return success();
}
@@ -880,8 +828,7 @@ void PushBackOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult PushBackOp::verify() {
Value n = getN();
if (n) {
if (Value n = getN()) {
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
if (nValue && nValue.value() < 1)
return emitOpError("n must be not less than 1");
@@ -972,32 +919,21 @@ LogicalResult ForeachOp::verify() {
LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
LogicalResult regionResult = success();
// Check correct number of block arguments and return type.
Region &formula = getRegion();
regionResult = verifyNumBlockArgs(this, formula, "reduce",
TypeRange{inputType, inputType}, inputType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
return success();
}
LogicalResult SelectOp::verify() {
Builder b(getContext());
Type inputType = getX().getType();
Type boolType = b.getI1Type();
LogicalResult regionResult = success();
// Check correct number of block arguments and return type.
Region &formula = getRegion();
regionResult = verifyNumBlockArgs(this, formula, "select",
TypeRange{inputType}, boolType);
if (failed(regionResult))
return regionResult;
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
TypeRange{inputType}, boolType))
return success();
}
@@ -1025,15 +961,8 @@ LogicalResult SortOp::verify() {
}
return success();
};
LogicalResult result = checkTypes(getXs());
if (failed(result))
return result;
if (n)
return checkTypes(getYs(), false);
return success();
RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
return n ? checkTypes(getYs(), false) : success();
}
LogicalResult SortCooOp::verify() {
@@ -1084,6 +1013,8 @@ LogicalResult YieldOp::verify() {
"reduce, select or foreach");
}
#undef RETURN_FAILURE_IF_FAILED
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//