[mlir] Expose printing functions in C API

Provide printing functions for most IR objects in C API (except Region that
does not have a `print` function, and Module that is expected to be printed as
Operation instead). The printing is based on a callback that is called with
chunks of the string representation and forwarded user-defined data.

Reviewed By: stellaraccident, Jing, mehdi_amini

Differential Revision: https://reviews.llvm.org/D85748
This commit is contained in:
Alex Zinenko
2020-08-11 18:25:09 +02:00
parent 3b0a4e9584
commit 321aa19ec8
4 changed files with 199 additions and 7 deletions

View File

@@ -71,10 +71,31 @@ owned by the `MlirContext` in which they were created.
### Nullity
A handle may refer to a _null_ object. It is the responsibility of the caller to
check if an object is null by using `MlirXIsNull(MlirX)`. API functions do _not_
check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
expect null objects as arguments unless explicitly stated otherwise. API
functions _may_ return null objects.
### Conversion To String and Printing
IR objects can be converted to a string representation, for example for
printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These
functions accept take arguments a callback with signature `void (*)(const char
*, intptr_t, void *)` and a pointer to user-defined data. They call the callback
and supply it with chunks of the string representation, provided as a pointer to
the first character and a length, and forward the user-defined data unmodified.
It is up to the caller to allocate memory if the string representation must be
stored and perform the copy. There is no guarantee that the pointer supplied to
the callback points to a null-terminated string, the size argument should be
used to find the end of the string. The callback may be called multiple times
with consecutive chunks of the string representation (the printing itself is
bufferred).
*Rationale*: this approach allows the caller to have full control of the
allocation and avoid unnecessary allocation and copying inside the printer.
For convenience, `mlirXDump(MlirX)` functions are provided to print the given
object to the standard error stream.
### Common Patterns
The API adopts the following patterns for recurrent functionality in MLIR.

View File

@@ -60,7 +60,7 @@ DEFINE_C_API_STRUCT(MlirModule, const void);
/** Named MLIR attribute.
*
* A named attribute is essentially a (name, attrbute) pair where the name is
* A named attribute is essentially a (name, attribute) pair where the name is
* a string.
*/
struct MlirNamedAttribute {
@@ -69,6 +69,17 @@ struct MlirNamedAttribute {
};
typedef struct MlirNamedAttribute MlirNamedAttribute;
/** A callback for printing to IR objects.
*
* This function is called back by the printing functions with the following
* arguments:
* - a pointer to the beginning of a string;
* - the length of the string (the pointer may point to a larger buffer, not
* necessarily null-terminated);
* - a pointer to user data forwarded from the printing call.
*/
typedef void (*MlirPrintCallback)(const char *, intptr_t, void *);
/*============================================================================*/
/* Context API. */
/*============================================================================*/
@@ -91,6 +102,12 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
/** Creates a location with unknown position owned by the given context. */
MlirLocation mlirLocationUnknownGet(MlirContext context);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
void *userData);
/*============================================================================*/
/* Module API. */
/*============================================================================*/
@@ -202,6 +219,14 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
/** Returns an attrbute attached to the operation given its name. */
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
const char *name);
/** Prints an operation by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
void *userData);
/** Prints an operation to stderr. */
void mlirOperationDump(MlirOperation op);
/*============================================================================*/
@@ -263,6 +288,12 @@ intptr_t mlirBlockGetNumArguments(MlirBlock block);
/** Returns `pos`-th argument of the block. */
MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
/** Prints a block by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
void *userData);
/*============================================================================*/
/* Value API. */
/*============================================================================*/
@@ -270,6 +301,12 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);
/** Prints a value by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
void *userData);
/*============================================================================*/
/* Type API. */
/*============================================================================*/
@@ -277,6 +314,11 @@ MlirType mlirValueGetType(MlirValue value);
/** Parses a type. The type is owned by the context. */
MlirType mlirTypeParseGet(MlirContext context, const char *type);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData);
/** Prints the type to the standard error stream. */
void mlirTypeDump(MlirType type);
@@ -287,6 +329,12 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
/** Prints an attribute by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
void *userData);
/** Prints the attrbute to the standard error stream. */
void mlirAttributeDump(MlirAttribute attr);

