mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[mlir][py] Reuse more of CAPI build time inference.
This reduces code generated for type inference and instead reuses facilities CAPI side that performed same role. Differential Revision: https://reviews.llvm.org/D156041t
This commit is contained in:
@@ -78,6 +78,7 @@ Args:
|
||||
ip: An InsertionPoint (defaults to resolve from context manager or set to
|
||||
False to disable insertion, even with an insertion point set in the
|
||||
context manager).
|
||||
infer_type: Whether to infer result types.
|
||||
Returns:
|
||||
A new "detached" Operation object. Detached operations can be added
|
||||
to blocks, which causes them to become "attached."
|
||||
@@ -1288,7 +1289,7 @@ py::object PyOperation::create(const std::string &name,
|
||||
std::optional<py::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
int regions, DefaultingPyLocation location,
|
||||
const py::object &maybeIp) {
|
||||
const py::object &maybeIp, bool inferType) {
|
||||
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
||||
llvm::SmallVector<MlirType, 4> mlirResults;
|
||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||
@@ -1367,6 +1368,7 @@ py::object PyOperation::create(const std::string &name,
|
||||
if (!mlirOperands.empty())
|
||||
mlirOperationStateAddOperands(&state, mlirOperands.size(),
|
||||
mlirOperands.data());
|
||||
state.enableResultTypeInference = inferType;
|
||||
if (!mlirResults.empty())
|
||||
mlirOperationStateAddResults(&state, mlirResults.size(),
|
||||
mlirResults.data());
|
||||
@@ -1398,6 +1400,8 @@ py::object PyOperation::create(const std::string &name,
|
||||
|
||||
// Construct the operation.
|
||||
MlirOperation operation = mlirOperationCreate(&state);
|
||||
if (!operation.ptr)
|
||||
throw py::value_error("Operation creation failed");
|
||||
PyOperationRef created =
|
||||
PyOperation::createDetached(location->getContext(), operation);
|
||||
maybeInsertOperation(created, maybeIp);
|
||||
@@ -1441,51 +1445,10 @@ void PyOperation::erase() {
|
||||
// PyOpView
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
py::object
|
||||
PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
|
||||
py::list operandList, std::optional<py::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
std::optional<int> regions,
|
||||
DefaultingPyLocation location,
|
||||
const py::object &maybeIp) {
|
||||
PyMlirContextRef context = location->getContext();
|
||||
// Class level operation construction metadata.
|
||||
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
|
||||
// Operand and result segment specs are either none, which does no
|
||||
// variadic unpacking, or a list of ints with segment sizes, where each
|
||||
// element is either a positive number (typically 1 for a scalar) or -1 to
|
||||
// indicate that it is derived from the length of the same-indexed operand
|
||||
// or result (implying that it is a list at that position).
|
||||
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
|
||||
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
|
||||
|
||||
std::vector<int32_t> operandSegmentLengths;
|
||||
std::vector<int32_t> resultSegmentLengths;
|
||||
|
||||
// Validate/determine region count.
|
||||
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
|
||||
int opMinRegionCount = std::get<0>(opRegionSpec);
|
||||
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
|
||||
if (!regions) {
|
||||
regions = opMinRegionCount;
|
||||
}
|
||||
if (*regions < opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
|
||||
// Unpack results.
|
||||
std::vector<PyType *> resultTypes;
|
||||
static void populateResultTypes(StringRef name, py::list resultTypeList,
|
||||
const py::object &resultSegmentSpecObj,
|
||||
std::vector<int32_t> &resultSegmentLengths,
|
||||
std::vector<PyType *> &resultTypes) {
|
||||
resultTypes.reserve(resultTypeList.size());
|
||||
if (resultSegmentSpecObj.is_none()) {
|
||||
// Non-variadic result unpacking.
|
||||
@@ -1568,6 +1531,56 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object PyOpView::buildGeneric(
|
||||
const py::object &cls, std::optional<py::list> resultTypeList,
|
||||
py::list operandList, std::optional<py::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
std::optional<int> regions, DefaultingPyLocation location,
|
||||
const py::object &maybeIp) {
|
||||
PyMlirContextRef context = location->getContext();
|
||||
// Class level operation construction metadata.
|
||||
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
|
||||
// Operand and result segment specs are either none, which does no
|
||||
// variadic unpacking, or a list of ints with segment sizes, where each
|
||||
// element is either a positive number (typically 1 for a scalar) or -1 to
|
||||
// indicate that it is derived from the length of the same-indexed operand
|
||||
// or result (implying that it is a list at that position).
|
||||
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
|
||||
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
|
||||
|
||||
std::vector<int32_t> operandSegmentLengths;
|
||||
std::vector<int32_t> resultSegmentLengths;
|
||||
|
||||
// Validate/determine region count.
|
||||
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
|
||||
int opMinRegionCount = std::get<0>(opRegionSpec);
|
||||
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
|
||||
if (!regions) {
|
||||
regions = opMinRegionCount;
|
||||
}
|
||||
if (*regions < opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
|
||||
throw py::value_error(
|
||||
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
|
||||
llvm::Twine(opMinRegionCount) +
|
||||
" regions but was built with regions=" + llvm::Twine(*regions))
|
||||
.str());
|
||||
}
|
||||
|
||||
// Unpack results.
|
||||
std::vector<PyType *> resultTypes;
|
||||
if (resultTypeList.has_value()) {
|
||||
populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
|
||||
resultSegmentLengths, resultTypes);
|
||||
}
|
||||
|
||||
// Unpack operands.
|
||||
std::vector<PyValue *> operands;
|
||||
@@ -1694,7 +1707,8 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
|
||||
/*operands=*/std::move(operands),
|
||||
/*attributes=*/std::move(attributes),
|
||||
/*successors=*/std::move(successors),
|
||||
/*regions=*/*regions, location, maybeIp);
|
||||
/*regions=*/*regions, location, maybeIp,
|
||||
!resultTypeList);
|
||||
}
|
||||
|
||||
pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
|
||||
@@ -2854,7 +2868,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
py::arg("attributes") = py::none(),
|
||||
py::arg("successors") = py::none(), py::arg("regions") = 0,
|
||||
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
|
||||
kOperationCreateDocstring)
|
||||
py::arg("infer_type") = false, kOperationCreateDocstring)
|
||||
.def_static(
|
||||
"parse",
|
||||
[](const std::string &sourceStr, const std::string &sourceName,
|
||||
|
||||
@@ -655,7 +655,8 @@ public:
|
||||
std::optional<std::vector<PyValue *>> operands,
|
||||
std::optional<pybind11::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||
DefaultingPyLocation location, const pybind11::object &ip);
|
||||
DefaultingPyLocation location, const pybind11::object &ip,
|
||||
bool inferType);
|
||||
|
||||
/// Creates an OpView suitable for this operation.
|
||||
pybind11::object createOpView();
|
||||
@@ -704,13 +705,12 @@ public:
|
||||
|
||||
pybind11::object getOperationObject() { return operationObject; }
|
||||
|
||||
static pybind11::object
|
||||
buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
|
||||
pybind11::list operandList,
|
||||
std::optional<pybind11::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
std::optional<int> regions, DefaultingPyLocation location,
|
||||
const pybind11::object &maybeIp);
|
||||
static pybind11::object buildGeneric(
|
||||
const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
|
||||
pybind11::list operandList, std::optional<pybind11::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
std::optional<int> regions, DefaultingPyLocation location,
|
||||
const pybind11::object &maybeIp);
|
||||
|
||||
/// Construct an instance of a class deriving from OpView, bypassing its
|
||||
/// `__init__` method. The derived class will typically define a constructor
|
||||
|
||||
@@ -245,14 +245,10 @@ def EmptyOp : TestOp<"empty">;
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
|
||||
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
|
||||
// CHECK: def __init__(self, *, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
|
||||
// CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
|
||||
// CHECK: operands=operands,
|
||||
// CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
|
||||
// CHECK: context=_ods_context,
|
||||
// CHECK: loc=loc)
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, operands=operands,
|
||||
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
|
||||
let results = (outs I32:$i32, F32:$f32);
|
||||
}
|
||||
|
||||
@@ -260,13 +256,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
|
||||
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
|
||||
// CHECK: def __init__(self, *, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
|
||||
// CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
|
||||
// CHECK: operands=operands,
|
||||
// CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
|
||||
// CHECK: context=_ods_context,
|
||||
// CHECK: loc=loc)
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, operands=operands,
|
||||
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
|
||||
let results = (outs AnyType, AnyType, AnyType);
|
||||
}
|
||||
|
||||
|
||||
@@ -493,9 +493,7 @@ constexpr const char *initTemplate = R"Py(
|
||||
attributes = {{}
|
||||
regions = None
|
||||
{1}
|
||||
super().__init__(self.build_generic(
|
||||
attributes=attributes, results=results, operands=operands,
|
||||
successors=_ods_successors, regions=regions, loc=loc, ip=ip))
|
||||
super().__init__(self.build_generic({2}))
|
||||
)Py";
|
||||
|
||||
/// Template for appending a single element to the operand/result list.
|
||||
@@ -755,17 +753,6 @@ _ods_derived_result_type = (
|
||||
/// Python code template appending {0} type {1} times to the results list.
|
||||
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
|
||||
|
||||
/// Python code template for inferring the operation results using the
|
||||
/// corresponding interface:
|
||||
/// - {0} is the name of the class for which the types are inferred.
|
||||
constexpr const char *inferTypeInterfaceTemplate =
|
||||
R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
|
||||
operands=operands,
|
||||
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
|
||||
context=_ods_context,
|
||||
loc=loc)
|
||||
)PY";
|
||||
|
||||
/// Appends the given multiline string as individual strings into
|
||||
/// `builderLines`.
|
||||
static void appendLineByLine(StringRef string,
|
||||
@@ -805,12 +792,8 @@ populateBuilderLinesResult(const Operator &op,
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasInferTypeInterface(op)) {
|
||||
appendLineByLine(
|
||||
llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
|
||||
builderLines);
|
||||
if (hasInferTypeInterface(op))
|
||||
return;
|
||||
}
|
||||
|
||||
// For each element, find or generate a name.
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
@@ -934,8 +917,20 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
|
||||
}
|
||||
functionArgs.push_back("loc=None");
|
||||
functionArgs.push_back("ip=None");
|
||||
|
||||
SmallVector<std::string> initArgs;
|
||||
initArgs.push_back("attributes=attributes");
|
||||
if (!hasInferTypeInterface(op))
|
||||
initArgs.push_back("results=results");
|
||||
initArgs.push_back("operands=operands");
|
||||
initArgs.push_back("successors=_ods_successors");
|
||||
initArgs.push_back("regions=regions");
|
||||
initArgs.push_back("loc=loc");
|
||||
initArgs.push_back("ip=ip");
|
||||
|
||||
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
|
||||
llvm::join(builderLines, "\n "));
|
||||
llvm::join(builderLines, "\n "),
|
||||
llvm::join(initArgs, ", "));
|
||||
}
|
||||
|
||||
static void emitSegmentSpec(
|
||||
|
||||
Reference in New Issue
Block a user