[mlir][py] Reuse more of CAPI build time inference.

This reduces code generated for type inference and instead reuses
facilities CAPI side that performed same role.

Differential Revision: https://reviews.llvm.org/D156041t
This commit is contained in:
Jacques Pienaar
2023-07-23 21:26:52 -07:00
parent c48ed93cf8
commit f573bc24d4
4 changed files with 91 additions and 90 deletions

View File

@@ -78,6 +78,7 @@ Args:
ip: An InsertionPoint (defaults to resolve from context manager or set to
False to disable insertion, even with an insertion point set in the
context manager).
infer_type: Whether to infer result types.
Returns:
A new "detached" Operation object. Detached operations can be added
to blocks, which causes them to become "attached."
@@ -1288,7 +1289,7 @@ py::object PyOperation::create(const std::string &name,
std::optional<py::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
const py::object &maybeIp) {
const py::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1367,6 +1368,7 @@ py::object PyOperation::create(const std::string &name,
if (!mlirOperands.empty())
mlirOperationStateAddOperands(&state, mlirOperands.size(),
mlirOperands.data());
state.enableResultTypeInference = inferType;
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
mlirResults.data());
@@ -1398,6 +1400,8 @@ py::object PyOperation::create(const std::string &name,
// Construct the operation.
MlirOperation operation = mlirOperationCreate(&state);
if (!operation.ptr)
throw py::value_error("Operation creation failed");
PyOperationRef created =
PyOperation::createDetached(location->getContext(), operation);
maybeInsertOperation(created, maybeIp);
@@ -1441,51 +1445,10 @@ void PyOperation::erase() {
// PyOpView
//------------------------------------------------------------------------------
py::object
PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
py::list operandList, std::optional<py::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions,
DefaultingPyLocation location,
const py::object &maybeIp) {
PyMlirContextRef context = location->getContext();
// Class level operation construction metadata.
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
// Operand and result segment specs are either none, which does no
// variadic unpacking, or a list of ints with segment sizes, where each
// element is either a positive number (typically 1 for a scalar) or -1 to
// indicate that it is derived from the length of the same-indexed operand
// or result (implying that it is a list at that position).
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
std::vector<int32_t> operandSegmentLengths;
std::vector<int32_t> resultSegmentLengths;
// Validate/determine region count.
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
if (!regions) {
regions = opMinRegionCount;
}
if (*regions < opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
// Unpack results.
std::vector<PyType *> resultTypes;
static void populateResultTypes(StringRef name, py::list resultTypeList,
const py::object &resultSegmentSpecObj,
std::vector<int32_t> &resultSegmentLengths,
std::vector<PyType *> &resultTypes) {
resultTypes.reserve(resultTypeList.size());
if (resultSegmentSpecObj.is_none()) {
// Non-variadic result unpacking.
@@ -1568,6 +1531,56 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
}
}
}
}
py::object PyOpView::buildGeneric(
const py::object &cls, std::optional<py::list> resultTypeList,
py::list operandList, std::optional<py::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const py::object &maybeIp) {
PyMlirContextRef context = location->getContext();
// Class level operation construction metadata.
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
// Operand and result segment specs are either none, which does no
// variadic unpacking, or a list of ints with segment sizes, where each
// element is either a positive number (typically 1 for a scalar) or -1 to
// indicate that it is derived from the length of the same-indexed operand
// or result (implying that it is a list at that position).
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
std::vector<int32_t> operandSegmentLengths;
std::vector<int32_t> resultSegmentLengths;
// Validate/determine region count.
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
if (!regions) {
regions = opMinRegionCount;
}
if (*regions < opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
throw py::value_error(
(llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
llvm::Twine(opMinRegionCount) +
" regions but was built with regions=" + llvm::Twine(*regions))
.str());
}
// Unpack results.
std::vector<PyType *> resultTypes;
if (resultTypeList.has_value()) {
populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
resultSegmentLengths, resultTypes);
}
// Unpack operands.
std::vector<PyValue *> operands;
@@ -1694,7 +1707,8 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
/*operands=*/std::move(operands),
/*attributes=*/std::move(attributes),
/*successors=*/std::move(successors),
/*regions=*/*regions, location, maybeIp);
/*regions=*/*regions, location, maybeIp,
!resultTypeList);
}
pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
@@ -2854,7 +2868,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("attributes") = py::none(),
py::arg("successors") = py::none(), py::arg("regions") = 0,
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
kOperationCreateDocstring)
py::arg("infer_type") = false, kOperationCreateDocstring)
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,

View File

@@ -655,7 +655,8 @@ public:
std::optional<std::vector<PyValue *>> operands,
std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const pybind11::object &ip);
DefaultingPyLocation location, const pybind11::object &ip,
bool inferType);
/// Creates an OpView suitable for this operation.
pybind11::object createOpView();
@@ -704,13 +705,12 @@ public:
pybind11::object getOperationObject() { return operationObject; }
static pybind11::object
buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
pybind11::list operandList,
std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const pybind11::object &maybeIp);
static pybind11::object buildGeneric(
const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
pybind11::list operandList, std::optional<pybind11::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const pybind11::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
/// `__init__` method. The derived class will typically define a constructor

View File

@@ -245,14 +245,10 @@ def EmptyOp : TestOp<"empty">;
// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
// CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
// CHECK: operands=operands,
// CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
// CHECK: context=_ods_context,
// CHECK: loc=loc)
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
let results = (outs I32:$i32, F32:$f32);
}
@@ -260,13 +256,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
// CHECK: def __init__(self, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: _ods_context = _ods_get_default_loc_context(loc)
// CHECK: results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
// CHECK: operands=operands,
// CHECK: attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
// CHECK: context=_ods_context,
// CHECK: loc=loc)
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
let results = (outs AnyType, AnyType, AnyType);
}

View File

@@ -493,9 +493,7 @@ constexpr const char *initTemplate = R"Py(
attributes = {{}
regions = None
{1}
super().__init__(self.build_generic(
attributes=attributes, results=results, operands=operands,
successors=_ods_successors, regions=regions, loc=loc, ip=ip))
super().__init__(self.build_generic({2}))
)Py";
/// Template for appending a single element to the operand/result list.
@@ -755,17 +753,6 @@ _ods_derived_result_type = (
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Python code template for inferring the operation results using the
/// corresponding interface:
/// - {0} is the name of the class for which the types are inferred.
constexpr const char *inferTypeInterfaceTemplate =
R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
operands=operands,
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
context=_ods_context,
loc=loc)
)PY";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
@@ -805,12 +792,8 @@ populateBuilderLinesResult(const Operator &op,
return;
}
if (hasInferTypeInterface(op)) {
appendLineByLine(
llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
builderLines);
if (hasInferTypeInterface(op))
return;
}
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
@@ -934,8 +917,20 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
}
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
SmallVector<std::string> initArgs;
initArgs.push_back("attributes=attributes");
if (!hasInferTypeInterface(op))
initArgs.push_back("results=results");
initArgs.push_back("operands=operands");
initArgs.push_back("successors=_ods_successors");
initArgs.push_back("regions=regions");
initArgs.push_back("loc=loc");
initArgs.push_back("ip=ip");
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "));
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
}
static void emitSegmentSpec(