[MLIR][Python] Support Python-defined rewrite patterns (#162699)
This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).
With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.
### Example
```python
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
2025-10-11 11:28:45 +08:00
|
|
|
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
|
|
|
|
|
|
from mlir.ir import *
|
|
|
|
|
from mlir.passmanager import *
|
|
|
|
|
from mlir.dialects.builtin import ModuleOp
|
|
|
|
|
from mlir.dialects import arith
|
|
|
|
|
from mlir.rewrite import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(f):
|
|
|
|
|
print("\nTEST:", f.__name__)
|
|
|
|
|
f()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testRewritePattern
|
|
|
|
|
@run
|
|
|
|
|
def testRewritePattern():
|
|
|
|
|
def to_muli(op, rewriter):
|
|
|
|
|
with rewriter.ip:
|
2025-10-13 11:56:57 +08:00
|
|
|
assert isinstance(op, arith.AddIOp)
|
|
|
|
|
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
|
[MLIR][Python] Support Python-defined rewrite patterns (#162699)
This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).
With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.
### Example
```python
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
2025-10-11 11:28:45 +08:00
|
|
|
rewriter.replace_op(op, new_op.owner)
|
|
|
|
|
|
|
|
|
|
def constant_1_to_2(op, rewriter):
|
2025-10-13 11:56:57 +08:00
|
|
|
c = op.value.value
|
[MLIR][Python] Support Python-defined rewrite patterns (#162699)
This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).
With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.
### Example
```python
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
2025-10-11 11:28:45 +08:00
|
|
|
if c != 1:
|
|
|
|
|
return True # failed to match
|
|
|
|
|
with rewriter.ip:
|
2025-10-13 11:56:57 +08:00
|
|
|
new_op = arith.constant(op.type, 2, loc=op.location)
|
[MLIR][Python] Support Python-defined rewrite patterns (#162699)
This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).
With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.
### Example
```python
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
2025-10-11 11:28:45 +08:00
|
|
|
rewriter.replace_op(op, [new_op])
|
|
|
|
|
|
|
|
|
|
with Context():
|
|
|
|
|
patterns = RewritePatternSet()
|
|
|
|
|
patterns.add(arith.AddIOp, to_muli)
|
|
|
|
|
patterns.add(arith.ConstantOp, constant_1_to_2)
|
|
|
|
|
frozen = patterns.freeze()
|
|
|
|
|
|
|
|
|
|
module = ModuleOp.parse(
|
|
|
|
|
r"""
|
|
|
|
|
module {
|
|
|
|
|
func.func @add(%a: i64, %b: i64) -> i64 {
|
|
|
|
|
%sum = arith.addi %a, %b : i64
|
|
|
|
|
return %sum : i64
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
|
|
|
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
|
|
|
|
|
# CHECK: return %0 : i64
|
|
|
|
|
print(module)
|
|
|
|
|
|
|
|
|
|
module = ModuleOp.parse(
|
|
|
|
|
r"""
|
|
|
|
|
module {
|
|
|
|
|
func.func @const() -> (i64, i64) {
|
|
|
|
|
%0 = arith.constant 1 : i64
|
|
|
|
|
%1 = arith.constant 3 : i64
|
|
|
|
|
return %0, %1 : i64, i64
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
|
|
|
# CHECK: %c2_i64 = arith.constant 2 : i64
|
|
|
|
|
# CHECK: %c3_i64 = arith.constant 3 : i64
|
|
|
|
|
# CHECK: return %c2_i64, %c3_i64 : i64, i64
|
|
|
|
|
print(module)
|