View File

@@ -13,6 +13,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Parser.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -56,6 +57,33 @@ static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
return storage;
}
/* ========================================================================== */
/* Printing helper. */
/* ========================================================================== */
namespace {
/// A simple raw ostream subclass that forwards write_impl calls to the
/// user-supplied callback together with opaque user-supplied data.
class CallbackOstream : public llvm::raw_ostream {
public:
CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
void *opaqueData)
: callback(callback), opaqueData(opaqueData), pos(0u) {}
void write_impl(const char *ptr, size_t size) override {
callback(ptr, size, opaqueData);
pos += size;
}
uint64_t current_pos() const override { return pos; }
private:
std::function<void(const char *, intptr_t, void *)> callback;
void *opaqueData;
uint64_t pos;
};
} // end namespace
/* ========================================================================== */
/* Context API. */
/* ========================================================================== */
@@ -81,6 +109,13 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(location).print(stream);
stream.flush();
}
/* ========================================================================== */
/* Module API. */
/* ========================================================================== */
@@ -239,6 +274,13 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name));
}
void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(op)->print(stream);
stream.flush();
}
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
/* ========================================================================== */
@@ -314,6 +356,13 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
}
void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(block)->print(stream);
stream.flush();
}
/* ========================================================================== */
/* Value API. */
/* ========================================================================== */
@@ -322,6 +371,13 @@ MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(value).print(stream);
stream.flush();
}
/* ========================================================================== */
/* Type API. */
/* ========================================================================== */
@@ -330,6 +386,12 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
return wrap(mlir::parseType(type, unwrap(context)));
}
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
CallbackOstream stream(callback, userData);
unwrap(type).print(stream);
stream.flush();
}
void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
/* ========================================================================== */
@@ -340,6 +402,13 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context)));
}
void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(attr).print(stream);
stream.flush();
}
void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {

View File

@@ -197,11 +197,47 @@ void collectStats(MlirOperation operation) {
head = next;
} while (head);
printf("Number of operations: %u\n", stats.numOperations);
printf("Number of attributes: %u\n", stats.numAttributes);
printf("Number of blocks: %u\n", stats.numBlocks);
printf("Number of regions: %u\n", stats.numRegions);
printf("Number of values: %u\n", stats.numValues);
fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
fprintf(stderr, "Number of values: %u\n", stats.numValues);
}
static void printToStderr(const char *str, intptr_t len, void *userData) {
(void)userData;
fwrite(str, 1, len, stderr);
}
static void printFirstOfEach(MlirOperation operation) {
// Assuming we are given a module, go to the first operation of the first
// function.
MlirRegion region = mlirOperationGetRegion(operation, 0);
MlirBlock block = mlirRegionGetFirstBlock(region);
operation = mlirBlockGetFirstOperation(block);
region = mlirOperationGetRegion(operation, 0);
block = mlirRegionGetFirstBlock(region);
operation = mlirBlockGetFirstOperation(block);
// In the module we created, the first operation of the first function is an
// "std.dim", which has an attribute an a single result that we can use to
// test the printing mechanism.
mlirBlockPrint(block, printToStderr, NULL);
fprintf(stderr, "\n");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
fprintf(stderr, "\n");
MlirValue value = mlirOperationGetResult(operation, 0);
mlirValuePrint(value, printToStderr, NULL);
fprintf(stderr, "\n");
MlirType type = mlirValueGetType(value);
mlirTypePrint(type, printToStderr, NULL);
fprintf(stderr, "\n");
}
int main() {
@@ -238,6 +274,24 @@ int main() {
// CHECK: Number of values: 9
// clang-format on
printFirstOfEach(module);
// clang-format off
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
// CHECK: %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
// CHECK: %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
// CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
// CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
// CHECK: }
// CHECK: return
// CHECK: constant 0 : index
// CHECK: 0 : index
// CHECK: constant 0 : index
// CHECK: index
// clang-format on
mlirModuleDestroy(moduleOp);
mlirContextDestroy(ctx);