[mlir][CAPI] Add result type inference to the CAPI.

* Adds a flag to MlirOperationState to enable result type inference using the InferTypeOpInterface.
* I chose this level of implementation for a couple of reasons:
  a) In the creation flow is naturally where generated and custom builder code will be invoking such a thing
  b) it is a bit more efficient to share the data structure and unpacking vs having a standalone entry-point
  c) we can always decide to expose more of these interfaces with first-class APIs, but that doesn't preclude that we will always want to use this one in this way (and less API surface area for common things is better for API stability and evolution).
* I struggled to find an appropriate way to test it since we don't link the test dialect into anything CAPI accessible at present. I opted instead for one of the simplest ops I found in a regular dialect which implements the interface.
* This does not do any trait-based type selection. That will be left to generated tablegen wrappers.

Differential Revision: https://reviews.llvm.org/D95283
This commit is contained in:
Stella Laurenzo
2021-01-22 18:43:50 -08:00
parent 6f2753273e
commit 52586c46b0
3 changed files with 106 additions and 15 deletions

View File

@@ -225,6 +225,7 @@ struct MlirOperationState {
MlirBlock *successors;
intptr_t nAttributes;
MlirNamedAttribute *attributes;
bool enableResultTypeInference;
};
typedef struct MlirOperationState MlirOperationState;
@@ -249,6 +250,14 @@ MLIR_CAPI_EXPORTED void
mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
MlirNamedAttribute const *attributes);
/// Enables result type inference for the operation under construction. If
/// enabled, then the caller must not have called
/// mlirOperationStateAddResults(). Note that if enabled, the
/// mlirOperationCreate() call is failable: it will return a null operation
/// on inference failure and will emit diagnostics.
MLIR_CAPI_EXPORTED void
mlirOperationStateEnableResultTypeInference(MlirOperationState *state);
//===----------------------------------------------------------------------===//
// Op Printing flags API.
// While many of these are simple settings that could be represented in a
@@ -293,8 +302,14 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
//===----------------------------------------------------------------------===//
/// Creates an operation and transfers ownership to the caller.
MLIR_CAPI_EXPORTED MlirOperation
mlirOperationCreate(const MlirOperationState *state);
/// Note that caller owned child objects are transferred in this call and must
/// not be further used. Particularly, this applies to any regions added to
/// the state (the implementation may invalidate any such pointers).
///
/// This call can fail under the following conditions, in which case, it will
/// return a null operation and emit diagnostics:
/// - Result type inference is enabled and cannot be performed.
MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state);
/// Takes an operation owned by the caller and destroys it.
MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op);

View File

@@ -18,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser.h"
using namespace mlir;
@@ -188,6 +189,7 @@ MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
state.successors = nullptr;
state.nAttributes = 0;
state.attributes = nullptr;
state.enableResultTypeInference = false;
return state;
}
@@ -219,11 +221,47 @@ void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
}
void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
state->enableResultTypeInference = true;
}
//===----------------------------------------------------------------------===//
// Operation API.
//===----------------------------------------------------------------------===//
MlirOperation mlirOperationCreate(const MlirOperationState *state) {
static LogicalResult inferOperationTypes(OperationState &state) {
MLIRContext *context = state.getContext();
const AbstractOperation *abstractOp =
AbstractOperation::lookup(state.name.getStringRef(), context);
if (!abstractOp) {
emitError(state.location)
<< "type inference was requested for the operation " << state.name
<< ", but the operation was not registered. Ensure that the dialect "
"containing the operation is linked into MLIR and registered with "
"the context";
return failure();
}
// Fallback to inference via an op interface.
auto *inferInterface = abstractOp->getInterface<InferTypeOpInterface>();
if (!inferInterface) {
emitError(state.location)
<< "type inference was requested for the operation " << state.name
<< ", but the operation does not support type inference. Result "
"types must be specified explicitly.";
return failure();
}
if (succeeded(inferInterface->inferReturnTypes(
context, state.location, state.operands,
state.attributes.getDictionary(context), state.regions, state.types)))
return success();
// Diagnostic emitted by interface.
return failure();
}
MlirOperation mlirOperationCreate(MlirOperationState *state) {
assert(state);
OperationState cppState(unwrap(state->location), unwrap(state->name));
SmallVector<Type, 4> resultStorage;
@@ -243,12 +281,21 @@ MlirOperation mlirOperationCreate(const MlirOperationState *state) {
for (intptr_t i = 0; i < state->nRegions; ++i)
cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
MlirOperation result = wrap(Operation::create(cppState));
free(state->results);
free(state->operands);
free(state->successors);
free(state->regions);
free(state->attributes);
// Infer result types.
if (state->enableResultTypeInference) {
assert(cppState.types.empty() &&
"result type inference enabled and result types provided");
if (failed(inferOperationTypes(cppState)))
return {nullptr};
}
MlirOperation result = wrap(Operation::create(cppState));
return result;
}

View File

@@ -553,6 +553,35 @@ static void buildWithInsertionsAndPrint(MlirContext ctx) {
// clang-format on
}
/// Creates operations with type inference and tests various failure modes.
static int createOperationWithTypeInference(MlirContext ctx) {
MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
// The shape.const_size op implements result type inference and is only used
// for that reason.
MlirOperationState state = mlirOperationStateGet(
mlirStringRefCreateFromCString("shape.const_size"), loc);
MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
mlirOperationStateAddAttributes(&state, 1, &valueAttr);
mlirOperationStateEnableResultTypeInference(&state);
// Expect result type inference to succeed.
MlirOperation op = mlirOperationCreate(&state);
if (mlirOperationIsNull(op)) {
fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
return 1;
}
// CHECK: RESULT_TYPE_INFERENCE: !shape.size
fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
fprintf(stderr, "\n");
mlirOperationDestroy(op);
return 0;
}
/// Dumps instances of all builtin types to check that C API works correctly.
/// Additionally, performs simple identity checks that a builtin type
/// constructed with C API can be inspected and has the expected type. The
@@ -957,14 +986,12 @@ int printBuiltinAttributes(MlirContext ctx) {
(uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
int64_t *int64RawData =
(int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
float *floatRawData =
(float *)mlirDenseElementsAttrGetRawData(floatElements);
float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
double *doubleRawData =
(double *)mlirDenseElementsAttrGetRawData(doubleElements);
if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
int32RawData[0] != 0 || int32RawData[1] != 1 ||
uint64RawData[0] != 0u || uint64RawData[1] != 1u ||
int64RawData[0] != 0 || int64RawData[1] != 1 ||
int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0)
return 18;
@@ -1389,19 +1416,21 @@ int main() {
if (constructAndTraverseIr(ctx))
return 1;
buildWithInsertionsAndPrint(ctx);
if (createOperationWithTypeInference(ctx))
return 2;
if (printBuiltinTypes(ctx))
return 2;
if (printBuiltinAttributes(ctx))
return 3;
if (printAffineMap(ctx))
if (printBuiltinAttributes(ctx))
return 4;
if (printAffineExpr(ctx))
if (printAffineMap(ctx))
return 5;
if (affineMapFromExprs(ctx))
if (printAffineExpr(ctx))
return 6;
if (registerOnlyStd())
if (affineMapFromExprs(ctx))
return 7;
if (registerOnlyStd())
return 8;
mlirContextDestroy(ctx);