mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 03:50:17 +08:00
Friendlier wrapper for transform.foreach. To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.) Reland of #171544 due to fixup for integration test.
327 lines
11 KiB
Python
327 lines
11 KiB
Python
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
from mlir.dialects import arith, func, pdl
|
|
from mlir.dialects.builtin import module
|
|
from mlir.ir import *
|
|
from mlir.rewrite import *
|
|
|
|
|
|
def construct_and_print_in_module(f):
|
|
print("\nTEST:", f.__name__)
|
|
with Context(), Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
module = f(module)
|
|
if module is not None:
|
|
print(module)
|
|
return f
|
|
|
|
|
|
def get_pdl_patterns():
|
|
# Create a rewrite from add to mul. This will match
|
|
# - operation name is arith.addi
|
|
# - operands are index types.
|
|
# - there are two operands.
|
|
with Location.unknown():
|
|
m = Module.create()
|
|
with InsertionPoint(m.body):
|
|
# Change all arith.addi with index types to arith.muli.
|
|
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
|
|
def pat():
|
|
# Match arith.addi with index types.
|
|
index_type = pdl.TypeOp(IndexType.get())
|
|
operand0 = pdl.OperandOp(index_type)
|
|
operand1 = pdl.OperandOp(index_type)
|
|
op0 = pdl.OperationOp(
|
|
name="arith.addi", args=[operand0, operand1], types=[index_type]
|
|
)
|
|
|
|
# Replace the matched op with arith.muli.
|
|
@pdl.rewrite()
|
|
def rew():
|
|
newOp = pdl.OperationOp(
|
|
name="arith.muli", args=[operand0, operand1], types=[index_type]
|
|
)
|
|
pdl.ReplaceOp(op0, with_op=newOp)
|
|
|
|
# Create a PDL module from module and freeze it. At this point the ownership
|
|
# of the module is transferred to the PDL module. This ownership transfer is
|
|
# not yet captured Python side/has sharp edges. So best to construct the
|
|
# module and PDL module in same scope.
|
|
# FIXME: This should be made more robust.
|
|
return PDLModule(m).freeze()
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_add_to_mul
|
|
# CHECK: arith.muli
|
|
@construct_and_print_in_module
|
|
def test_add_to_mul(module_):
|
|
index_type = IndexType.get()
|
|
|
|
# Create a test case.
|
|
@module(sym_name="ir")
|
|
def ir():
|
|
@func.func(index_type, index_type)
|
|
def add_func(a, b):
|
|
return arith.addi(a, b)
|
|
|
|
frozen = get_pdl_patterns()
|
|
# Could apply frozen pattern set multiple times.
|
|
apply_patterns_and_fold_greedily(module_, frozen)
|
|
return module_
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_add_to_mul_with_op
|
|
# CHECK: arith.muli
|
|
@construct_and_print_in_module
|
|
def test_add_to_mul_with_op(module_):
|
|
index_type = IndexType.get()
|
|
|
|
# Create a test case.
|
|
@module(sym_name="ir")
|
|
def ir():
|
|
@func.func(index_type, index_type)
|
|
def add_func(a, b):
|
|
return arith.addi(a, b)
|
|
|
|
frozen = get_pdl_patterns()
|
|
apply_patterns_and_fold_greedily(module_.operation, frozen)
|
|
return module_
|
|
|
|
|
|
# If we use arith.constant and arith.addi here,
|
|
# these C++-defined folding/canonicalization will be applied
|
|
# implicitly in the greedy pattern rewrite driver to
|
|
# make our Python-defined folding useless,
|
|
# so here we define a new dialect to workaround this.
|
|
def load_myint_dialect():
|
|
from mlir.dialects import irdl
|
|
|
|
m = Module.create()
|
|
with InsertionPoint(m.body):
|
|
myint = irdl.dialect("myint")
|
|
with InsertionPoint(myint.body):
|
|
constant = irdl.operation_("constant")
|
|
with InsertionPoint(constant.body):
|
|
iattr = irdl.base(base_name="#builtin.integer")
|
|
i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
|
|
irdl.attributes_([iattr], ["value"])
|
|
irdl.results_([i32], ["cst"], [irdl.Variadicity.single])
|
|
add = irdl.operation_("add")
|
|
with InsertionPoint(add.body):
|
|
i32 = irdl.is_(TypeAttr.get(IntegerType.get_signless(32)))
|
|
irdl.operands_(
|
|
[i32, i32],
|
|
["lhs", "rhs"],
|
|
[irdl.Variadicity.single, irdl.Variadicity.single],
|
|
)
|
|
irdl.results_([i32], ["res"], [irdl.Variadicity.single])
|
|
|
|
m.operation.verify()
|
|
irdl.load_dialects(m)
|
|
|
|
|
|
# This PDL pattern is to fold constant additions,
|
|
# including two patterns:
|
|
# 1. add(constant0, constant1) -> constant2
|
|
# where constant2 = constant0 + constant1;
|
|
# 2. add(x, 0) or add(0, x) -> x.
|
|
def get_pdl_pattern_fold():
|
|
m = Module.create()
|
|
i32 = IntegerType.get_signless(32)
|
|
with InsertionPoint(m.body):
|
|
|
|
@pdl.pattern(benefit=1, sym_name="myint_add_fold")
|
|
def pat():
|
|
t = pdl.TypeOp(i32)
|
|
a0 = pdl.AttributeOp()
|
|
a1 = pdl.AttributeOp()
|
|
c0 = pdl.OperationOp(
|
|
name="myint.constant", attributes={"value": a0}, types=[t]
|
|
)
|
|
c1 = pdl.OperationOp(
|
|
name="myint.constant", attributes={"value": a1}, types=[t]
|
|
)
|
|
v0 = pdl.ResultOp(c0, 0)
|
|
v1 = pdl.ResultOp(c1, 0)
|
|
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
|
|
|
|
@pdl.rewrite()
|
|
def rew():
|
|
sum = pdl.apply_native_rewrite(
|
|
[pdl.AttributeType.get()], "add_fold", [a0, a1]
|
|
)
|
|
newOp = pdl.OperationOp(
|
|
name="myint.constant", attributes={"value": sum}, types=[t]
|
|
)
|
|
pdl.ReplaceOp(op0, with_op=newOp)
|
|
|
|
@pdl.pattern(benefit=1, sym_name="myint_add_zero_fold")
|
|
def pat():
|
|
t = pdl.TypeOp(i32)
|
|
v0 = pdl.OperandOp()
|
|
v1 = pdl.OperandOp()
|
|
v = pdl.apply_native_constraint([pdl.ValueType.get()], "has_zero", [v0, v1])
|
|
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
|
|
|
|
@pdl.rewrite()
|
|
def rew():
|
|
pdl.ReplaceOp(op0, with_values=[v])
|
|
|
|
def add_fold(rewriter, results, values):
|
|
a0, a1 = values
|
|
results.append(IntegerAttr.get(i32, a0.value + a1.value))
|
|
|
|
def is_zero(value):
|
|
op = value.owner
|
|
if isinstance(op, OpView):
|
|
return op.name == "myint.constant" and op.attributes["value"].value == 0
|
|
return False
|
|
|
|
# Check if either operand is a constant zero,
|
|
# and append the other operand to the results if so.
|
|
def has_zero(rewriter, results, values):
|
|
v0, v1 = values
|
|
if is_zero(v0):
|
|
results.append(v1)
|
|
return False
|
|
if is_zero(v1):
|
|
results.append(v0)
|
|
return False
|
|
return True
|
|
|
|
pdl_module = PDLModule(m)
|
|
pdl_module.register_rewrite_function("add_fold", add_fold)
|
|
pdl_module.register_constraint_function("has_zero", has_zero)
|
|
return pdl_module.freeze()
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_pdl_register_function
|
|
# CHECK: "myint.constant"() {value = 8 : i32} : () -> i32
|
|
@construct_and_print_in_module
|
|
def test_pdl_register_function(module_):
|
|
load_myint_dialect()
|
|
|
|
module_ = Module.parse(
|
|
"""
|
|
%c0 = "myint.constant"() { value = 2 }: () -> (i32)
|
|
%c1 = "myint.constant"() { value = 3 }: () -> (i32)
|
|
%x = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
|
|
"myint.add"(%x, %c1): (i32, i32) -> (i32)
|
|
"""
|
|
)
|
|
|
|
frozen = get_pdl_pattern_fold()
|
|
apply_patterns_and_fold_greedily(module_, frozen)
|
|
|
|
return module_
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_pdl_register_function_constraint
|
|
# CHECK: return %arg0 : i32
|
|
@construct_and_print_in_module
|
|
def test_pdl_register_function_constraint(module_):
|
|
load_myint_dialect()
|
|
|
|
module_ = Module.parse(
|
|
"""
|
|
func.func @f(%x : i32) -> i32 {
|
|
%c0 = "myint.constant"() { value = 1 }: () -> (i32)
|
|
%c1 = "myint.constant"() { value = -1 }: () -> (i32)
|
|
%a = "myint.add"(%c0, %c1): (i32, i32) -> (i32)
|
|
%b = "myint.add"(%a, %x): (i32, i32) -> (i32)
|
|
%c = "myint.add"(%b, %a): (i32, i32) -> (i32)
|
|
func.return %c : i32
|
|
}
|
|
"""
|
|
)
|
|
|
|
frozen = get_pdl_pattern_fold()
|
|
apply_patterns_and_fold_greedily(module_, frozen)
|
|
|
|
return module_
|
|
|
|
|
|
# This pattern is to expand constant to additions
|
|
# unless the constant is no more than 1,
|
|
# e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
|
|
def get_pdl_pattern_expand():
|
|
m = Module.create()
|
|
i32 = IntegerType.get_signless(32)
|
|
with InsertionPoint(m.body):
|
|
|
|
@pdl.pattern(benefit=1, sym_name="myint_constant_expand")
|
|
def pat():
|
|
t = pdl.TypeOp(i32)
|
|
cst = pdl.AttributeOp()
|
|
pdl.apply_native_constraint([], "is_one", [cst])
|
|
op0 = pdl.OperationOp(
|
|
name="myint.constant", attributes={"value": cst}, types=[t]
|
|
)
|
|
|
|
@pdl.rewrite()
|
|
def rew():
|
|
expanded = pdl.apply_native_rewrite(
|
|
[pdl.OperationType.get()], "expand", [cst]
|
|
)
|
|
pdl.ReplaceOp(op0, with_op=expanded)
|
|
|
|
def is_one(rewriter, results, values):
|
|
cst = values[0].value
|
|
return cst <= 1
|
|
|
|
def expand(rewriter, results, values):
|
|
cst = values[0].value
|
|
c1 = cst // 2
|
|
c2 = cst - c1
|
|
with rewriter.ip:
|
|
op1 = Operation.create(
|
|
"myint.constant",
|
|
results=[i32],
|
|
attributes={"value": IntegerAttr.get(i32, c1)},
|
|
)
|
|
op2 = Operation.create(
|
|
"myint.constant",
|
|
results=[i32],
|
|
attributes={"value": IntegerAttr.get(i32, c2)},
|
|
)
|
|
res = Operation.create(
|
|
"myint.add", results=[i32], operands=[op1.result, op2.result]
|
|
)
|
|
results.append(res)
|
|
|
|
pdl_module = PDLModule(m)
|
|
pdl_module.register_constraint_function("is_one", is_one)
|
|
pdl_module.register_rewrite_function("expand", expand)
|
|
return pdl_module.freeze()
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_pdl_register_function_expand
|
|
# CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
|
|
# CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
|
|
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
|
|
# CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
|
|
# CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
|
|
# CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
|
|
# CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
|
|
# CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
|
|
# CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
|
|
# CHECK: return %8 : i32
|
|
@construct_and_print_in_module
|
|
def test_pdl_register_function_expand(module_):
|
|
load_myint_dialect()
|
|
|
|
module_ = Module.parse(
|
|
"""
|
|
func.func @f() -> i32 {
|
|
%0 = "myint.constant"() { value = 5 }: () -> (i32)
|
|
return %0 : i32
|
|
}
|
|
"""
|
|
)
|
|
|
|
frozen = get_pdl_pattern_expand()
|
|
apply_patterns_and_fold_greedily(module_, frozen)
|
|
|
|
return module_
|