mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 12:19:23 +08:00
[mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953)
This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this. This change also alters the API contract of _ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this. Most of the diff here is simply reordering the code in IRCore.cpp.
This commit is contained in:
@@ -1481,12 +1481,11 @@ static void maybeInsertOperation(PyOperationRef &op,
|
||||
|
||||
nb::object PyOperation::create(std::string_view name,
|
||||
std::optional<std::vector<PyType *>> results,
|
||||
std::optional<std::vector<PyValue *>> operands,
|
||||
llvm::ArrayRef<MlirValue> operands,
|
||||
std::optional<nb::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
int regions, DefaultingPyLocation location,
|
||||
const nb::object &maybeIp, bool inferType) {
|
||||
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
||||
llvm::SmallVector<MlirType, 4> mlirResults;
|
||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
|
||||
@@ -1495,16 +1494,6 @@ nb::object PyOperation::create(std::string_view name,
|
||||
if (regions < 0)
|
||||
throw nb::value_error("number of regions must be >= 0");
|
||||
|
||||
// Unpack/validate operands.
|
||||
if (operands) {
|
||||
mlirOperands.reserve(operands->size());
|
||||
for (PyValue *operand : *operands) {
|
||||
if (!operand)
|
||||
throw nb::value_error("operand value cannot be None");
|
||||
mlirOperands.push_back(operand->get());
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack/validate results.
|
||||
if (results) {
|
||||
mlirResults.reserve(results->size());
|
||||
@@ -1562,9 +1551,8 @@ nb::object PyOperation::create(std::string_view name,
|
||||
// point, exceptions cannot be thrown or else the state will leak.
|
||||
MlirOperationState state =
|
||||
mlirOperationStateGet(toMlirStringRef(name), location);
|
||||
if (!mlirOperands.empty())
|
||||
mlirOperationStateAddOperands(&state, mlirOperands.size(),
|
||||
mlirOperands.data());
|
||||
if (!operands.empty())
|
||||
mlirOperationStateAddOperands(&state, operands.size(), operands.data());
|
||||
state.enableResultTypeInference = inferType;
|
||||
if (!mlirResults.empty())
|
||||
mlirOperationStateAddResults(&state, mlirResults.size(),
|
||||
@@ -1632,6 +1620,143 @@ void PyOperation::erase() {
|
||||
mlirOperationDestroy(operation);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// CRTP base class for Python MLIR values that subclass Value and should be
|
||||
/// castable from it. The value hierarchy is one level deep and is not supposed
|
||||
/// to accommodate other levels unless core MLIR changes.
|
||||
template <typename DerivedTy>
|
||||
class PyConcreteValue : public PyValue {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
// and redefine bindDerived.
|
||||
using ClassTy = nb::class_<DerivedTy, PyValue>;
|
||||
using IsAFunctionTy = bool (*)(MlirValue);
|
||||
|
||||
PyConcreteValue() = default;
|
||||
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
|
||||
: PyValue(operationRef, value) {}
|
||||
PyConcreteValue(PyValue &orig)
|
||||
: PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
|
||||
|
||||
/// Attempts to cast the original value to the derived type and throws on
|
||||
/// type mismatches.
|
||||
static MlirValue castFrom(PyValue &orig) {
|
||||
if (!DerivedTy::isaFunction(orig.get())) {
|
||||
auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
|
||||
throw nb::value_error((Twine("Cannot cast value to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr +
|
||||
")")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return orig.get();
|
||||
}
|
||||
|
||||
/// Binds the Python module objects to functions of this class.
|
||||
static void bind(nb::module_ &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyValue &otherValue) -> bool {
|
||||
return DerivedTy::isaFunction(otherValue);
|
||||
},
|
||||
nb::arg("other_value"));
|
||||
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
|
||||
[](DerivedTy &self) { return self.maybeDownCast(); });
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Python wrapper for MlirOpResult.
|
||||
class PyOpResult : public PyConcreteValue<PyOpResult> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
|
||||
static constexpr const char *pyClassName = "OpResult";
|
||||
using PyConcreteValue::PyConcreteValue;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro("owner", [](PyOpResult &self) {
|
||||
assert(
|
||||
mlirOperationEqual(self.getParentOperation()->get(),
|
||||
mlirOpResultGetOwner(self.get())) &&
|
||||
"expected the owner of the value in Python to match that in the IR");
|
||||
return self.getParentOperation().getObject();
|
||||
});
|
||||
c.def_prop_ro("result_number", [](PyOpResult &self) {
|
||||
return mlirOpResultGetResultNumber(self.get());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the list of types of the values held by container.
|
||||
template <typename Container>
|
||||
static std::vector<MlirType> getValueTypes(Container &container,
|
||||
PyMlirContextRef &context) {
|
||||
std::vector<MlirType> result;
|
||||
result.reserve(container.size());
|
||||
for (int i = 0, e = container.size(); i < e; ++i) {
|
||||
result.push_back(mlirValueGetType(container.getElement(i).get()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// A list of operation results. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The (returned) result list is associated
|
||||
/// with the operation whose results these are, and thus extends the lifetime of
|
||||
/// this operation.
|
||||
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "OpResultList";
|
||||
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
|
||||
|
||||
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
|
||||
intptr_t length = -1, intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirOperationGetNumResults(operation->get())
|
||||
: length,
|
||||
step),
|
||||
operation(std::move(operation)) {}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro("types", [](PyOpResultList &self) {
|
||||
return getValueTypes(self, self.operation->getContext());
|
||||
});
|
||||
c.def_prop_ro("owner", [](PyOpResultList &self) {
|
||||
return self.operation->createOpView();
|
||||
});
|
||||
}
|
||||
|
||||
PyOperationRef &getOperation() { return operation; }
|
||||
|
||||
private:
|
||||
/// Give the parent CRTP class access to hook implementations below.
|
||||
friend class Sliceable<PyOpResultList, PyOpResult>;
|
||||
|
||||
intptr_t getRawNumElements() {
|
||||
operation->checkValid();
|
||||
return mlirOperationGetNumResults(operation->get());
|
||||
}
|
||||
|
||||
PyOpResult getRawElement(intptr_t index) {
|
||||
PyValue value(operation, mlirOperationGetResult(operation->get(), index));
|
||||
return PyOpResult(value);
|
||||
}
|
||||
|
||||
PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
|
||||
return PyOpResultList(operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// PyOpView
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -1733,6 +1858,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
|
||||
}
|
||||
}
|
||||
|
||||
static MlirValue getUniqueResult(MlirOperation operation) {
|
||||
auto numResults = mlirOperationGetNumResults(operation);
|
||||
if (numResults != 1) {
|
||||
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
|
||||
throw nb::value_error((Twine("Cannot call .result on operation ") +
|
||||
StringRef(name.data, name.length) + " which has " +
|
||||
Twine(numResults) +
|
||||
" results (it is only valid for operations with a "
|
||||
"single result)")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
|
||||
static MlirValue getOpResultOrValue(nb::handle operand) {
|
||||
if (operand.is_none()) {
|
||||
throw nb::value_error("contained a None item");
|
||||
}
|
||||
PyOperationBase *op;
|
||||
if (nb::try_cast<PyOperationBase *>(operand, op)) {
|
||||
return getUniqueResult(op->getOperation());
|
||||
}
|
||||
PyOpResultList *opResultList;
|
||||
if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
|
||||
return getUniqueResult(opResultList->getOperation()->get());
|
||||
}
|
||||
PyValue *value;
|
||||
if (nb::try_cast<PyValue *>(operand, value)) {
|
||||
return value->get();
|
||||
}
|
||||
throw nb::value_error("is not a Value");
|
||||
}
|
||||
|
||||
nb::object PyOpView::buildGeneric(
|
||||
std::string_view name, std::tuple<int, bool> opRegionSpec,
|
||||
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
|
||||
@@ -1783,16 +1942,14 @@ nb::object PyOpView::buildGeneric(
|
||||
}
|
||||
|
||||
// Unpack operands.
|
||||
std::vector<PyValue *> operands;
|
||||
llvm::SmallVector<MlirValue, 4> operands;
|
||||
operands.reserve(operands.size());
|
||||
if (operandSegmentSpecObj.is_none()) {
|
||||
// Non-sized operand unpacking.
|
||||
for (const auto &it : llvm::enumerate(operandList)) {
|
||||
try {
|
||||
operands.push_back(nb::cast<PyValue *>(it.value()));
|
||||
if (!operands.back())
|
||||
throw nb::cast_error();
|
||||
} catch (nb::cast_error &err) {
|
||||
operands.push_back(getOpResultOrValue(it.value()));
|
||||
} catch (nb::builtin_exception &err) {
|
||||
throw nb::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Value (" + err.what() + ")")
|
||||
@@ -1818,29 +1975,31 @@ nb::object PyOpView::buildGeneric(
|
||||
int segmentSpec = std::get<1>(it.value());
|
||||
if (segmentSpec == 1 || segmentSpec == 0) {
|
||||
// Unpack unary element.
|
||||
try {
|
||||
auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
|
||||
if (operandValue) {
|
||||
operands.push_back(operandValue);
|
||||
operandSegmentLengths.push_back(1);
|
||||
} else if (segmentSpec == 0) {
|
||||
// Allowed to be optional.
|
||||
operandSegmentLengths.push_back(0);
|
||||
} else {
|
||||
throw nb::value_error(
|
||||
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
||||
" of operation \"" + name +
|
||||
"\" must be a Value (was None and operand is not optional)")
|
||||
.str()
|
||||
.c_str());
|
||||
auto &operand = std::get<0>(it.value());
|
||||
if (!operand.is_none()) {
|
||||
try {
|
||||
|
||||
operands.push_back(getOpResultOrValue(operand));
|
||||
} catch (nb::builtin_exception &err) {
|
||||
throw nb::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) +
|
||||
" of operation \"" + name +
|
||||
"\" must be a Value (" + err.what() + ")")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
} catch (nb::cast_error &err) {
|
||||
throw nb::value_error((llvm::Twine("Operand ") +
|
||||
llvm::Twine(it.index()) + " of operation \"" +
|
||||
name + "\" must be a Value (" + err.what() +
|
||||
")")
|
||||
.str()
|
||||
.c_str());
|
||||
|
||||
operandSegmentLengths.push_back(1);
|
||||
} else if (segmentSpec == 0) {
|
||||
// Allowed to be optional.
|
||||
operandSegmentLengths.push_back(0);
|
||||
} else {
|
||||
throw nb::value_error(
|
||||
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
||||
" of operation \"" + name +
|
||||
"\" must be a Value (was None and operand is not optional)")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
} else if (segmentSpec == -1) {
|
||||
// Unpack sequence by appending.
|
||||
@@ -1852,10 +2011,7 @@ nb::object PyOpView::buildGeneric(
|
||||
// Unpack the list.
|
||||
auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
|
||||
for (nb::handle segmentItem : segment) {
|
||||
operands.push_back(nb::cast<PyValue *>(segmentItem));
|
||||
if (!operands.back()) {
|
||||
throw nb::type_error("contained a None item");
|
||||
}
|
||||
operands.push_back(getOpResultOrValue(segmentItem));
|
||||
}
|
||||
operandSegmentLengths.push_back(nb::len(segment));
|
||||
}
|
||||
@@ -2269,57 +2425,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// CRTP base class for Python MLIR values that subclass Value and should be
|
||||
/// castable from it. The value hierarchy is one level deep and is not supposed
|
||||
/// to accommodate other levels unless core MLIR changes.
|
||||
template <typename DerivedTy>
|
||||
class PyConcreteValue : public PyValue {
|
||||
public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
// and redefine bindDerived.
|
||||
using ClassTy = nb::class_<DerivedTy, PyValue>;
|
||||
using IsAFunctionTy = bool (*)(MlirValue);
|
||||
|
||||
PyConcreteValue() = default;
|
||||
PyConcreteValue(PyOperationRef operationRef, MlirValue value)
|
||||
: PyValue(operationRef, value) {}
|
||||
PyConcreteValue(PyValue &orig)
|
||||
: PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
|
||||
|
||||
/// Attempts to cast the original value to the derived type and throws on
|
||||
/// type mismatches.
|
||||
static MlirValue castFrom(PyValue &orig) {
|
||||
if (!DerivedTy::isaFunction(orig.get())) {
|
||||
auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
|
||||
throw nb::value_error((Twine("Cannot cast value to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr +
|
||||
")")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return orig.get();
|
||||
}
|
||||
|
||||
/// Binds the Python module objects to functions of this class.
|
||||
static void bind(nb::module_ &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyValue &otherValue) -> bool {
|
||||
return DerivedTy::isaFunction(otherValue);
|
||||
},
|
||||
nb::arg("other_value"));
|
||||
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
|
||||
[](DerivedTy &self) { return self.maybeDownCast(); });
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Implemented by derived classes to add methods to the Python subclass.
|
||||
static void bindDerived(ClassTy &m) {}
|
||||
};
|
||||
|
||||
/// Python wrapper for MlirBlockArgument.
|
||||
class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
|
||||
@@ -2345,39 +2450,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Python wrapper for MlirOpResult.
|
||||
class PyOpResult : public PyConcreteValue<PyOpResult> {
|
||||
public:
|
||||
static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
|
||||
static constexpr const char *pyClassName = "OpResult";
|
||||
using PyConcreteValue::PyConcreteValue;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro("owner", [](PyOpResult &self) {
|
||||
assert(
|
||||
mlirOperationEqual(self.getParentOperation()->get(),
|
||||
mlirOpResultGetOwner(self.get())) &&
|
||||
"expected the owner of the value in Python to match that in the IR");
|
||||
return self.getParentOperation().getObject();
|
||||
});
|
||||
c.def_prop_ro("result_number", [](PyOpResult &self) {
|
||||
return mlirOpResultGetResultNumber(self.get());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Returns the list of types of the values held by container.
|
||||
template <typename Container>
|
||||
static std::vector<MlirType> getValueTypes(Container &container,
|
||||
PyMlirContextRef &context) {
|
||||
std::vector<MlirType> result;
|
||||
result.reserve(container.size());
|
||||
for (int i = 0, e = container.size(); i < e; ++i) {
|
||||
result.push_back(mlirValueGetType(container.getElement(i).get()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// A list of block arguments. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The argument list is associated with the
|
||||
/// operation that contains the block (detached blocks are not allowed in
|
||||
@@ -2484,53 +2556,6 @@ private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
/// A list of operation results. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The (returned) result list is associated
|
||||
/// with the operation whose results these are, and thus extends the lifetime of
|
||||
/// this operation.
|
||||
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "OpResultList";
|
||||
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
|
||||
|
||||
PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
|
||||
intptr_t length = -1, intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirOperationGetNumResults(operation->get())
|
||||
: length,
|
||||
step),
|
||||
operation(std::move(operation)) {}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro("types", [](PyOpResultList &self) {
|
||||
return getValueTypes(self, self.operation->getContext());
|
||||
});
|
||||
c.def_prop_ro("owner", [](PyOpResultList &self) {
|
||||
return self.operation->createOpView();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
/// Give the parent CRTP class access to hook implementations below.
|
||||
friend class Sliceable<PyOpResultList, PyOpResult>;
|
||||
|
||||
intptr_t getRawNumElements() {
|
||||
operation->checkValid();
|
||||
return mlirOperationGetNumResults(operation->get());
|
||||
}
|
||||
|
||||
PyOpResult getRawElement(intptr_t index) {
|
||||
PyValue value(operation, mlirOperationGetResult(operation->get(), index));
|
||||
return PyOpResult(value);
|
||||
}
|
||||
|
||||
PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
|
||||
return PyOpResultList(operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
/// A list of operation successors. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The (returned) successor list is
|
||||
/// associated with the operation whose successors these are, and thus extends
|
||||
@@ -3123,20 +3148,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
"result",
|
||||
[](PyOperationBase &self) {
|
||||
auto &operation = self.getOperation();
|
||||
auto numResults = mlirOperationGetNumResults(operation);
|
||||
if (numResults != 1) {
|
||||
auto name = mlirIdentifierStr(mlirOperationGetName(operation));
|
||||
throw nb::value_error(
|
||||
(Twine("Cannot call .result on operation ") +
|
||||
StringRef(name.data, name.length) + " which has " +
|
||||
Twine(numResults) +
|
||||
" results (it is only valid for operations with a "
|
||||
"single result)")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return PyOpResult(operation.getRef(),
|
||||
mlirOperationGetResult(operation, 0))
|
||||
return PyOpResult(operation.getRef(), getUniqueResult(operation))
|
||||
.maybeDownCast();
|
||||
},
|
||||
"Shortcut to get an op result if it has only one (throws an error "
|
||||
@@ -3233,14 +3245,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
nb::arg("walk_order") = MlirWalkPostOrder);
|
||||
|
||||
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
|
||||
.def_static("create", &PyOperation::create, nb::arg("name"),
|
||||
nb::arg("results").none() = nb::none(),
|
||||
nb::arg("operands").none() = nb::none(),
|
||||
nb::arg("attributes").none() = nb::none(),
|
||||
nb::arg("successors").none() = nb::none(),
|
||||
nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(),
|
||||
nb::arg("ip").none() = nb::none(),
|
||||
nb::arg("infer_type") = false, kOperationCreateDocstring)
|
||||
.def_static(
|
||||
"create",
|
||||
[](std::string_view name,
|
||||
std::optional<std::vector<PyType *>> results,
|
||||
std::optional<std::vector<PyValue *>> operands,
|
||||
std::optional<nb::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||
DefaultingPyLocation location, const nb::object &maybeIp,
|
||||
bool inferType) {
|
||||
// Unpack/validate operands.
|
||||
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
||||
if (operands) {
|
||||
mlirOperands.reserve(operands->size());
|
||||
for (PyValue *operand : *operands) {
|
||||
if (!operand)
|
||||
throw nb::value_error("operand value cannot be None");
|
||||
mlirOperands.push_back(operand->get());
|
||||
}
|
||||
}
|
||||
|
||||
return PyOperation::create(name, results, mlirOperands, attributes,
|
||||
successors, regions, location, maybeIp,
|
||||
inferType);
|
||||
},
|
||||
nb::arg("name"), nb::arg("results").none() = nb::none(),
|
||||
nb::arg("operands").none() = nb::none(),
|
||||
nb::arg("attributes").none() = nb::none(),
|
||||
nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
|
||||
nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
|
||||
nb::arg("infer_type") = false, kOperationCreateDocstring)
|
||||
.def_static(
|
||||
"parse",
|
||||
[](const std::string &sourceStr, const std::string &sourceName,
|
||||
|
||||
@@ -686,7 +686,7 @@ public:
|
||||
/// Creates an operation. See corresponding python docstring.
|
||||
static nanobind::object
|
||||
create(std::string_view name, std::optional<std::vector<PyType *>> results,
|
||||
std::optional<std::vector<PyValue *>> operands,
|
||||
llvm::ArrayRef<MlirValue> operands,
|
||||
std::optional<nanobind::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||
DefaultingPyLocation location, const nanobind::object &ip,
|
||||
|
||||
@@ -115,7 +115,10 @@ def get_op_results_or_values(
|
||||
_cext.ir.Operation,
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
|
||||
]
|
||||
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
|
||||
) -> _Union[
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
|
||||
_cext.ir.OpResultList,
|
||||
]:
|
||||
"""Returns the given sequence of values or the results of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
@@ -127,7 +130,7 @@ def get_op_results_or_values(
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.results
|
||||
else:
|
||||
return [get_op_result_or_value(element) for element in arg]
|
||||
return arg
|
||||
|
||||
|
||||
def get_op_result_or_op_results(
|
||||
|
||||
@@ -27,8 +27,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_results_or_values(variadic1))
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
|
||||
// CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: operands.append(variadic2)
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(
|
||||
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
|
||||
@@ -173,8 +173,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
|
||||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
|
||||
// CHECK: _ods_get_default_loc_context(loc))
|
||||
// CHECK: if is_ is not None: attributes["is"] = (is_
|
||||
@@ -307,9 +307,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
|
||||
// CHECK: operands.append(_get_op_result_or_value(f32))
|
||||
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
|
||||
// CHECK: operands.append(_gen_arg_0)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: operands.append(_gen_arg_2)
|
||||
// CHECK: results.append(i32)
|
||||
// CHECK: results.append(_gen_res_1)
|
||||
// CHECK: results.append(i64)
|
||||
@@ -349,8 +349,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_optional))
|
||||
// CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional))
|
||||
// CHECK: operands.append(non_optional)
|
||||
// CHECK: if optional is not None: operands.append(optional)
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(
|
||||
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
|
||||
@@ -380,7 +380,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
|
||||
// CHECK: operands.append(non_variadic)
|
||||
// CHECK: operands.extend(_get_op_results_or_values(variadic))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(
|
||||
@@ -445,7 +445,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(in_))
|
||||
// CHECK: operands.append(in_)
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(
|
||||
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
|
||||
@@ -547,8 +547,8 @@ def SimpleOp : TestOp<"simple"> {
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(i32))
|
||||
// CHECK: operands.append(_get_op_result_or_value(f32))
|
||||
// CHECK: operands.append(i32)
|
||||
// CHECK: operands.append(f32)
|
||||
// CHECK: results.append(i64)
|
||||
// CHECK: results.append(f64)
|
||||
// CHECK: _ods_successors = None
|
||||
|
||||
@@ -37,7 +37,6 @@ from ._ods_common import (
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_result_or_op_results as _get_op_result_or_op_results,
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
segmented_accessor as _ods_segmented_accessor,
|
||||
)
|
||||
@@ -501,17 +500,15 @@ constexpr const char *initTemplate = R"Py(
|
||||
|
||||
/// Template for appending a single element to the operand/result list.
|
||||
/// {0} is the field name.
|
||||
constexpr const char *singleOperandAppendTemplate =
|
||||
"operands.append(_get_op_result_or_value({0}))";
|
||||
constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
|
||||
constexpr const char *singleResultAppendTemplate = "results.append({0})";
|
||||
|
||||
/// Template for appending an optional element to the operand/result list.
|
||||
/// {0} is the field name.
|
||||
constexpr const char *optionalAppendOperandTemplate =
|
||||
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
|
||||
"if {0} is not None: operands.append({0})";
|
||||
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
|
||||
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
|
||||
"None)";
|
||||
"operands.append({0})";
|
||||
constexpr const char *optionalAppendResultTemplate =
|
||||
"if {0} is not None: results.append({0})";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user