mirror of
https://github.com/intel/llvm.git
synced 2026-02-04 03:26:06 +08:00
[MLIR] Error handling in MaterializeVectors
This removes assertions as a means to capture NYI behavior and propagates errors up. PiperOrigin-RevId: 224376935
This commit is contained in:
committed by
jpienaar
parent
4adc169bd0
commit
5b610630b2
@@ -209,23 +209,25 @@ struct MaterializeVectors : public FunctionPass {
|
||||
|
||||
char MaterializeVectors::passID = 0;
|
||||
|
||||
// Returns the distance, in number of elements, between a slice in a dimension
|
||||
// and the next slice in the same dimension.
|
||||
// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
|
||||
/// Given a shape with sizes greater than 0 along all dimensions,
|
||||
/// returns the distance, in number of elements, between a slice in a dimension
|
||||
/// and the next slice in the same dimension.
|
||||
/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
|
||||
static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
|
||||
SmallVector<unsigned, 8> tmp;
|
||||
tmp.reserve(shape.size());
|
||||
unsigned running = 1;
|
||||
for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
|
||||
// TODO(ntv): emitError instead of NYI assert.
|
||||
assert(*rit > 0 && "NYI: symbolic or null shape dimension");
|
||||
assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
|
||||
"shape");
|
||||
tmp.push_back(running);
|
||||
running *= *rit;
|
||||
}
|
||||
return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
|
||||
}
|
||||
|
||||
// Returns the linearized expression.
|
||||
/// Given a shape with sizes greater than 0 along all dimensions, returns the
|
||||
/// delinearized components of linearIndex along shape.
|
||||
static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
|
||||
ArrayRef<unsigned> shape) {
|
||||
SmallVector<unsigned, 8> res;
|
||||
@@ -256,6 +258,8 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
/// insertion.
|
||||
/// For now, this is limited to ConstantOp because we do not vectorize loop
|
||||
/// indices and will need to be extended in the future.
|
||||
///
|
||||
/// If substitution fails, returns nullptr.
|
||||
static MLValue *
|
||||
substitute(SSAValue *v, VectorType hwVectorType,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
@@ -271,7 +275,7 @@ substitute(SSAValue *v, VectorType hwVectorType,
|
||||
return res.first->second;
|
||||
}
|
||||
v->getDefiningOperation()->emitError("Missing substitution");
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
@@ -400,6 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
|
||||
/// affine reindexing. Just substitute their SSAValue* operands and be done. For
|
||||
/// this case the actual instance is irrelevant. Just use the SSA values in
|
||||
/// substitutionsMap.
|
||||
///
|
||||
/// If the underlying substitution fails, this fails too and returns nullptr.
|
||||
static OperationStmt *
|
||||
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
|
||||
@@ -407,11 +413,18 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
|
||||
"Should call the function specialized for VectorTransferReadOp");
|
||||
assert(!opStmt->isa<VectorTransferWriteOp>() &&
|
||||
"Should call the function specialized for VectorTransferWriteOp");
|
||||
bool fail = false;
|
||||
auto operands = map(
|
||||
[hwVectorType, substitutionsMap](SSAValue *v) {
|
||||
return substitute(v, hwVectorType, substitutionsMap);
|
||||
[hwVectorType, substitutionsMap, &fail](SSAValue *v) {
|
||||
auto *res =
|
||||
fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
|
||||
fail |= !res;
|
||||
return res;
|
||||
},
|
||||
opStmt->getOperands());
|
||||
if (fail) {
|
||||
return nullptr;
|
||||
}
|
||||
auto attrs = materializeAttributes(opStmt, hwVectorType);
|
||||
return b->createOperation(opStmt->getLoc(), opStmt->getName(), operands,
|
||||
{hwVectorType}, attrs);
|
||||
@@ -452,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
|
||||
auto permutationMap = transfer->getPermutationMap();
|
||||
LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: "));
|
||||
LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
|
||||
return composeUnboundedMaps(projectionMap, transfer->getPermutationMap());
|
||||
return composeUnboundedMaps(projectionMap, permutationMap);
|
||||
}
|
||||
|
||||
/// Creates an instantiated version of `read` for the instance of
|
||||
@@ -516,6 +529,8 @@ instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
|
||||
/// type, all operands are substituted according to `substitutions`. Thanks
|
||||
/// to the topological order of a slice, the substitution is always
|
||||
/// possible.
|
||||
///
|
||||
/// Returns true on failure.
|
||||
static bool instantiateMaterialization(Statement *stmt,
|
||||
MaterializationState *state) {
|
||||
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
|
||||
@@ -557,6 +572,9 @@ static bool instantiateMaterialization(Statement *stmt,
|
||||
}
|
||||
auto *clone =
|
||||
instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
|
||||
if (!clone) {
|
||||
return true;
|
||||
}
|
||||
state->substitutionsMap->insert(std::make_pair(
|
||||
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
|
||||
return false;
|
||||
@@ -578,10 +596,12 @@ static bool instantiateMaterialization(Statement *stmt,
|
||||
/// equivalent of loop strip-mining + loop sinking and encoded this in the
|
||||
/// vector type.
|
||||
///
|
||||
/// Returns true on failure.
|
||||
///
|
||||
/// TODO(ntv): materialized allocs.
|
||||
/// TODO(ntv): full loops + materialized allocs.
|
||||
/// TODO(ntv): partial unrolling + materialized allocs.
|
||||
static void emitSlice(MaterializationState *state,
|
||||
static bool emitSlice(MaterializationState *state,
|
||||
SetVector<Statement *> *slice) {
|
||||
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
|
||||
assert(ratio.hasValue() &&
|
||||
@@ -601,7 +621,7 @@ static void emitSlice(MaterializationState *state,
|
||||
auto fail = instantiateMaterialization(stmt, &scopedState);
|
||||
if (fail) {
|
||||
stmt->emitError("Unhandled super-vector materialization failure");
|
||||
assert(!fail);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -618,6 +638,7 @@ static void emitSlice(MaterializationState *state,
|
||||
LLVM_DEBUG((*slice)[idx]->print(dbgs()));
|
||||
(*slice)[idx]->erase();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Materializes super-vector types into concrete hw vector types as follows:
|
||||
@@ -637,7 +658,7 @@ static void emitSlice(MaterializationState *state,
|
||||
/// Additionally, this set is limited to statements in the same lexical scope
|
||||
/// because we currently disallow vectorization of defs that come from another
|
||||
/// scope.
|
||||
static void materialize(MLFunction *f,
|
||||
static bool materialize(MLFunction *f,
|
||||
const SetVector<OperationStmt *> &terminators,
|
||||
MaterializationState *state) {
|
||||
DenseSet<Statement *> seen;
|
||||
@@ -686,10 +707,14 @@ static void materialize(MLFunction *f,
|
||||
"Only f32 supported for now");
|
||||
state->hwVectorType = VectorType::get(
|
||||
state->hwVectorSize, state->superVectorType.getElementType());
|
||||
emitSlice(state, &slice);
|
||||
auto fail = emitSlice(state, &slice);
|
||||
if (fail) {
|
||||
return true;
|
||||
}
|
||||
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
|
||||
LLVM_DEBUG(f->print(dbgs()));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
|
||||
@@ -720,9 +745,9 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
|
||||
}
|
||||
|
||||
// Call materialization.
|
||||
materialize(f, terminators, &state);
|
||||
auto fail = materialize(f, terminators, &state);
|
||||
|
||||
return PassResult::Success;
|
||||
return fail ? PassResult::Failure : PassResult::Success;
|
||||
}
|
||||
|
||||
FunctionPass *mlir::createMaterializeVectors() {
|
||||
|
||||
Reference in New Issue
Block a user