[MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns (#163080)

This is a follow-up PR for #162699.

Currently, in the function where we define rewrite patterns, the `op` we
receive is of type `ir.Operation` rather than a specific `OpView` type
(such as `arith.AddIOp`). This means we can’t conveniently access
certain parts of the operation — for example, we need to use
`op.operands[0]` instead of `op.lhs`. The following example code
illustrates this situation.

```python
def to_muli(op, rewriter):
  # op is typed ir.Operation instead of arith.AddIOp
  pass

patterns.add(arith.AddIOp, to_muli)
```

In this PR, we convert the operation to its corresponding `OpView`
subclass before invoking the rewrite pattern callback, making it much
easier to write patterns.

---------

Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
This commit is contained in:
Twice
2025-10-13 11:56:57 +08:00
committed by GitHub
parent 6785c4f2ff
commit 06e2c78680
3 changed files with 11 additions and 5 deletions

View File

@@ -197,7 +197,12 @@ public:
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
nb::object res = f(op, PyPatternRewriter(rewriter));
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePattenCreate(

View File

@@ -92,7 +92,7 @@ class ConstantOp(ConstantOp):
@property
def value(self):
return Attribute(self.operation.attributes["value"])
return self.operation.attributes["value"]
@property
def literal_value(self) -> Union[int, float]:

View File

@@ -17,15 +17,16 @@ def run(f):
def testRewritePattern():
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
assert isinstance(op, arith.AddIOp)
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
rewriter.replace_op(op, new_op.owner)
def constant_1_to_2(op, rewriter):
c = op.attributes["value"].value
c = op.value.value
if c != 1:
return True # failed to match
with rewriter.ip:
new_op = arith.constant(op.result.type, 2, loc=op.location)
new_op = arith.constant(op.type, 2, loc=op.location)
rewriter.replace_op(op, [new_op])
with Context():