mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were: * Only doing "soft fail" verification on IR printing of Operation, not of a Module. * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle). * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors. This patch: * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation. * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works). * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it. It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed: * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level). * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing). Notes to downstreams: * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`. Differential Revision: https://reviews.llvm.org/D114680
919 lines
25 KiB
Python
919 lines
25 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
import io
|
|
import itertools
|
|
from mlir.ir import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# Verify iterator based traversal of the op/region/block hierarchy.
|
|
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
|
|
@run
|
|
def testTraverseOpRegionBlockIterators():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1(%arg0: i32) -> i32 {
|
|
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
|
return %1 : i32
|
|
}
|
|
""", ctx)
|
|
op = module.operation
|
|
assert op.context is ctx
|
|
# Get the block using iterators off of the named collections.
|
|
regions = list(op.regions)
|
|
blocks = list(regions[0].blocks)
|
|
# CHECK: MODULE REGIONS=1 BLOCKS=1
|
|
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
|
|
|
|
# Should verify.
|
|
# CHECK: .verify = True
|
|
print(f".verify = {module.operation.verify()}")
|
|
|
|
# Get the regions and blocks from the default collections.
|
|
default_regions = list(op.regions)
|
|
default_blocks = list(default_regions[0])
|
|
# They should compare equal regardless of how obtained.
|
|
assert default_regions == regions
|
|
assert default_blocks == blocks
|
|
|
|
# Should be able to get the operations from either the named collection
|
|
# or the block.
|
|
operations = list(blocks[0].operations)
|
|
default_operations = list(blocks[0])
|
|
assert default_operations == operations
|
|
|
|
def walk_operations(indent, op):
|
|
for i, region in enumerate(op.regions):
|
|
print(f"{indent}REGION {i}:")
|
|
for j, block in enumerate(region):
|
|
print(f"{indent} BLOCK {j}:")
|
|
for k, child_op in enumerate(block):
|
|
print(f"{indent} OP {k}: {child_op}")
|
|
walk_operations(indent + " ", child_op)
|
|
|
|
# CHECK: REGION 0:
|
|
# CHECK: BLOCK 0:
|
|
# CHECK: OP 0: func
|
|
# CHECK: REGION 0:
|
|
# CHECK: BLOCK 0:
|
|
# CHECK: OP 0: %0 = "custom.addi"
|
|
# CHECK: OP 1: return
|
|
walk_operations("", op)
|
|
|
|
|
|
# Verify index based traversal of the op/region/block hierarchy.
|
|
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
|
|
@run
|
|
def testTraverseOpRegionBlockIndices():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1(%arg0: i32) -> i32 {
|
|
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
|
return %1 : i32
|
|
}
|
|
""", ctx)
|
|
|
|
def walk_operations(indent, op):
|
|
for i in range(len(op.regions)):
|
|
region = op.regions[i]
|
|
print(f"{indent}REGION {i}:")
|
|
for j in range(len(region.blocks)):
|
|
block = region.blocks[j]
|
|
print(f"{indent} BLOCK {j}:")
|
|
for k in range(len(block.operations)):
|
|
child_op = block.operations[k]
|
|
print(f"{indent} OP {k}: {child_op}")
|
|
print(f"{indent} OP {k}: parent {child_op.operation.parent.name}")
|
|
walk_operations(indent + " ", child_op)
|
|
|
|
# CHECK: REGION 0:
|
|
# CHECK: BLOCK 0:
|
|
# CHECK: OP 0: func
|
|
# CHECK: OP 0: parent builtin.module
|
|
# CHECK: REGION 0:
|
|
# CHECK: BLOCK 0:
|
|
# CHECK: OP 0: %0 = "custom.addi"
|
|
# CHECK: OP 0: parent builtin.func
|
|
# CHECK: OP 1: return
|
|
# CHECK: OP 1: parent builtin.func
|
|
walk_operations("", module.operation)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testBlockAndRegionOwners
|
|
@run
|
|
def testBlockAndRegionOwners():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
builtin.module {
|
|
builtin.func @f() {
|
|
std.return
|
|
}
|
|
}
|
|
""", ctx)
|
|
|
|
assert module.operation.regions[0].owner == module.operation
|
|
assert module.operation.regions[0].blocks[0].owner == module.operation
|
|
|
|
func = module.body.operations[0]
|
|
assert func.operation.regions[0].owner == func
|
|
assert func.operation.regions[0].blocks[0].owner == func
|
|
|
|
|
|
# CHECK-LABEL: TEST: testBlockArgumentList
|
|
@run
|
|
def testBlockArgumentList():
|
|
with Context() as ctx:
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
|
|
return
|
|
}
|
|
""", ctx)
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
assert len(entry_block.arguments) == 3
|
|
# CHECK: Argument 0, type i32
|
|
# CHECK: Argument 1, type f64
|
|
# CHECK: Argument 2, type index
|
|
for arg in entry_block.arguments:
|
|
print(f"Argument {arg.arg_number}, type {arg.type}")
|
|
new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
|
|
arg.set_type(new_type)
|
|
|
|
# CHECK: Argument 0, type i8
|
|
# CHECK: Argument 1, type i16
|
|
# CHECK: Argument 2, type i24
|
|
for arg in entry_block.arguments:
|
|
print(f"Argument {arg.arg_number}, type {arg.type}")
|
|
|
|
# Check that slicing works for block argument lists.
|
|
# CHECK: Argument 1, type i16
|
|
# CHECK: Argument 2, type i24
|
|
for arg in entry_block.arguments[1:]:
|
|
print(f"Argument {arg.arg_number}, type {arg.type}")
|
|
|
|
# Check that we can concatenate slices of argument lists.
|
|
# CHECK: Length: 4
|
|
print("Length: ",
|
|
len(entry_block.arguments[:2] + entry_block.arguments[1:]))
|
|
|
|
# CHECK: Type: i8
|
|
# CHECK: Type: i16
|
|
# CHECK: Type: i24
|
|
for t in entry_block.arguments.types:
|
|
print("Type: ", t)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationOperands
|
|
@run
|
|
def testOperationOperands():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
func @f1(%arg0: i32) {
|
|
%0 = "test.producer"() : () -> i64
|
|
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
|
|
return
|
|
}""")
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
consumer = entry_block.operations[1]
|
|
assert len(consumer.operands) == 2
|
|
# CHECK: Operand 0, type i32
|
|
# CHECK: Operand 1, type i64
|
|
for i, operand in enumerate(consumer.operands):
|
|
print(f"Operand {i}, type {operand.type}")
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationOperandsSlice
|
|
@run
|
|
def testOperationOperandsSlice():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
func @f1() {
|
|
%0 = "test.producer0"() : () -> i64
|
|
%1 = "test.producer1"() : () -> i64
|
|
%2 = "test.producer2"() : () -> i64
|
|
%3 = "test.producer3"() : () -> i64
|
|
%4 = "test.producer4"() : () -> i64
|
|
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
|
|
return
|
|
}""")
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
consumer = entry_block.operations[5]
|
|
assert len(consumer.operands) == 5
|
|
for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
|
|
assert left == right
|
|
|
|
# CHECK: test.producer0
|
|
# CHECK: test.producer1
|
|
# CHECK: test.producer2
|
|
# CHECK: test.producer3
|
|
# CHECK: test.producer4
|
|
full_slice = consumer.operands[:]
|
|
for operand in full_slice:
|
|
print(operand)
|
|
|
|
# CHECK: test.producer0
|
|
# CHECK: test.producer1
|
|
first_two = consumer.operands[0:2]
|
|
for operand in first_two:
|
|
print(operand)
|
|
|
|
# CHECK: test.producer3
|
|
# CHECK: test.producer4
|
|
last_two = consumer.operands[3:]
|
|
for operand in last_two:
|
|
print(operand)
|
|
|
|
# CHECK: test.producer0
|
|
# CHECK: test.producer2
|
|
# CHECK: test.producer4
|
|
even = consumer.operands[::2]
|
|
for operand in even:
|
|
print(operand)
|
|
|
|
# CHECK: test.producer2
|
|
fourth = consumer.operands[::2][1::2]
|
|
for operand in fourth:
|
|
print(operand)
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationOperandsSet
|
|
@run
|
|
def testOperationOperandsSet():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
func @f1() {
|
|
%0 = "test.producer0"() : () -> i64
|
|
%1 = "test.producer1"() : () -> i64
|
|
%2 = "test.producer2"() : () -> i64
|
|
"test.consumer"(%0) : (i64) -> ()
|
|
return
|
|
}""")
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
producer1 = entry_block.operations[1]
|
|
producer2 = entry_block.operations[2]
|
|
consumer = entry_block.operations[3]
|
|
assert len(consumer.operands) == 1
|
|
type = consumer.operands[0].type
|
|
|
|
# CHECK: test.producer1
|
|
consumer.operands[0] = producer1.result
|
|
print(consumer.operands[0])
|
|
|
|
# CHECK: test.producer2
|
|
consumer.operands[-1] = producer2.result
|
|
print(consumer.operands[0])
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDetachedOperation
|
|
@run
|
|
def testDetachedOperation():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signed(32)
|
|
op1 = Operation.create(
|
|
"custom.op1",
|
|
results=[i32, i32],
|
|
regions=1,
|
|
attributes={
|
|
"foo": StringAttr.get("foo_value"),
|
|
"bar": StringAttr.get("bar_value"),
|
|
})
|
|
# CHECK: %0:2 = "custom.op1"() ( {
|
|
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
|
|
print(op1)
|
|
|
|
# TODO: Check successors once enough infra exists to do it properly.
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationInsertionPoint
|
|
@run
|
|
def testOperationInsertionPoint():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1(%arg0: i32) -> i32 {
|
|
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
|
return %1 : i32
|
|
}
|
|
""", ctx)
|
|
|
|
# Create test op.
|
|
with Location.unknown(ctx):
|
|
op1 = Operation.create("custom.op1")
|
|
op2 = Operation.create("custom.op2")
|
|
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
ip = InsertionPoint.at_block_begin(entry_block)
|
|
ip.insert(op1)
|
|
ip.insert(op2)
|
|
# CHECK: func @f1
|
|
# CHECK: "custom.op1"()
|
|
# CHECK: "custom.op2"()
|
|
# CHECK: %0 = "custom.addi"
|
|
print(module)
|
|
|
|
# Trying to add a previously added op should raise.
|
|
try:
|
|
ip.insert(op1)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
assert False, "expected insert of attached op to raise"
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationWithRegion
|
|
@run
|
|
def testOperationWithRegion():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signed(32)
|
|
op1 = Operation.create("custom.op1", regions=1)
|
|
block = op1.regions[0].blocks.append(i32, i32)
|
|
# CHECK: "custom.op1"() ( {
|
|
# CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
|
|
# CHECK: "custom.terminator"() : () -> ()
|
|
# CHECK: }) : () -> ()
|
|
terminator = Operation.create("custom.terminator")
|
|
ip = InsertionPoint(block)
|
|
ip.insert(terminator)
|
|
print(op1)
|
|
|
|
# Now add the whole operation to another op.
|
|
# TODO: Verify lifetime hazard by nulling out the new owning module and
|
|
# accessing op1.
|
|
# TODO: Also verify accessing the terminator once both parents are nulled
|
|
# out.
|
|
module = Module.parse(r"""
|
|
func @f1(%arg0: i32) -> i32 {
|
|
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
|
|
return %1 : i32
|
|
}
|
|
""")
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
ip = InsertionPoint.at_block_begin(entry_block)
|
|
ip.insert(op1)
|
|
# CHECK: func @f1
|
|
# CHECK: "custom.op1"()
|
|
# CHECK: "custom.terminator"
|
|
# CHECK: %0 = "custom.addi"
|
|
print(module)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationResultList
|
|
@run
|
|
def testOperationResultList():
|
|
ctx = Context()
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1() {
|
|
%0:3 = call @f2() : () -> (i32, f64, index)
|
|
return
|
|
}
|
|
func private @f2() -> (i32, f64, index)
|
|
""", ctx)
|
|
caller = module.body.operations[0]
|
|
call = caller.regions[0].blocks[0].operations[0]
|
|
assert len(call.results) == 3
|
|
# CHECK: Result 0, type i32
|
|
# CHECK: Result 1, type f64
|
|
# CHECK: Result 2, type index
|
|
for res in call.results:
|
|
print(f"Result {res.result_number}, type {res.type}")
|
|
|
|
# CHECK: Result type i32
|
|
# CHECK: Result type f64
|
|
# CHECK: Result type index
|
|
for t in call.results.types:
|
|
print(f"Result type {t}")
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationResultListSlice
|
|
@run
|
|
def testOperationResultListSlice():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
func @f1() {
|
|
"some.op"() : () -> (i1, i2, i3, i4, i5)
|
|
return
|
|
}
|
|
""")
|
|
func = module.body.operations[0]
|
|
entry_block = func.regions[0].blocks[0]
|
|
producer = entry_block.operations[0]
|
|
|
|
assert len(producer.results) == 5
|
|
for left, right in zip(producer.results, producer.results[::-1][::-1]):
|
|
assert left == right
|
|
assert left.result_number == right.result_number
|
|
|
|
# CHECK: Result 0, type i1
|
|
# CHECK: Result 1, type i2
|
|
# CHECK: Result 2, type i3
|
|
# CHECK: Result 3, type i4
|
|
# CHECK: Result 4, type i5
|
|
full_slice = producer.results[:]
|
|
for res in full_slice:
|
|
print(f"Result {res.result_number}, type {res.type}")
|
|
|
|
# CHECK: Result 1, type i2
|
|
# CHECK: Result 2, type i3
|
|
# CHECK: Result 3, type i4
|
|
middle = producer.results[1:4]
|
|
for res in middle:
|
|
print(f"Result {res.result_number}, type {res.type}")
|
|
|
|
# CHECK: Result 1, type i2
|
|
# CHECK: Result 3, type i4
|
|
odd = producer.results[1::2]
|
|
for res in odd:
|
|
print(f"Result {res.result_number}, type {res.type}")
|
|
|
|
# CHECK: Result 3, type i4
|
|
# CHECK: Result 1, type i2
|
|
inverted_middle = producer.results[-2:0:-2]
|
|
for res in inverted_middle:
|
|
print(f"Result {res.result_number}, type {res.type}")
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationAttributes
|
|
@run
|
|
def testOperationAttributes():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
"some.op"() { some.attribute = 1 : i8,
|
|
other.attribute = 3.0,
|
|
dependent = "text" } : () -> ()
|
|
""", ctx)
|
|
op = module.body.operations[0]
|
|
assert len(op.attributes) == 3
|
|
iattr = IntegerAttr(op.attributes["some.attribute"])
|
|
fattr = FloatAttr(op.attributes["other.attribute"])
|
|
sattr = StringAttr(op.attributes["dependent"])
|
|
# CHECK: Attribute type i8, value 1
|
|
print(f"Attribute type {iattr.type}, value {iattr.value}")
|
|
# CHECK: Attribute type f64, value 3.0
|
|
print(f"Attribute type {fattr.type}, value {fattr.value}")
|
|
# CHECK: Attribute value text
|
|
print(f"Attribute value {sattr.value}")
|
|
|
|
# We don't know in which order the attributes are stored.
|
|
# CHECK-DAG: NamedAttribute(dependent="text")
|
|
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
|
|
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
|
|
for attr in op.attributes:
|
|
print(str(attr))
|
|
|
|
# Check that exceptions are raised as expected.
|
|
try:
|
|
op.attributes["does_not_exist"]
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
assert False, "expected KeyError on accessing a non-existent attribute"
|
|
|
|
try:
|
|
op.attributes[42]
|
|
except IndexError:
|
|
pass
|
|
else:
|
|
assert False, "expected IndexError on accessing an out-of-bounds attribute"
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationPrint
|
|
@run
|
|
def testOperationPrint():
|
|
ctx = Context()
|
|
module = Module.parse(
|
|
r"""
|
|
func @f1(%arg0: i32) -> i32 {
|
|
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
|
return %arg0 : i32
|
|
}
|
|
""", ctx)
|
|
|
|
# Test print to stdout.
|
|
# CHECK: return %arg0 : i32
|
|
module.operation.print()
|
|
|
|
# Test print to text file.
|
|
f = io.StringIO()
|
|
# CHECK: <class 'str'>
|
|
# CHECK: return %arg0 : i32
|
|
module.operation.print(file=f)
|
|
str_value = f.getvalue()
|
|
print(str_value.__class__)
|
|
print(f.getvalue())
|
|
|
|
# Test print to binary file.
|
|
f = io.BytesIO()
|
|
# CHECK: <class 'bytes'>
|
|
# CHECK: return %arg0 : i32
|
|
module.operation.print(file=f, binary=True)
|
|
bytes_value = f.getvalue()
|
|
print(bytes_value.__class__)
|
|
print(bytes_value)
|
|
|
|
# Test get_asm with options.
|
|
# CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>
|
|
# CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
|
|
module.operation.print(
|
|
large_elements_limit=2,
|
|
enable_debug_info=True,
|
|
pretty_debug_info=True,
|
|
print_generic_op_form=True,
|
|
use_local_scope=True)
|
|
|
|
|
|
|
|
|
|
# CHECK-LABEL: TEST: testKnownOpView
|
|
@run
|
|
def testKnownOpView():
|
|
with Context(), Location.unknown():
|
|
Context.current.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
%1 = "custom.f32"() : () -> f32
|
|
%2 = "custom.f32"() : () -> f32
|
|
%3 = arith.addf %1, %2 : f32
|
|
""")
|
|
print(module)
|
|
|
|
# addf should map to a known OpView class in the std dialect.
|
|
# We know the OpView for it defines an 'lhs' attribute.
|
|
addf = module.body.operations[2]
|
|
# CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
|
|
print(repr(addf))
|
|
# CHECK: "custom.f32"()
|
|
print(addf.lhs)
|
|
|
|
# One of the custom ops should resolve to the default OpView.
|
|
custom = module.body.operations[0]
|
|
# CHECK: OpView object
|
|
print(repr(custom))
|
|
|
|
# Check again to make sure negative caching works.
|
|
custom = module.body.operations[0]
|
|
# CHECK: OpView object
|
|
print(repr(custom))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testSingleResultProperty
|
|
@run
|
|
def testSingleResultProperty():
|
|
with Context(), Location.unknown():
|
|
Context.current.allow_unregistered_dialects = True
|
|
module = Module.parse(r"""
|
|
"custom.no_result"() : () -> ()
|
|
%0:2 = "custom.two_result"() : () -> (f32, f32)
|
|
%1 = "custom.one_result"() : () -> f32
|
|
""")
|
|
print(module)
|
|
|
|
try:
|
|
module.body.operations[0].result
|
|
except ValueError as e:
|
|
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
|
|
print(e)
|
|
else:
|
|
assert False, "Expected exception"
|
|
|
|
try:
|
|
module.body.operations[1].result
|
|
except ValueError as e:
|
|
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
|
|
print(e)
|
|
else:
|
|
assert False, "Expected exception"
|
|
|
|
# CHECK: %1 = "custom.one_result"() : () -> f32
|
|
print(module.body.operations[2])
|
|
|
|
|
|
def create_invalid_operation():
|
|
# This module has two region and is invalid verify that we fallback
|
|
# to the generic printer for safety.
|
|
op = Operation.create("builtin.module", regions=2)
|
|
op.regions[0].blocks.append()
|
|
return op
|
|
|
|
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
|
|
@run
|
|
def testInvalidOperationStrSoftFails():
|
|
ctx = Context()
|
|
with Location.unknown(ctx):
|
|
invalid_op = create_invalid_operation()
|
|
# Verify that we fallback to the generic printer for safety.
|
|
# CHECK: // Verification failed, printing generic form
|
|
# CHECK: "builtin.module"() ( {
|
|
# CHECK: }) : () -> ()
|
|
print(invalid_op)
|
|
# CHECK: .verify = False
|
|
print(f".verify = {invalid_op.operation.verify()}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
|
|
@run
|
|
def testInvalidModuleStrSoftFails():
|
|
ctx = Context()
|
|
with Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
invalid_op = create_invalid_operation()
|
|
# Verify that we fallback to the generic printer for safety.
|
|
# CHECK: // Verification failed, printing generic form
|
|
print(module)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
|
|
@run
|
|
def testInvalidOperationGetAsmBinarySoftFails():
|
|
ctx = Context()
|
|
with Location.unknown(ctx):
|
|
invalid_op = create_invalid_operation()
|
|
# Verify that we fallback to the generic printer for safety.
|
|
# CHECK: b'// Verification failed, printing generic form\n
|
|
print(invalid_op.get_asm(binary=True))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
|
|
@run
|
|
def testCreateWithInvalidAttributes():
|
|
ctx = Context()
|
|
with Location.unknown(ctx):
|
|
try:
|
|
Operation.create(
|
|
"builtin.module", attributes={None: StringAttr.get("name")})
|
|
except Exception as e:
|
|
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
|
|
print(e)
|
|
try:
|
|
Operation.create(
|
|
"builtin.module", attributes={42: StringAttr.get("name")})
|
|
except Exception as e:
|
|
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
|
|
print(e)
|
|
try:
|
|
Operation.create("builtin.module", attributes={"some_key": ctx})
|
|
except Exception as e:
|
|
# CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
|
|
print(e)
|
|
try:
|
|
Operation.create("builtin.module", attributes={"some_key": None})
|
|
except Exception as e:
|
|
# CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
|
|
print(e)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationName
|
|
@run
|
|
def testOperationName():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
%0 = "custom.op1"() : () -> f32
|
|
%1 = "custom.op2"() : () -> i32
|
|
%2 = "custom.op1"() : () -> f32
|
|
""", ctx)
|
|
|
|
# CHECK: custom.op1
|
|
# CHECK: custom.op2
|
|
# CHECK: custom.op1
|
|
for op in module.body.operations:
|
|
print(op.operation.name)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCapsuleConversions
|
|
@run
|
|
def testCapsuleConversions():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
m = Operation.create("custom.op1").operation
|
|
m_capsule = m._CAPIPtr
|
|
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
|
|
m2 = Operation._CAPICreate(m_capsule)
|
|
assert m2 is m
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationErase
|
|
@run
|
|
def testOperationErase():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
m = Module.create()
|
|
with InsertionPoint(m.body):
|
|
op = Operation.create("custom.op1")
|
|
|
|
# CHECK: "custom.op1"
|
|
print(m)
|
|
|
|
op.operation.erase()
|
|
|
|
# CHECK-NOT: "custom.op1"
|
|
print(m)
|
|
|
|
# Ensure we can create another operation
|
|
Operation.create("custom.op2")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationLoc
|
|
@run
|
|
def testOperationLoc():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with ctx:
|
|
loc = Location.name("loc")
|
|
op = Operation.create("custom.op", loc=loc)
|
|
assert op.location == loc
|
|
assert op.operation.location == loc
|
|
|
|
|
|
# CHECK-LABEL: TEST: testModuleMerge
|
|
@run
|
|
def testModuleMerge():
|
|
with Context():
|
|
m1 = Module.parse("func private @foo()")
|
|
m2 = Module.parse("""
|
|
func private @bar()
|
|
func private @qux()
|
|
""")
|
|
foo = m1.body.operations[0]
|
|
bar = m2.body.operations[0]
|
|
qux = m2.body.operations[1]
|
|
bar.move_before(foo)
|
|
qux.move_after(foo)
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @bar
|
|
# CHECK: func private @foo
|
|
# CHECK: func private @qux
|
|
print(m1)
|
|
|
|
# CHECK: module {
|
|
# CHECK-NEXT: }
|
|
print(m2)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
|
|
@run
|
|
def testAppendMoveFromAnotherBlock():
|
|
with Context():
|
|
m1 = Module.parse("func private @foo()")
|
|
m2 = Module.parse("func private @bar()")
|
|
func = m1.body.operations[0]
|
|
m2.body.append(func)
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @bar
|
|
# CHECK: func private @foo
|
|
|
|
print(m2)
|
|
# CHECK: module {
|
|
# CHECK-NEXT: }
|
|
print(m1)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testDetachFromParent
|
|
@run
|
|
def testDetachFromParent():
|
|
with Context():
|
|
m1 = Module.parse("func private @foo()")
|
|
func = m1.body.operations[0].detach_from_parent()
|
|
|
|
try:
|
|
func.detach_from_parent()
|
|
except ValueError as e:
|
|
if "has no parent" not in str(e):
|
|
raise
|
|
else:
|
|
assert False, "expected ValueError when detaching a detached operation"
|
|
|
|
print(m1)
|
|
# CHECK-NOT: func private @foo
|
|
|
|
|
|
# CHECK-LABEL: TEST: testSymbolTable
|
|
@run
|
|
def testSymbolTable():
|
|
with Context() as ctx:
|
|
ctx.allow_unregistered_dialects = True
|
|
m1 = Module.parse("""
|
|
func private @foo()
|
|
func private @bar()""")
|
|
m2 = Module.parse("""
|
|
func private @qux()
|
|
func private @foo()
|
|
"foo.bar"() : () -> ()""")
|
|
|
|
symbol_table = SymbolTable(m1.operation)
|
|
|
|
# CHECK: func private @foo
|
|
# CHECK: func private @bar
|
|
assert "foo" in symbol_table
|
|
print(symbol_table["foo"])
|
|
assert "bar" in symbol_table
|
|
bar = symbol_table["bar"]
|
|
print(symbol_table["bar"])
|
|
|
|
assert "qux" not in symbol_table
|
|
|
|
del symbol_table["bar"]
|
|
try:
|
|
symbol_table.erase(symbol_table["bar"])
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
assert False, "expected KeyError"
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @foo()
|
|
print(m1)
|
|
assert "bar" not in symbol_table
|
|
|
|
try:
|
|
print(bar)
|
|
except RuntimeError as e:
|
|
if "the operation has been invalidated" not in str(e):
|
|
raise
|
|
else:
|
|
assert False, "expected RuntimeError due to invalidated operation"
|
|
|
|
qux = m2.body.operations[0]
|
|
m1.body.append(qux)
|
|
symbol_table.insert(qux)
|
|
assert "qux" in symbol_table
|
|
|
|
# Check that insertion actually renames this symbol in the symbol table.
|
|
foo2 = m2.body.operations[0]
|
|
m1.body.append(foo2)
|
|
updated_name = symbol_table.insert(foo2)
|
|
assert foo2.name.value != "foo"
|
|
assert foo2.name == updated_name
|
|
|
|
# CHECK: module
|
|
# CHECK: func private @foo()
|
|
# CHECK: func private @qux()
|
|
# CHECK: func private @foo{{.*}}
|
|
print(m1)
|
|
|
|
try:
|
|
symbol_table.insert(m2.body.operations[0])
|
|
except ValueError as e:
|
|
if "Expected operation to have a symbol name" not in str(e):
|
|
raise
|
|
else:
|
|
assert False, "exepcted ValueError when adding a non-symbol"
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOperationHash
|
|
@run
|
|
def testOperationHash():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with ctx, Location.unknown():
|
|
op = Operation.create("custom.op1")
|
|
assert hash(op) == hash(op.operation)
|