[MLIR] Error handling in MaterializeVectors

This removes assertions as a means to capture NYI behavior and propagates
errors up.

PiperOrigin-RevId: 224376935
This commit is contained in:
Nicolas Vasilache
2018-12-06 11:37:53 -08:00
committed by jpienaar
parent 4adc169bd0
commit 5b610630b2

View File

@@ -209,23 +209,25 @@ struct MaterializeVectors : public FunctionPass {
char MaterializeVectors::passID = 0;
// Returns the distance, in number of elements, between a slice in a dimension
// and the next slice in the same dimension.
// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
/// Given a shape with sizes greater than 0 along all dimensions,
/// returns the distance, in number of elements, between a slice in a dimension
/// and the next slice in the same dimension.
/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
SmallVector<unsigned, 8> tmp;
tmp.reserve(shape.size());
unsigned running = 1;
for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
// TODO(ntv): emitError instead of NYI assert.
assert(*rit > 0 && "NYI: symbolic or null shape dimension");
assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
"shape");
tmp.push_back(running);
running *= *rit;
}
return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
}
// Returns the linearized expression.
/// Given a shape with sizes greater than 0 along all dimensions, returns the
/// delinearized components of linearIndex along shape.
static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
ArrayRef<unsigned> shape) {
SmallVector<unsigned, 8> res;
@@ -256,6 +258,8 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
/// insertion.
/// For now, this is limited to ConstantOp because we do not vectorize loop
/// indices and will need to be extended in the future.
///
/// If substitution fails, returns nullptr.
static MLValue *
substitute(SSAValue *v, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
@@ -271,7 +275,7 @@ substitute(SSAValue *v, VectorType hwVectorType,
return res.first->second;
}
v->getDefiningOperation()->emitError("Missing substitution");
assert(false);
return nullptr;
}
return it->second;
}
@@ -400,6 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
/// affine reindexing. Just substitute their SSAValue* operands and be done. For
/// this case the actual instance is irrelevant. Just use the SSA values in
/// substitutionsMap.
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
@@ -407,11 +413,18 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
"Should call the function specialized for VectorTransferReadOp");
assert(!opStmt->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp");
bool fail = false;
auto operands = map(
[hwVectorType, substitutionsMap](SSAValue *v) {
return substitute(v, hwVectorType, substitutionsMap);
[hwVectorType, substitutionsMap, &fail](SSAValue *v) {
auto *res =
fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
fail |= !res;
return res;
},
opStmt->getOperands());
if (fail) {
return nullptr;
}
auto attrs = materializeAttributes(opStmt, hwVectorType);
return b->createOperation(opStmt->getLoc(), opStmt->getName(), operands,
{hwVectorType}, attrs);
@@ -452,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
auto permutationMap = transfer->getPermutationMap();
LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: "));
LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
return composeUnboundedMaps(projectionMap, transfer->getPermutationMap());
return composeUnboundedMaps(projectionMap, permutationMap);
}
/// Creates an instantiated version of `read` for the instance of
@@ -516,6 +529,8 @@ instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
/// type, all operands are substituted according to `substitutions`. Thanks
/// to the topological order of a slice, the substitution is always
/// possible.
///
/// Returns true on failure.
static bool instantiateMaterialization(Statement *stmt,
MaterializationState *state) {
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
@@ -557,6 +572,9 @@ static bool instantiateMaterialization(Statement *stmt,
}
auto *clone =
instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
if (!clone) {
return true;
}
state->substitutionsMap->insert(std::make_pair(
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
return false;
@@ -578,10 +596,12 @@ static bool instantiateMaterialization(Statement *stmt,
/// equivalent of loop strip-mining + loop sinking and encoded this in the
/// vector type.
///
/// Returns true on failure.
///
/// TODO(ntv): materialized allocs.
/// TODO(ntv): full loops + materialized allocs.
/// TODO(ntv): partial unrolling + materialized allocs.
static void emitSlice(MaterializationState *state,
static bool emitSlice(MaterializationState *state,
SetVector<Statement *> *slice) {
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
assert(ratio.hasValue() &&
@@ -601,7 +621,7 @@ static void emitSlice(MaterializationState *state,
auto fail = instantiateMaterialization(stmt, &scopedState);
if (fail) {
stmt->emitError("Unhandled super-vector materialization failure");
assert(!fail);
return true;
}
}
}
@@ -618,6 +638,7 @@ static void emitSlice(MaterializationState *state,
LLVM_DEBUG((*slice)[idx]->print(dbgs()));
(*slice)[idx]->erase();
}
return false;
}
/// Materializes super-vector types into concrete hw vector types as follows:
@@ -637,7 +658,7 @@ static void emitSlice(MaterializationState *state,
/// Additionally, this set is limited to statements in the same lexical scope
/// because we currently disallow vectorization of defs that come from another
/// scope.
static void materialize(MLFunction *f,
static bool materialize(MLFunction *f,
const SetVector<OperationStmt *> &terminators,
MaterializationState *state) {
DenseSet<Statement *> seen;
@@ -686,10 +707,14 @@ static void materialize(MLFunction *f,
"Only f32 supported for now");
state->hwVectorType = VectorType::get(
state->hwVectorSize, state->superVectorType.getElementType());
emitSlice(state, &slice);
auto fail = emitSlice(state, &slice);
if (fail) {
return true;
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG(f->print(dbgs()));
}
return false;
}
PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
@@ -720,9 +745,9 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
}
// Call materialization.
materialize(f, terminators, &state);
auto fail = materialize(f, terminators, &state);
return PassResult::Success;
return fail ? PassResult::Failure : PassResult::Success;
}
FunctionPass *mlir::createMaterializeVectors() {