From e904ddf3154766cdad84b2808dbf9a1dd259bb62 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Mar 2019 05:04:38 -0700 Subject: [PATCH] Python bindings: expose various Ops through declarative builders In particular, expose `cond_br`, `select` and `call` operations with syntax similar to that of the previous emitter-based EDSC interface. These are provided for backwards-compatibility. Ideally, we want them to be Table-generated from the Op definitions when those definitions are declarative. Additionally, expose the ability to construct any op given its canonical name, which also exercises the construction of unregistered ops. PiperOrigin-RevId: 238421583 --- mlir/bindings/python/pybind.cpp | 46 ++++++++++++++++++++++- mlir/bindings/python/test/test_py2and3.py | 41 ++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/mlir/bindings/python/pybind.cpp b/mlir/bindings/python/pybind.cpp index fb48aa904bc1..d03a8f88f463 100644 --- a/mlir/bindings/python/pybind.cpp +++ b/mlir/bindings/python/pybind.cpp @@ -76,6 +76,17 @@ struct PythonValueHandle { return std::to_string(reinterpret_cast(value.getValue())); } + PythonValueHandle call(const std::vector &args) { + assert(value.hasType() && value.getType().isa() && + "can only call function-typed values"); + + std::vector argValues; + argValues.reserve(args.size()); + for (auto arg : args) + argValues.push_back(arg.value.getValue()); + return ValueHandle::create(value, argValues); + } + mlir::edsc::ValueHandle value; }; @@ -958,6 +969,38 @@ PYBIND11_MODULE(pybind, m) { return PythonValueHandle(nullptr); }, py::arg("dest"), py::arg("args") = std::vector()); + m.def( + "cond_br", + [](PythonValueHandle condition, const PythonBlockHandle &trueDest, + const std::vector &trueArgs, + const PythonBlockHandle &falseDest, + const std::vector &falseArgs) -> PythonValueHandle { + std::vector trueArguments(trueArgs.begin(), + trueArgs.end()); + std::vector falseArguments(falseArgs.begin(), + falseArgs.end()); + intrinsics::COND_BR(condition, trueDest, trueArguments, falseDest, + falseArguments); + return PythonValueHandle(nullptr); + }); + m.def("select", + [](PythonValueHandle condition, PythonValueHandle trueValue, + PythonValueHandle falseValue) -> PythonValueHandle { + return ValueHandle::create(condition.value, trueValue.value, + falseValue.value); + }); + m.def("op", + [](const std::string &name, + const std::vector &operands, + const std::vector &resultTypes) -> PythonValueHandle { + std::vector operandHandles(operands.begin(), + operands.end()); + std::vector types; + types.reserve(resultTypes.size()); + for (auto t : resultTypes) + types.push_back(Type::getFromOpaquePointer(t.type)); + return ValueHandle::create(name, operandHandles, types); + }); m.def("Max", [](const py::list &args) { SmallVector owning; @@ -1163,7 +1206,8 @@ PYBIND11_MODULE(pybind, m) { -> PythonValueHandle { return lhs.value / rhs.value; }) .def("__mod__", [](PythonValueHandle lhs, PythonValueHandle rhs) - -> PythonValueHandle { return lhs.value % rhs.value; }); + -> PythonValueHandle { return lhs.value % rhs.value; }) + .def("__call__", &PythonValueHandle::call); } py::class_( diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index b89e90c77f33..779135fdee9f 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -227,6 +227,18 @@ class EdscTest(unittest.TestCase): self.assertIn("^bb1(%0: index, %1: index):", code) self.assertIn(" br ^bb1(%1, %0 : index, index)", code) + def testCondBr(self): + with self.module.function_context("foo", [self.boolType], []) as fun: + with E.BlockContext() as blk1: + E.ret([]) + with E.BlockContext([self.indexType]) as blk2: + E.ret([]) + cst = E.constant_index(0) + E.cond_br(fun.arg(0), blk1, [], blk2, [cst]) + + code = str(fun) + self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code) + def testRet(self): with self.module.function_context("foo", [], [self.indexType, self.indexType]) as fun: @@ -238,6 +250,35 @@ class EdscTest(unittest.TestCase): self.assertIn(" %c0 = constant 0 : index", code) self.assertIn(" return %c42, %c0 : index, index", code) + def testSelectOp(self): + with self.module.function_context("foo", [self.boolType], + [self.i32Type]) as fun: + a = E.constant_int(42, 32) + b = E.constant_int(0, 32) + E.ret([E.select(fun.arg(0), a, b)]) + + code = str(fun) + self.assertIn("%0 = select %arg0, %c42_i32, %c0_i32 : i32", code) + + def testCallOp(self): + callee = self.module.declare_function("sqrtf", [self.f32Type], + [self.f32Type]) + with self.module.function_context("call", [self.f32Type], []) as fun: + funCst = E.constant_function(callee) + funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type) + + code = str(self.module) + self.assertIn("func @sqrtf(f32) -> f32", code) + self.assertIn("%f = constant @sqrtf : (f32) -> f32", code) + self.assertIn("%0 = call_indirect %f(%arg0) : (f32) -> f32", code) + + def testCustom(self): + with self.module.function_context("custom", [self.indexType, self.f32Type], + []) as fun: + E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1) + code = str(fun) + self.assertIn('%0 = "foo"(%arg0) : (index) -> f32', code) + self.assertIn("%1 = addf %0, %arg1 : f32", code) def testConstants(self): with self.module.function_context("constants", [], []) as fun: