mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 21:55:39 +08:00
Implement python iteration over the operation/region/block hierarchy.
* Removes the half-completed prior attempt at region/block mutation in favor of new approach to ownership. * Will re-add mutation more correctly in a follow-on. * Eliminates the detached state on blocks and regions, simplifying the ownership hierarchy. * Adds both iterator and index based access at each level. Differential Revision: https://reviews.llvm.org/D87982
This commit is contained in:
@@ -91,6 +91,12 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
|
||||
/** Takes an MLIR context owned by the caller and destroys it. */
|
||||
void mlirContextDestroy(MlirContext context);
|
||||
|
||||
/** Sets whether unregistered dialects are allowed in this context. */
|
||||
void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow);
|
||||
|
||||
/** Returns whether the context allows unregistered dialects. */
|
||||
int mlirContextGetAllowUnregisteredDialects(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Location API. */
|
||||
/*============================================================================*/
|
||||
|
||||
@@ -46,45 +46,6 @@ static const char kContextGetUnknownLocationDocstring[] =
|
||||
static const char kContextGetFileLocationDocstring[] =
|
||||
R"(Gets a Location representing a file, line and column)";
|
||||
|
||||
static const char kContextCreateBlockDocstring[] =
|
||||
R"(Creates a detached block)";
|
||||
|
||||
static const char kContextCreateRegionDocstring[] =
|
||||
R"(Creates a detached region)";
|
||||
|
||||
static const char kRegionAppendBlockDocstring[] =
|
||||
R"(Appends a block to a region.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block is already attached to another region.
|
||||
)";
|
||||
|
||||
static const char kRegionInsertBlockDocstring[] =
|
||||
R"(Inserts a block at a postiion in a region.
|
||||
|
||||
Raises:
|
||||
ValueError: If the block is already attached to another region.
|
||||
)";
|
||||
|
||||
static const char kRegionFirstBlockDocstring[] =
|
||||
R"(Gets the first block in a region.
|
||||
|
||||
Blocks can also be accessed via the `blocks` container.
|
||||
|
||||
Raises:
|
||||
IndexError: If the region has no blocks.
|
||||
)";
|
||||
|
||||
static const char kBlockNextInRegionDocstring[] =
|
||||
R"(Gets the next block in the enclosing region.
|
||||
|
||||
Blocks can also be accessed via the `blocks` container of the owning region.
|
||||
This method exists to mirror the lower level API and should not be preferred.
|
||||
|
||||
Raises:
|
||||
IndexError: If there are no further blocks.
|
||||
)";
|
||||
|
||||
static const char kOperationStrDunderDocstring[] =
|
||||
R"(Prints the assembly form of the operation with default options.
|
||||
|
||||
@@ -170,6 +131,241 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Collections.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
namespace {
|
||||
|
||||
class PyRegionIterator {
|
||||
public:
|
||||
PyRegionIterator(PyOperationRef operation)
|
||||
: operation(std::move(operation)) {}
|
||||
|
||||
PyRegionIterator &dunderIter() { return *this; }
|
||||
|
||||
PyRegion dunderNext() {
|
||||
operation->checkValid();
|
||||
if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
|
||||
return PyRegion(operation, region);
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyRegionIterator>(m, "RegionIterator")
|
||||
.def("__iter__", &PyRegionIterator::dunderIter)
|
||||
.def("__next__", &PyRegionIterator::dunderNext);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
int nextIndex = 0;
|
||||
};
|
||||
|
||||
/// Regions of an op are fixed length and indexed numerically so are represented
|
||||
/// with a sequence-like container.
|
||||
class PyRegionList {
|
||||
public:
|
||||
PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
|
||||
|
||||
intptr_t dunderLen() {
|
||||
operation->checkValid();
|
||||
return mlirOperationGetNumRegions(operation->get());
|
||||
}
|
||||
|
||||
PyRegion dunderGetItem(intptr_t index) {
|
||||
// dunderLen checks validity.
|
||||
if (index < 0 || index >= dunderLen()) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds region");
|
||||
}
|
||||
MlirRegion region = mlirOperationGetRegion(operation->get(), index);
|
||||
return PyRegion(operation, region);
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyRegionList>(m, "ReqionSequence")
|
||||
.def("__len__", &PyRegionList::dunderLen)
|
||||
.def("__getitem__", &PyRegionList::dunderGetItem);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
class PyBlockIterator {
|
||||
public:
|
||||
PyBlockIterator(PyOperationRef operation, MlirBlock next)
|
||||
: operation(std::move(operation)), next(next) {}
|
||||
|
||||
PyBlockIterator &dunderIter() { return *this; }
|
||||
|
||||
PyBlock dunderNext() {
|
||||
operation->checkValid();
|
||||
if (mlirBlockIsNull(next)) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
PyBlock returnBlock(operation, next);
|
||||
next = mlirBlockGetNextInRegion(next);
|
||||
return returnBlock;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyBlockIterator>(m, "BlockIterator")
|
||||
.def("__iter__", &PyBlockIterator::dunderIter)
|
||||
.def("__next__", &PyBlockIterator::dunderNext);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
MlirBlock next;
|
||||
};
|
||||
|
||||
/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
|
||||
/// we present them as a more full-featured list-like container but optimzie
|
||||
/// it for forward iteration. Blocks are always owned by a region.
|
||||
class PyBlockList {
|
||||
public:
|
||||
PyBlockList(PyOperationRef operation, MlirRegion region)
|
||||
: operation(std::move(operation)), region(region) {}
|
||||
|
||||
PyBlockIterator dunderIter() {
|
||||
operation->checkValid();
|
||||
return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
|
||||
}
|
||||
|
||||
intptr_t dunderLen() {
|
||||
operation->checkValid();
|
||||
intptr_t count = 0;
|
||||
MlirBlock block = mlirRegionGetFirstBlock(region);
|
||||
while (!mlirBlockIsNull(block)) {
|
||||
count += 1;
|
||||
block = mlirBlockGetNextInRegion(block);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
PyBlock dunderGetItem(intptr_t index) {
|
||||
operation->checkValid();
|
||||
if (index < 0) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds block");
|
||||
}
|
||||
MlirBlock block = mlirRegionGetFirstBlock(region);
|
||||
while (!mlirBlockIsNull(block)) {
|
||||
if (index == 0) {
|
||||
return PyBlock(operation, block);
|
||||
}
|
||||
block = mlirBlockGetNextInRegion(block);
|
||||
index -= 1;
|
||||
}
|
||||
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyBlockList>(m, "BlockList")
|
||||
.def("__getitem__", &PyBlockList::dunderGetItem)
|
||||
.def("__iter__", &PyBlockList::dunderIter)
|
||||
.def("__len__", &PyBlockList::dunderLen);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef operation;
|
||||
MlirRegion region;
|
||||
};
|
||||
|
||||
class PyOperationIterator {
|
||||
public:
|
||||
PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
|
||||
: parentOperation(std::move(parentOperation)), next(next) {}
|
||||
|
||||
PyOperationIterator &dunderIter() { return *this; }
|
||||
|
||||
py::object dunderNext() {
|
||||
parentOperation->checkValid();
|
||||
if (mlirOperationIsNull(next)) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
PyOperationRef returnOperation =
|
||||
PyOperation::forOperation(parentOperation->getContext(), next);
|
||||
next = mlirOperationGetNextInBlock(next);
|
||||
return returnOperation.releaseObject();
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOperationIterator>(m, "OperationIterator")
|
||||
.def("__iter__", &PyOperationIterator::dunderIter)
|
||||
.def("__next__", &PyOperationIterator::dunderNext);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef parentOperation;
|
||||
MlirOperation next;
|
||||
};
|
||||
|
||||
/// Operations are exposed by the C-API as a forward-only linked list. In
|
||||
/// Python, we present them as a more full-featured list-like container but
|
||||
/// optimzie it for forward iteration. Iterable operations are always owned
|
||||
/// by a block.
|
||||
class PyOperationList {
|
||||
public:
|
||||
PyOperationList(PyOperationRef parentOperation, MlirBlock block)
|
||||
: parentOperation(std::move(parentOperation)), block(block) {}
|
||||
|
||||
PyOperationIterator dunderIter() {
|
||||
parentOperation->checkValid();
|
||||
return PyOperationIterator(parentOperation,
|
||||
mlirBlockGetFirstOperation(block));
|
||||
}
|
||||
|
||||
intptr_t dunderLen() {
|
||||
parentOperation->checkValid();
|
||||
intptr_t count = 0;
|
||||
MlirOperation childOp = mlirBlockGetFirstOperation(block);
|
||||
while (!mlirOperationIsNull(childOp)) {
|
||||
count += 1;
|
||||
childOp = mlirOperationGetNextInBlock(childOp);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
py::object dunderGetItem(intptr_t index) {
|
||||
parentOperation->checkValid();
|
||||
if (index < 0) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds operation");
|
||||
}
|
||||
MlirOperation childOp = mlirBlockGetFirstOperation(block);
|
||||
while (!mlirOperationIsNull(childOp)) {
|
||||
if (index == 0) {
|
||||
return PyOperation::forOperation(parentOperation->getContext(), childOp)
|
||||
.releaseObject();
|
||||
}
|
||||
childOp = mlirOperationGetNextInBlock(childOp);
|
||||
index -= 1;
|
||||
}
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"attempt to access out of bounds operation");
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyOperationList>(m, "OperationList")
|
||||
.def("__getitem__", &PyOperationList::dunderGetItem)
|
||||
.def("__iter__", &PyOperationList::dunderIter)
|
||||
.def("__len__", &PyOperationList::dunderLen);
|
||||
}
|
||||
|
||||
private:
|
||||
PyOperationRef parentOperation;
|
||||
MlirBlock block;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyMlirContext
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -309,24 +505,6 @@ void PyOperation::checkValid() {
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyBlock, PyRegion.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
void PyRegion::attachToParent() {
|
||||
if (!detached) {
|
||||
throw SetPyError(PyExc_ValueError, "Region is already attached to an op");
|
||||
}
|
||||
detached = false;
|
||||
}
|
||||
|
||||
void PyBlock::attachToParent() {
|
||||
if (!detached) {
|
||||
throw SetPyError(PyExc_ValueError, "Block is already attached to an op");
|
||||
}
|
||||
detached = false;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyAttribute.
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -967,6 +1145,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
return ref.releaseObject();
|
||||
})
|
||||
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
|
||||
.def_property(
|
||||
"allow_unregistered_dialects",
|
||||
[](PyMlirContext &self) -> bool {
|
||||
return mlirContextGetAllowUnregisteredDialects(self.get());
|
||||
},
|
||||
[](PyMlirContext &self, bool value) {
|
||||
mlirContextSetAllowUnregisteredDialects(self.get(), value);
|
||||
})
|
||||
.def(
|
||||
"parse_module",
|
||||
[](PyMlirContext &self, const std::string moduleAsm) {
|
||||
@@ -1026,37 +1212,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
self.get(), filename.c_str(), line, col));
|
||||
},
|
||||
kContextGetFileLocationDocstring, py::arg("filename"),
|
||||
py::arg("line"), py::arg("col"))
|
||||
.def(
|
||||
"create_region",
|
||||
[](PyMlirContext &self) {
|
||||
// The creating context is explicitly captured on regions to
|
||||
// facilitate illegal assemblies of objects from multiple contexts
|
||||
// that would invalidate the memory model.
|
||||
return PyRegion(self.get(), mlirRegionCreate(),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
|
||||
.def(
|
||||
"create_block",
|
||||
[](PyMlirContext &self, std::vector<PyType> pyTypes) {
|
||||
// In order for the keep_alive extend the proper lifetime, all
|
||||
// types must be from the same context.
|
||||
for (auto pyType : pyTypes) {
|
||||
if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
|
||||
self.get())) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"All types used to construct a block must be from "
|
||||
"the same context as the block");
|
||||
}
|
||||
}
|
||||
llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
|
||||
pyTypes.end());
|
||||
return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
|
||||
/*detached=*/true);
|
||||
},
|
||||
py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
|
||||
py::arg("line"), py::arg("col"));
|
||||
|
||||
py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@@ -1096,17 +1252,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
// Mapping of Operation.
|
||||
py::class_<PyOperation>(m, "Operation")
|
||||
.def_property_readonly(
|
||||
"first_region",
|
||||
[](PyOperation &self) {
|
||||
self.checkValid();
|
||||
if (mlirOperationGetNumRegions(self.get()) == 0) {
|
||||
throw SetPyError(PyExc_IndexError, "Operation has no regions");
|
||||
}
|
||||
return PyRegion(self.getContext()->get(),
|
||||
mlirOperationGetRegion(self.get(), 0),
|
||||
/*detached=*/false);
|
||||
},
|
||||
py::keep_alive<0, 1>(), "Gets the operation's first region")
|
||||
"regions",
|
||||
[](PyOperation &self) { return PyRegionList(self.getRef()); })
|
||||
.def("__iter__",
|
||||
[](PyOperation &self) { return PyRegionIterator(self.getRef()); })
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyOperation &self) {
|
||||
@@ -1120,63 +1269,62 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
|
||||
// Mapping of PyRegion.
|
||||
py::class_<PyRegion>(m, "Region")
|
||||
.def(
|
||||
"append_block",
|
||||
[](PyRegion &self, PyBlock &block) {
|
||||
if (!mlirContextEqual(self.context, block.context)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"Block must have been created from the same context as "
|
||||
"this region");
|
||||
}
|
||||
|
||||
block.attachToParent();
|
||||
mlirRegionAppendOwnedBlock(self.region, block.block);
|
||||
},
|
||||
kRegionAppendBlockDocstring)
|
||||
.def(
|
||||
"insert_block",
|
||||
[](PyRegion &self, int pos, PyBlock &block) {
|
||||
if (!mlirContextEqual(self.context, block.context)) {
|
||||
throw SetPyError(
|
||||
PyExc_ValueError,
|
||||
"Block must have been created from the same context as "
|
||||
"this region");
|
||||
}
|
||||
block.attachToParent();
|
||||
// TODO: Make this return a failure and raise if out of bounds.
|
||||
mlirRegionInsertOwnedBlock(self.region, pos, block.block);
|
||||
},
|
||||
kRegionInsertBlockDocstring)
|
||||
.def_property_readonly(
|
||||
"first_block",
|
||||
"blocks",
|
||||
[](PyRegion &self) {
|
||||
MlirBlock block = mlirRegionGetFirstBlock(self.region);
|
||||
if (mlirBlockIsNull(block)) {
|
||||
throw SetPyError(PyExc_IndexError, "Region has no blocks");
|
||||
}
|
||||
return PyBlock(self.context, block, /*detached=*/false);
|
||||
return PyBlockList(self.getParentOperation(), self.get());
|
||||
},
|
||||
kRegionFirstBlockDocstring);
|
||||
"Returns a forward-optimized sequence of blocks.")
|
||||
.def(
|
||||
"__iter__",
|
||||
[](PyRegion &self) {
|
||||
self.checkValid();
|
||||
MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
|
||||
return PyBlockIterator(self.getParentOperation(), firstBlock);
|
||||
},
|
||||
"Iterates over blocks in the region.")
|
||||
.def("__eq__", [](PyRegion &self, py::object &other) {
|
||||
try {
|
||||
PyRegion *otherRegion = other.cast<PyRegion *>();
|
||||
return self.get().ptr == otherRegion->get().ptr;
|
||||
} catch (std::exception &e) {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
|
||||
// Mapping of PyBlock.
|
||||
py::class_<PyBlock>(m, "Block")
|
||||
.def_property_readonly(
|
||||
"next_in_region",
|
||||
"operations",
|
||||
[](PyBlock &self) {
|
||||
MlirBlock block = mlirBlockGetNextInRegion(self.block);
|
||||
if (mlirBlockIsNull(block)) {
|
||||
throw SetPyError(PyExc_IndexError,
|
||||
"Attempt to read past last block");
|
||||
}
|
||||
return PyBlock(self.context, block, /*detached=*/false);
|
||||
return PyOperationList(self.getParentOperation(), self.get());
|
||||
},
|
||||
py::keep_alive<0, 1>(), kBlockNextInRegionDocstring)
|
||||
"Returns a forward-optimized sequence of operations.")
|
||||
.def(
|
||||
"__iter__",
|
||||
[](PyBlock &self) {
|
||||
self.checkValid();
|
||||
MlirOperation firstOperation =
|
||||
mlirBlockGetFirstOperation(self.get());
|
||||
return PyOperationIterator(self.getParentOperation(),
|
||||
firstOperation);
|
||||
},
|
||||
"Iterates over operations in the block.")
|
||||
.def("__eq__",
|
||||
[](PyBlock &self, py::object &other) {
|
||||
try {
|
||||
PyBlock *otherBlock = other.cast<PyBlock *>();
|
||||
return self.get().ptr == otherBlock->get().ptr;
|
||||
} catch (std::exception &e) {
|
||||
return false;
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"__str__",
|
||||
[](PyBlock &self) {
|
||||
self.checkValid();
|
||||
PyPrintAccumulator printAccum;
|
||||
mlirBlockPrint(self.block, printAccum.getCallback(),
|
||||
mlirBlockPrint(self.get(), printAccum.getCallback(),
|
||||
printAccum.getUserData());
|
||||
return printAccum.join();
|
||||
},
|
||||
@@ -1310,4 +1458,12 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
||||
PyMemRefType::bind(m);
|
||||
PyUnrankedMemRefType::bind(m);
|
||||
PyTupleType::bind(m);
|
||||
|
||||
// Container bindings.
|
||||
PyBlockIterator::bind(m);
|
||||
PyBlockList::bind(m);
|
||||
PyOperationIterator::bind(m);
|
||||
PyOperationList::bind(m);
|
||||
PyRegionIterator::bind(m);
|
||||
PyRegionList::bind(m);
|
||||
}
|
||||
|
||||
@@ -249,69 +249,43 @@ private:
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirRegion.
|
||||
/// Note that region can exist in a detached state (where this instance is
|
||||
/// responsible for clearing) or an attached state (where its owner is
|
||||
/// responsible).
|
||||
///
|
||||
/// This python wrapper retains a redundant reference to its creating context
|
||||
/// in order to facilitate checking that parts of the operation hierarchy
|
||||
/// are only assembled from the same context.
|
||||
/// Regions are managed completely by their containing operation. Unlike the
|
||||
/// C++ API, the python API does not support detached regions.
|
||||
class PyRegion {
|
||||
public:
|
||||
PyRegion(MlirContext context, MlirRegion region, bool detached)
|
||||
: context(context), region(region), detached(detached) {}
|
||||
PyRegion(PyRegion &&other)
|
||||
: context(other.context), region(other.region), detached(other.detached) {
|
||||
other.detached = false;
|
||||
}
|
||||
~PyRegion() {
|
||||
if (detached)
|
||||
mlirRegionDestroy(region);
|
||||
PyRegion(PyOperationRef parentOperation, MlirRegion region)
|
||||
: parentOperation(std::move(parentOperation)), region(region) {
|
||||
assert(!mlirRegionIsNull(region) && "python region cannot be null");
|
||||
}
|
||||
|
||||
// Call prior to attaching the region to a parent.
|
||||
// This will transition to the attached state and will throw an exception
|
||||
// if already attached.
|
||||
void attachToParent();
|
||||
MlirRegion get() { return region; }
|
||||
PyOperationRef &getParentOperation() { return parentOperation; }
|
||||
|
||||
MlirContext context;
|
||||
MlirRegion region;
|
||||
void checkValid() { return parentOperation->checkValid(); }
|
||||
|
||||
private:
|
||||
bool detached;
|
||||
PyOperationRef parentOperation;
|
||||
MlirRegion region;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirBlock.
|
||||
/// Note that blocks can exist in a detached state (where this instance is
|
||||
/// responsible for clearing) or an attached state (where its owner is
|
||||
/// responsible).
|
||||
///
|
||||
/// This python wrapper retains a redundant reference to its creating context
|
||||
/// in order to facilitate checking that parts of the operation hierarchy
|
||||
/// are only assembled from the same context.
|
||||
/// Blocks are managed completely by their containing operation. Unlike the
|
||||
/// C++ API, the python API does not support detached blocks.
|
||||
class PyBlock {
|
||||
public:
|
||||
PyBlock(MlirContext context, MlirBlock block, bool detached)
|
||||
: context(context), block(block), detached(detached) {}
|
||||
PyBlock(PyBlock &&other)
|
||||
: context(other.context), block(other.block), detached(other.detached) {
|
||||
other.detached = false;
|
||||
}
|
||||
~PyBlock() {
|
||||
if (detached)
|
||||
mlirBlockDestroy(block);
|
||||
PyBlock(PyOperationRef parentOperation, MlirBlock block)
|
||||
: parentOperation(std::move(parentOperation)), block(block) {
|
||||
assert(!mlirBlockIsNull(block) && "python block cannot be null");
|
||||
}
|
||||
|
||||
// Call prior to attaching the block to a parent.
|
||||
// This will transition to the attached state and will throw an exception
|
||||
// if already attached.
|
||||
void attachToParent();
|
||||
MlirBlock get() { return block; }
|
||||
PyOperationRef &getParentOperation() { return parentOperation; }
|
||||
|
||||
MlirContext context;
|
||||
MlirBlock block;
|
||||
void checkValid() { return parentOperation->checkValid(); }
|
||||
|
||||
private:
|
||||
bool detached;
|
||||
PyOperationRef parentOperation;
|
||||
MlirBlock block;
|
||||
};
|
||||
|
||||
/// Wrapper around the generic MlirAttribute.
|
||||
|
||||
@@ -12,6 +12,7 @@ add_mlir_library(MLIRCAPIIR
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRStandardOps
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRSupport
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
@@ -25,6 +26,10 @@ using namespace mlir;
|
||||
|
||||
MlirContext mlirContextCreate() {
|
||||
auto *context = new MLIRContext(/*loadAllDialects=*/false);
|
||||
// TODO: Come up with a story for which dialects to load into the context
|
||||
// and do not expand this beyond StandardOps until done so. This is loaded
|
||||
// by default here because it is hard to make progress otherwise.
|
||||
context->loadDialect<StandardOpsDialect>();
|
||||
return wrap(context);
|
||||
}
|
||||
|
||||
@@ -34,6 +39,14 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
|
||||
|
||||
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
|
||||
|
||||
void mlirContextSetAllowUnregisteredDialects(MlirContext context, int allow) {
|
||||
unwrap(context)->allowUnregisteredDialects(allow);
|
||||
}
|
||||
|
||||
int mlirContextGetAllowUnregisteredDialects(MlirContext context) {
|
||||
return unwrap(context)->allowsUnregisteredDialects();
|
||||
}
|
||||
|
||||
/* ========================================================================== */
|
||||
/* Location API. */
|
||||
/* ========================================================================== */
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
import itertools
|
||||
import mlir
|
||||
|
||||
def run(f):
|
||||
@@ -10,65 +11,91 @@ def run(f):
|
||||
assert mlir.ir.Context._get_live_count() == 0
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDetachedRegionBlock
|
||||
def testDetachedRegionBlock():
|
||||
# Verify iterator based traversal of the op/region/block hierarchy.
|
||||
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
|
||||
def testTraverseOpRegionBlockIterators():
|
||||
ctx = mlir.ir.Context()
|
||||
t = mlir.ir.F32Type(ctx)
|
||||
region = ctx.create_region()
|
||||
block = ctx.create_block([t, t])
|
||||
# CHECK: <<UNLINKED BLOCK>>
|
||||
print(block)
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = ctx.parse_module(r"""
|
||||
func @f1(%arg0: i32) -> i32 {
|
||||
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
return %1 : i32
|
||||
}
|
||||
""")
|
||||
op = module.operation
|
||||
# Get the block using iterators off of the named collections.
|
||||
regions = list(op.regions)
|
||||
blocks = list(regions[0].blocks)
|
||||
# CHECK: MODULE REGIONS=1 BLOCKS=1
|
||||
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
|
||||
|
||||
run(testDetachedRegionBlock)
|
||||
# Get the regions and blocks from the default collections.
|
||||
default_regions = list(op)
|
||||
default_blocks = list(default_regions[0])
|
||||
# They should compare equal regardless of how obtained.
|
||||
assert default_regions == regions
|
||||
assert default_blocks == blocks
|
||||
|
||||
# Should be able to get the operations from either the named collection
|
||||
# or the block.
|
||||
operations = list(blocks[0].operations)
|
||||
default_operations = list(blocks[0])
|
||||
assert default_operations == operations
|
||||
|
||||
def walk_operations(indent, op):
|
||||
for i, region in enumerate(op):
|
||||
print(f"{indent}REGION {i}:")
|
||||
for j, block in enumerate(region):
|
||||
print(f"{indent} BLOCK {j}:")
|
||||
for k, child_op in enumerate(block):
|
||||
print(f"{indent} OP {k}: {child_op}")
|
||||
walk_operations(indent + " ", child_op)
|
||||
|
||||
# CHECK: REGION 0:
|
||||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: func
|
||||
# CHECK: REGION 0:
|
||||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: %0 = "custom.addi"
|
||||
# CHECK: OP 1: return
|
||||
# CHECK: OP 1: "module_terminator"
|
||||
walk_operations("", op)
|
||||
|
||||
run(testTraverseOpRegionBlockIterators)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockTypeContextMismatch
|
||||
def testBlockTypeContextMismatch():
|
||||
c1 = mlir.ir.Context()
|
||||
c2 = mlir.ir.Context()
|
||||
t1 = mlir.ir.F32Type(c1)
|
||||
t2 = mlir.ir.F32Type(c2)
|
||||
try:
|
||||
block = c1.create_block([t1, t2])
|
||||
except ValueError as e:
|
||||
# CHECK: ERROR: All types used to construct a block must be from the same context as the block
|
||||
print("ERROR:", e)
|
||||
|
||||
run(testBlockTypeContextMismatch)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockAppend
|
||||
def testBlockAppend():
|
||||
# Verify index based traversal of the op/region/block hierarchy.
|
||||
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
|
||||
def testTraverseOpRegionBlockIndices():
|
||||
ctx = mlir.ir.Context()
|
||||
t = mlir.ir.F32Type(ctx)
|
||||
region = ctx.create_region()
|
||||
try:
|
||||
region.first_block
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
ctx.allow_unregistered_dialects = True
|
||||
module = ctx.parse_module(r"""
|
||||
func @f1(%arg0: i32) -> i32 {
|
||||
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
||||
return %1 : i32
|
||||
}
|
||||
""")
|
||||
|
||||
block = ctx.create_block([t, t])
|
||||
region.append_block(block)
|
||||
try:
|
||||
region.append_block(block)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
def walk_operations(indent, op):
|
||||
for i in range(len(op.regions)):
|
||||
region = op.regions[i]
|
||||
print(f"{indent}REGION {i}:")
|
||||
for j in range(len(region.blocks)):
|
||||
block = region.blocks[j]
|
||||
print(f"{indent} BLOCK {j}:")
|
||||
for k in range(len(block.operations)):
|
||||
child_op = block.operations[k]
|
||||
print(f"{indent} OP {k}: {child_op}")
|
||||
walk_operations(indent + " ", child_op)
|
||||
|
||||
block2 = ctx.create_block([t])
|
||||
region.insert_block(1, block2)
|
||||
# CHECK: <<UNLINKED BLOCK>>
|
||||
block_first = region.first_block
|
||||
print(block_first)
|
||||
block_next = block_first.next_in_region
|
||||
try:
|
||||
block_next = block_next.next_in_region
|
||||
except IndexError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("Expected exception not raised")
|
||||
# CHECK: REGION 0:
|
||||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: func
|
||||
# CHECK: REGION 0:
|
||||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: %0 = "custom.addi"
|
||||
# CHECK: OP 1: return
|
||||
# CHECK: OP 1: "module_terminator"
|
||||
walk_operations("", module.operation)
|
||||
|
||||
run(testBlockAppend)
|
||||
run(testTraverseOpRegionBlockIndices)
|
||||
|
||||
Reference in New Issue
Block a user