Files
llvm/mlir/test/python/ir/operation.py
Stella Laurenzo ace1d0ad3d [mlir][python] Normalize asm-printing IR behavior.
While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were:
  * Only doing "soft fail" verification on IR printing of Operation, not of a Module.
  * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle).
  * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors.

This patch:
  * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation.
  * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works).
  * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it.

It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed:
  * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level).
  * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing).

Notes to downstreams:
  * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`.

Differential Revision: https://reviews.llvm.org/D114680
2021-11-28 18:02:01 -08:00

919 lines
25 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
return f
# Verify iterator based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
@run
def testTraverseOpRegionBlockIterators():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""", ctx)
op = module.operation
assert op.context is ctx
# Get the block using iterators off of the named collections.
regions = list(op.regions)
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
# Should verify.
# CHECK: .verify = True
print(f".verify = {module.operation.verify()}")
# Get the regions and blocks from the default collections.
default_regions = list(op.regions)
default_blocks = list(default_regions[0])
# They should compare equal regardless of how obtained.
assert default_regions == regions
assert default_blocks == blocks
# Should be able to get the operations from either the named collection
# or the block.
operations = list(blocks[0].operations)
default_operations = list(blocks[0])
assert default_operations == operations
def walk_operations(indent, op):
for i, region in enumerate(op.regions):
print(f"{indent}REGION {i}:")
for j, block in enumerate(region):
print(f"{indent} BLOCK {j}:")
for k, child_op in enumerate(block):
print(f"{indent} OP {k}: {child_op}")
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 1: return
walk_operations("", op)
# Verify index based traversal of the op/region/block hierarchy.
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
@run
def testTraverseOpRegionBlockIndices():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""", ctx)
def walk_operations(indent, op):
for i in range(len(op.regions)):
region = op.regions[i]
print(f"{indent}REGION {i}:")
for j in range(len(region.blocks)):
block = region.blocks[j]
print(f"{indent} BLOCK {j}:")
for k in range(len(block.operations)):
child_op = block.operations[k]
print(f"{indent} OP {k}: {child_op}")
print(f"{indent} OP {k}: parent {child_op.operation.parent.name}")
walk_operations(indent + " ", child_op)
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: func
# CHECK: OP 0: parent builtin.module
# CHECK: REGION 0:
# CHECK: BLOCK 0:
# CHECK: OP 0: %0 = "custom.addi"
# CHECK: OP 0: parent builtin.func
# CHECK: OP 1: return
# CHECK: OP 1: parent builtin.func
walk_operations("", module.operation)
# CHECK-LABEL: TEST: testBlockAndRegionOwners
@run
def testBlockAndRegionOwners():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
builtin.module {
builtin.func @f() {
std.return
}
}
""", ctx)
assert module.operation.regions[0].owner == module.operation
assert module.operation.regions[0].blocks[0].owner == module.operation
func = module.body.operations[0]
assert func.operation.regions[0].owner == func
assert func.operation.regions[0].blocks[0].owner == func
# CHECK-LABEL: TEST: testBlockArgumentList
@run
def testBlockArgumentList():
with Context() as ctx:
module = Module.parse(
r"""
func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
return
}
""", ctx)
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
assert len(entry_block.arguments) == 3
# CHECK: Argument 0, type i32
# CHECK: Argument 1, type f64
# CHECK: Argument 2, type index
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
arg.set_type(new_type)
# CHECK: Argument 0, type i8
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that slicing works for block argument lists.
# CHECK: Argument 1, type i16
# CHECK: Argument 2, type i24
for arg in entry_block.arguments[1:]:
print(f"Argument {arg.arg_number}, type {arg.type}")
# Check that we can concatenate slices of argument lists.
# CHECK: Length: 4
print("Length: ",
len(entry_block.arguments[:2] + entry_block.arguments[1:]))
# CHECK: Type: i8
# CHECK: Type: i16
# CHECK: Type: i24
for t in entry_block.arguments.types:
print("Type: ", t)
# CHECK-LABEL: TEST: testOperationOperands
@run
def testOperationOperands():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
func @f1(%arg0: i32) {
%0 = "test.producer"() : () -> i64
"test.consumer"(%arg0, %0) : (i32, i64) -> ()
return
}""")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[1]
assert len(consumer.operands) == 2
# CHECK: Operand 0, type i32
# CHECK: Operand 1, type i64
for i, operand in enumerate(consumer.operands):
print(f"Operand {i}, type {operand.type}")
# CHECK-LABEL: TEST: testOperationOperandsSlice
@run
def testOperationOperandsSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
%3 = "test.producer3"() : () -> i64
%4 = "test.producer4"() : () -> i64
"test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
return
}""")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
consumer = entry_block.operations[5]
assert len(consumer.operands) == 5
for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
assert left == right
# CHECK: test.producer0
# CHECK: test.producer1
# CHECK: test.producer2
# CHECK: test.producer3
# CHECK: test.producer4
full_slice = consumer.operands[:]
for operand in full_slice:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer1
first_two = consumer.operands[0:2]
for operand in first_two:
print(operand)
# CHECK: test.producer3
# CHECK: test.producer4
last_two = consumer.operands[3:]
for operand in last_two:
print(operand)
# CHECK: test.producer0
# CHECK: test.producer2
# CHECK: test.producer4
even = consumer.operands[::2]
for operand in even:
print(operand)
# CHECK: test.producer2
fourth = consumer.operands[::2][1::2]
for operand in fourth:
print(operand)
# CHECK-LABEL: TEST: testOperationOperandsSet
@run
def testOperationOperandsSet():
with Context() as ctx, Location.unknown(ctx):
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
func @f1() {
%0 = "test.producer0"() : () -> i64
%1 = "test.producer1"() : () -> i64
%2 = "test.producer2"() : () -> i64
"test.consumer"(%0) : (i64) -> ()
return
}""")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer1 = entry_block.operations[1]
producer2 = entry_block.operations[2]
consumer = entry_block.operations[3]
assert len(consumer.operands) == 1
type = consumer.operands[0].type
# CHECK: test.producer1
consumer.operands[0] = producer1.result
print(consumer.operands[0])
# CHECK: test.producer2
consumer.operands[-1] = producer2.result
print(consumer.operands[0])
# CHECK-LABEL: TEST: testDetachedOperation
@run
def testDetachedOperation():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create(
"custom.op1",
results=[i32, i32],
regions=1,
attributes={
"foo": StringAttr.get("foo_value"),
"bar": StringAttr.get("bar_value"),
})
# CHECK: %0:2 = "custom.op1"() ( {
# CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
print(op1)
# TODO: Check successors once enough infra exists to do it properly.
# CHECK-LABEL: TEST: testOperationInsertionPoint
@run
def testOperationInsertionPoint():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""", ctx)
# Create test op.
with Location.unknown(ctx):
op1 = Operation.create("custom.op1")
op2 = Operation.create("custom.op2")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
ip.insert(op2)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.op2"()
# CHECK: %0 = "custom.addi"
print(module)
# Trying to add a previously added op should raise.
try:
ip.insert(op1)
except ValueError:
pass
else:
assert False, "expected insert of attached op to raise"
# CHECK-LABEL: TEST: testOperationWithRegion
@run
def testOperationWithRegion():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signed(32)
op1 = Operation.create("custom.op1", regions=1)
block = op1.regions[0].blocks.append(i32, i32)
# CHECK: "custom.op1"() ( {
# CHECK: ^bb0(%arg0: si32, %arg1: si32): // no predecessors
# CHECK: "custom.terminator"() : () -> ()
# CHECK: }) : () -> ()
terminator = Operation.create("custom.terminator")
ip = InsertionPoint(block)
ip.insert(terminator)
print(op1)
# Now add the whole operation to another op.
# TODO: Verify lifetime hazard by nulling out the new owning module and
# accessing op1.
# TODO: Also verify accessing the terminator once both parents are nulled
# out.
module = Module.parse(r"""
func @f1(%arg0: i32) -> i32 {
%1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
return %1 : i32
}
""")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(op1)
# CHECK: func @f1
# CHECK: "custom.op1"()
# CHECK: "custom.terminator"
# CHECK: %0 = "custom.addi"
print(module)
# CHECK-LABEL: TEST: testOperationResultList
@run
def testOperationResultList():
ctx = Context()
module = Module.parse(
r"""
func @f1() {
%0:3 = call @f2() : () -> (i32, f64, index)
return
}
func private @f2() -> (i32, f64, index)
""", ctx)
caller = module.body.operations[0]
call = caller.regions[0].blocks[0].operations[0]
assert len(call.results) == 3
# CHECK: Result 0, type i32
# CHECK: Result 1, type f64
# CHECK: Result 2, type index
for res in call.results:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result type i32
# CHECK: Result type f64
# CHECK: Result type index
for t in call.results.types:
print(f"Result type {t}")
# CHECK-LABEL: TEST: testOperationResultListSlice
@run
def testOperationResultListSlice():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
module = Module.parse(r"""
func @f1() {
"some.op"() : () -> (i1, i2, i3, i4, i5)
return
}
""")
func = module.body.operations[0]
entry_block = func.regions[0].blocks[0]
producer = entry_block.operations[0]
assert len(producer.results) == 5
for left, right in zip(producer.results, producer.results[::-1][::-1]):
assert left == right
assert left.result_number == right.result_number
# CHECK: Result 0, type i1
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
# CHECK: Result 4, type i5
full_slice = producer.results[:]
for res in full_slice:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 2, type i3
# CHECK: Result 3, type i4
middle = producer.results[1:4]
for res in middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 1, type i2
# CHECK: Result 3, type i4
odd = producer.results[1::2]
for res in odd:
print(f"Result {res.result_number}, type {res.type}")
# CHECK: Result 3, type i4
# CHECK: Result 1, type i2
inverted_middle = producer.results[-2:0:-2]
for res in inverted_middle:
print(f"Result {res.result_number}, type {res.type}")
# CHECK-LABEL: TEST: testOperationAttributes
@run
def testOperationAttributes():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
"some.op"() { some.attribute = 1 : i8,
other.attribute = 3.0,
dependent = "text" } : () -> ()
""", ctx)
op = module.body.operations[0]
assert len(op.attributes) == 3
iattr = IntegerAttr(op.attributes["some.attribute"])
fattr = FloatAttr(op.attributes["other.attribute"])
sattr = StringAttr(op.attributes["dependent"])
# CHECK: Attribute type i8, value 1
print(f"Attribute type {iattr.type}, value {iattr.value}")
# CHECK: Attribute type f64, value 3.0
print(f"Attribute type {fattr.type}, value {fattr.value}")
# CHECK: Attribute value text
print(f"Attribute value {sattr.value}")
# We don't know in which order the attributes are stored.
# CHECK-DAG: NamedAttribute(dependent="text")
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
for attr in op.attributes:
print(str(attr))
# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
except KeyError:
pass
else:
assert False, "expected KeyError on accessing a non-existent attribute"
try:
op.attributes[42]
except IndexError:
pass
else:
assert False, "expected IndexError on accessing an out-of-bounds attribute"
# CHECK-LABEL: TEST: testOperationPrint
@run
def testOperationPrint():
ctx = Context()
module = Module.parse(
r"""
func @f1(%arg0: i32) -> i32 {
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
return %arg0 : i32
}
""", ctx)
# Test print to stdout.
# CHECK: return %arg0 : i32
module.operation.print()
# Test print to text file.
f = io.StringIO()
# CHECK: <class 'str'>
# CHECK: return %arg0 : i32
module.operation.print(file=f)
str_value = f.getvalue()
print(str_value.__class__)
print(f.getvalue())
# Test print to binary file.
f = io.BytesIO()
# CHECK: <class 'bytes'>
# CHECK: return %arg0 : i32
module.operation.print(file=f, binary=True)
bytes_value = f.getvalue()
print(bytes_value.__class__)
print(bytes_value)
# Test get_asm with options.
# CHECK: value = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>
# CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
module.operation.print(
large_elements_limit=2,
enable_debug_info=True,
pretty_debug_info=True,
print_generic_op_form=True,
use_local_scope=True)
# CHECK-LABEL: TEST: testKnownOpView
@run
def testKnownOpView():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(r"""
%1 = "custom.f32"() : () -> f32
%2 = "custom.f32"() : () -> f32
%3 = arith.addf %1, %2 : f32
""")
print(module)
# addf should map to a known OpView class in the std dialect.
# We know the OpView for it defines an 'lhs' attribute.
addf = module.body.operations[2]
# CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
print(repr(addf))
# CHECK: "custom.f32"()
print(addf.lhs)
# One of the custom ops should resolve to the default OpView.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# Check again to make sure negative caching works.
custom = module.body.operations[0]
# CHECK: OpView object
print(repr(custom))
# CHECK-LABEL: TEST: testSingleResultProperty
@run
def testSingleResultProperty():
with Context(), Location.unknown():
Context.current.allow_unregistered_dialects = True
module = Module.parse(r"""
"custom.no_result"() : () -> ()
%0:2 = "custom.two_result"() : () -> (f32, f32)
%1 = "custom.one_result"() : () -> f32
""")
print(module)
try:
module.body.operations[0].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.no_result which has 0 results
print(e)
else:
assert False, "Expected exception"
try:
module.body.operations[1].result
except ValueError as e:
# CHECK: Cannot call .result on operation custom.two_result which has 2 results
print(e)
else:
assert False, "Expected exception"
# CHECK: %1 = "custom.one_result"() : () -> f32
print(module.body.operations[2])
def create_invalid_operation():
# This module has two region and is invalid verify that we fallback
# to the generic printer for safety.
op = Operation.create("builtin.module", regions=2)
op.regions[0].blocks.append()
return op
# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
@run
def testInvalidOperationStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: // Verification failed, printing generic form
# CHECK: "builtin.module"() ( {
# CHECK: }) : () -> ()
print(invalid_op)
# CHECK: .verify = False
print(f".verify = {invalid_op.operation.verify()}")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@run
def testInvalidModuleStrSoftFails():
ctx = Context()
with Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: // Verification failed, printing generic form
print(module)
# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
@run
def testInvalidOperationGetAsmBinarySoftFails():
ctx = Context()
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
# CHECK: b'// Verification failed, printing generic form\n
print(invalid_op.get_asm(binary=True))
# CHECK-LABEL: TEST: testCreateWithInvalidAttributes
@run
def testCreateWithInvalidAttributes():
ctx = Context()
with Location.unknown(ctx):
try:
Operation.create(
"builtin.module", attributes={None: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create(
"builtin.module", attributes={42: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": ctx})
except Exception as e:
# CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
try:
Operation.create("builtin.module", attributes={"some_key": None})
except Exception as e:
# CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
# CHECK-LABEL: TEST: testOperationName
@run
def testOperationName():
ctx = Context()
ctx.allow_unregistered_dialects = True
module = Module.parse(
r"""
%0 = "custom.op1"() : () -> f32
%1 = "custom.op2"() : () -> i32
%2 = "custom.op1"() : () -> f32
""", ctx)
# CHECK: custom.op1
# CHECK: custom.op2
# CHECK: custom.op1
for op in module.body.operations:
print(op.operation.name)
# CHECK-LABEL: TEST: testCapsuleConversions
@run
def testCapsuleConversions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Operation.create("custom.op1").operation
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
assert m2 is m
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
m = Module.create()
with InsertionPoint(m.body):
op = Operation.create("custom.op1")
# CHECK: "custom.op1"
print(m)
op.operation.erase()
# CHECK-NOT: "custom.op1"
print(m)
# Ensure we can create another operation
Operation.create("custom.op2")
# CHECK-LABEL: TEST: testOperationLoc
@run
def testOperationLoc():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx:
loc = Location.name("loc")
op = Operation.create("custom.op", loc=loc)
assert op.location == loc
assert op.operation.location == loc
# CHECK-LABEL: TEST: testModuleMerge
@run
def testModuleMerge():
with Context():
m1 = Module.parse("func private @foo()")
m2 = Module.parse("""
func private @bar()
func private @qux()
""")
foo = m1.body.operations[0]
bar = m2.body.operations[0]
qux = m2.body.operations[1]
bar.move_before(foo)
qux.move_after(foo)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
# CHECK: func private @qux
print(m1)
# CHECK: module {
# CHECK-NEXT: }
print(m2)
# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
@run
def testAppendMoveFromAnotherBlock():
with Context():
m1 = Module.parse("func private @foo()")
m2 = Module.parse("func private @bar()")
func = m1.body.operations[0]
m2.body.append(func)
# CHECK: module
# CHECK: func private @bar
# CHECK: func private @foo
print(m2)
# CHECK: module {
# CHECK-NEXT: }
print(m1)
# CHECK-LABEL: TEST: testDetachFromParent
@run
def testDetachFromParent():
with Context():
m1 = Module.parse("func private @foo()")
func = m1.body.operations[0].detach_from_parent()
try:
func.detach_from_parent()
except ValueError as e:
if "has no parent" not in str(e):
raise
else:
assert False, "expected ValueError when detaching a detached operation"
print(m1)
# CHECK-NOT: func private @foo
# CHECK-LABEL: TEST: testSymbolTable
@run
def testSymbolTable():
with Context() as ctx:
ctx.allow_unregistered_dialects = True
m1 = Module.parse("""
func private @foo()
func private @bar()""")
m2 = Module.parse("""
func private @qux()
func private @foo()
"foo.bar"() : () -> ()""")
symbol_table = SymbolTable(m1.operation)
# CHECK: func private @foo
# CHECK: func private @bar
assert "foo" in symbol_table
print(symbol_table["foo"])
assert "bar" in symbol_table
bar = symbol_table["bar"]
print(symbol_table["bar"])
assert "qux" not in symbol_table
del symbol_table["bar"]
try:
symbol_table.erase(symbol_table["bar"])
except KeyError:
pass
else:
assert False, "expected KeyError"
# CHECK: module
# CHECK: func private @foo()
print(m1)
assert "bar" not in symbol_table
try:
print(bar)
except RuntimeError as e:
if "the operation has been invalidated" not in str(e):
raise
else:
assert False, "expected RuntimeError due to invalidated operation"
qux = m2.body.operations[0]
m1.body.append(qux)
symbol_table.insert(qux)
assert "qux" in symbol_table
# Check that insertion actually renames this symbol in the symbol table.
foo2 = m2.body.operations[0]
m1.body.append(foo2)
updated_name = symbol_table.insert(foo2)
assert foo2.name.value != "foo"
assert foo2.name == updated_name
# CHECK: module
# CHECK: func private @foo()
# CHECK: func private @qux()
# CHECK: func private @foo{{.*}}
print(m1)
try:
symbol_table.insert(m2.body.operations[0])
except ValueError as e:
if "Expected operation to have a symbol name" not in str(e):
raise
else:
assert False, "exepcted ValueError when adding a non-symbol"
# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
ctx = Context()
ctx.allow_unregistered_dialects = True
with ctx, Location.unknown():
op = Operation.create("custom.op1")
assert hash(op) == hash(op.operation)