[mlir][bytecode] Return error instead of min version

Can't return a well-formed IR output while enabling version to be bumped
up during emission. Previously it would return min version but
potentially invalid IR which was confusing, instead make it return
error and abort immediately instead.

Differential Revision: https://reviews.llvm.org/D149569
This commit is contained in:
Jacques Pienaar
2023-04-30 22:11:02 -07:00
parent fddd28364c
commit 5c90e1ffb0
9 changed files with 37 additions and 69 deletions

View File

@@ -565,27 +565,16 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
MlirStringCallback callback,
void *userData);
struct MlirBytecodeWriterResult {
int64_t minVersion;
};
typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult;
/// Same as mlirOperationPrint but writing the bytecode format.
MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op,
MlirStringCallback callback,
void *userData);
inline static bool
mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) {
return res.minVersion;
}
/// Same as mlirOperationPrint but writing the bytecode format and returns the
/// minimum bytecode version the consumer needs to support.
MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode(
MlirOperation op, MlirStringCallback callback, void *userData);
/// Same as mlirOperationWriteBytecode but with writer config.
MLIR_CAPI_EXPORTED MlirBytecodeWriterResult
mlirOperationWriteBytecodeWithConfig(MlirOperation op,
MlirBytecodeWriterConfig config,
MlirStringCallback callback,
void *userData);
/// Same as mlirOperationWriteBytecode but with writer config and returns
/// failure only if desired bytecode could not be honored.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
MlirOperation op, MlirBytecodeWriterConfig config,
MlirStringCallback callback, void *userData);
/// Prints an operation to stderr.
MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);

View File

@@ -75,21 +75,15 @@ private:
std::unique_ptr<Impl> impl;
};
/// Status of bytecode serialization.
struct BytecodeWriterResult {
/// The minimum version of the reader required to read the serialized file.
int64_t minVersion;
};
//===----------------------------------------------------------------------===//
// Entry Points
//===----------------------------------------------------------------------===//
/// Write the bytecode for the given operation to the provided output stream.
/// For streams where it matters, the given stream should be in "binary" mode.
BytecodeWriterResult
writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config = {});
/// It only ever fails if setDesiredByteCodeVersion can't be honored.
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config = {});
} // namespace mlir

View File

@@ -16,6 +16,7 @@
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -1134,9 +1135,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsDestroy(flags);
}
MlirBytecodeWriterResult
PyOperationBase::writeBytecode(const py::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
void PyOperationBase::writeBytecode(const py::object &fileObject,
std::optional<int64_t> bytecodeVersion) {
PyOperation &operation = getOperation();
operation.checkValid();
PyFileAccumulator accum(fileObject, /*binary=*/true);
@@ -1147,8 +1147,12 @@ PyOperationBase::writeBytecode(const py::object &fileObject,
MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
return mlirOperationWriteBytecodeWithConfig(
MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
operation, config, accum.getCallback(), accum.getUserData());
if (mlirLogicalResultIsFailure(res))
throw py::value_error((Twine("Unable to honor desired bytecode version ") +
Twine(*bytecodeVersion))
.str());
}
py::object PyOperationBase::getAsm(bool binary,
@@ -3378,10 +3382,6 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("from_op"), py::arg("all_sym_uses_visible"),
py::arg("callback"));
py::class_<MlirBytecodeWriterResult>(m, "BytecodeResult", py::module_local())
.def("min_version",
[](MlirBytecodeWriterResult &res) { return res.minVersion; });
// Container bindings.
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);

View File

@@ -554,9 +554,8 @@ public:
bool assumeVerified);
// Implement the bound 'writeBytecode' method.
MlirBytecodeWriterResult
writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
void writeBytecode(const pybind11::object &fileObject,
std::optional<int64_t> bytecodeVersion);
/// Moves the operation before or after the other operation.
void moveAfter(PyOperationBase &other);

View File

@@ -887,12 +887,10 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
// Entry Points
//===----------------------------------------------------------------------===//
BytecodeWriterResult
mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
BytecodeWriter writer(op, config.getImpl());
writer.write(op, os);
// Return the bytecode version emitted - currently there is no additional
// feedback as to minimum beyond the requested one.
return {config.getImpl().bytecodeVersion};
// Currently there is no failure case.
return success();
}

View File

@@ -524,25 +524,18 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
unwrap(op)->print(stream, *unwrap(flags));
}
MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op,
MlirStringCallback callback,
void *userData) {
void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
MlirBytecodeWriterResult res;
BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream);
res.minVersion = r.minVersion;
return res;
// As no desired version is set, no failure can occur.
(void)writeBytecodeToFile(unwrap(op), stream);
}
MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig(
MlirLogicalResult mlirOperationWriteBytecodeWithConfig(
MlirOperation op, MlirBytecodeWriterConfig config,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
BytecodeWriterResult r =
writeBytecodeToFile(unwrap(op), stream, *unwrap(config));
MlirBytecodeWriterResult res;
res.minVersion = r.minVersion;
return res;
return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config)));
}
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }

View File

@@ -885,7 +885,8 @@ MLIRDocument::convertToBytecode() {
std::string rawBytecodeBuffer;
llvm::raw_string_ostream os(rawBytecodeBuffer);
writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
// No desired bytecode version set, so no need to check for error.
(void)writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
result.output = llvm::encodeBase64(rawBytecodeBuffer);
}
return result;

View File

@@ -264,14 +264,9 @@ performActions(raw_ostream &os,
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
BytecodeWriterConfig writerConfig(fallbackResourceMap);
if (auto v = config.bytecodeVersionToEmit()) {
if (auto v = config.bytecodeVersionToEmit())
writerConfig.setDesiredBytecodeVersion(*v);
// Returns failure if requested version couldn't be used for opt tools.
return success(
writeBytecodeToFile(op.get(), os, writerConfig).minVersion <= *v);
}
writeBytecodeToFile(op.get(), os, writerConfig);
return success();
return writeBytecodeToFile(op.get(), os, writerConfig);
}
if (config.bytecodeVersionToEmit().has_value())

View File

@@ -571,8 +571,7 @@ def testOperationPrint():
# Test roundtrip to bytecode.
bytecode_stream = io.BytesIO()
result = module.operation.write_bytecode(bytecode_stream, desired_version=1)
assert result.min_version() == 1, "Requested version not serialized to"
module.operation.write_bytecode(bytecode_stream, desired_version=1)
bytecode = bytecode_stream.getvalue()
assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR"
module_roundtrip = Module.parse(bytecode, ctx)