Add support for custom ops in declarative builders.

This CL adds support for named custom instructions in declarative builders.
To allow this, it introduces a templated `CustomInstruction` class.
This CL also splits ValueHandle which can capture only the **value** in single-valued instructions from InstructionHandle which can capture any instruction but provide no typing and sugaring to extract the potential Value*.

PiperOrigin-RevId: 237543222
This commit is contained in:
Nicolas Vasilache
2019-03-08 16:41:25 -08:00
committed by jpienaar
parent 80d3568c0a
commit eb19b4eefc
7 changed files with 198 additions and 56 deletions

View File

@@ -38,6 +38,7 @@ struct index_t {
};
class BlockHandle;
class CapturableHandle;
class NestedBuilder;
class ValueHandle;
@@ -162,7 +163,7 @@ public:
/// In order to be admissible in a nested ArrayRef<ValueHandle>, operator()
/// returns a ValueHandle::null() that cannot be captured.
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
ValueHandle operator()(ArrayRef<ValueHandle> stmts);
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
};
/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid
@@ -192,7 +193,7 @@ public:
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps);
// TODO(ntv): when loops return escaping ssa-values, this should be adapted.
ValueHandle operator()(ArrayRef<ValueHandle> stmts);
ValueHandle operator()(ArrayRef<CapturableHandle> stmts);
private:
SmallVector<LoopBuilder, 4> loops;
@@ -225,13 +226,20 @@ public:
/// The only purpose of this operator is to serve as a sequence point so that
/// the evaluation of `stmts` (which build IR snippets in a scoped fashion) is
/// sequenced strictly after the constructor of BlockBuilder.
void operator()(ArrayRef<ValueHandle> stmts);
void operator()(ArrayRef<CapturableHandle> stmts);
private:
BlockBuilder(const BlockBuilder &) = delete;
BlockBuilder &operator=(const BlockBuilder &other) = delete;
};
/// Base class for Handles that cannot be constructed explicitly by a user of
/// the API.
struct CapturableHandle {
protected:
CapturableHandle() = default;
};
/// ValueHandle implements a (potentially "delayed") typed Value abstraction.
/// ValueHandle should be captured by pointer but otherwise passed by Value
/// everywhere.
@@ -245,7 +253,13 @@ private:
/// 2. delayed state (empty value), in which case it represents an eagerly
/// typed "delayed" value that can be hold a Value in the future;
/// 3. constructed state,in which case it holds a Value.
class ValueHandle {
///
/// A ValueHandle is meant to capture a single Value* and should be used for
/// instructions that have a single result. For convenience of use, we also
/// include AffineForOp in this category although it does not return a value.
/// In the case of AffineForOp, the captured Value* is the loop induction
/// variable.
class ValueHandle : public CapturableHandle {
public:
/// A ValueHandle in a null state can never be captured;
static ValueHandle null() { return ValueHandle(); }
@@ -275,14 +289,13 @@ public:
/// ValueHandle is a value type, the assignment operator typechecks before
/// assigning.
/// ```
ValueHandle &operator=(const ValueHandle &other);
/// Implicit conversion useful for automatic conversion to Container<Value*>.
operator Value *() const { return getValue(); }
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the AST.
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static ValueHandle create(Args... args);
@@ -291,6 +304,11 @@ public:
static ValueHandle createComposedAffineApply(AffineMap map,
ArrayRef<Value *> operands);
/// Generic create for a named instruction producing a single value.
static ValueHandle create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
bool hasValue() const { return v != nullptr; }
Value *getValue() const { return v; }
bool hasType() const { return t != Type(); }
@@ -303,12 +321,59 @@ private:
Value *v;
};
/// An InstructionHandle can be used in lieu of ValueHandle to capture the
/// instruction in cases when one does not care about, or cannot extract, a
/// unique Value* from the instruction.
/// This can be used for capturing zero result instructions as well as
/// multi-result instructions that are not supported by ValueHandle.
/// We do not distinguish further between zero and multi-result instructions at
/// this time.
struct InstructionHandle : public CapturableHandle {
InstructionHandle() : inst(nullptr) {}
InstructionHandle(Instruction *inst) : inst(inst) {}
InstructionHandle(const InstructionHandle &) = default;
InstructionHandle &operator=(const InstructionHandle &) = default;
/// Generic mlir::Op create. This is the key to being extensible to the whole
/// of MLIR without duplicating the type system or the op definitions.
template <typename Op, typename... Args>
static InstructionHandle create(Args... args);
/// Generic create for a named instruction.
static InstructionHandle create(StringRef name,
ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes = {});
operator Instruction *() { return inst; }
private:
Instruction *inst;
};
/// Simple wrapper to build a generic instruction without successor blocks.
template <typename HandleType> struct CustomInstruction {
CustomInstruction(StringRef name) : name(name) {
static_assert(std::is_same<HandleType, ValueHandle>() ||
std::is_same<HandleType, InstructionHandle>(),
"Only CustomInstruction<ValueHandle> or "
"CustomInstruction<InstructionHandle> can be constructed.");
}
HandleType operator()(ArrayRef<ValueHandle> operands = {},
ArrayRef<Type> resultTypes = {},
ArrayRef<NamedAttribute> attributes = {}) {
return HandleType::create(name, operands, resultTypes, attributes);
}
std::string name;
};
/// A BlockHandle represents a (potentially "delayed") Block abstraction.
/// This extra abstraction is necessary because an mlir::Block is not an
/// mlir::Value.
/// A BlockHandle should be captured by pointer but otherwise passed by Value
/// everywhere.
class BlockHandle {
class BlockHandle : public CapturableHandle {
public:
/// A BlockHandle constructed without an mlir::Block* represents a "delayed"
/// Block. A delayed Block represents the declaration (in the PL sense) of a
@@ -338,6 +403,14 @@ private:
mlir::Block *block;
};
template <typename Op, typename... Args>
InstructionHandle InstructionHandle::create(Args... args) {
return InstructionHandle(
ScopedContext::getBuilder()
->create<Op>(ScopedContext::getLocation(), args...)
->getInstruction());
}
template <typename Op, typename... Args>
ValueHandle ValueHandle::create(Args... args) {
Instruction *inst = ScopedContext::getBuilder()
@@ -350,9 +423,8 @@ ValueHandle ValueHandle::create(Args... args) {
f->createBody();
return ValueHandle(f->getInductionVar());
}
return ValueHandle();
}
llvm_unreachable("unsupported inst with > 1 results");
llvm_unreachable("unsupported instruction, use an InstructionHandle instead");
}
namespace op {

View File

@@ -106,7 +106,7 @@ struct IndexedValue {
/// Emits a `store`.
// NOLINTNEXTLINE: unconventional-assign-operator
ValueHandle operator=(ValueHandle rhs) {
InstructionHandle operator=(ValueHandle rhs) {
return intrinsics::STORE(rhs, getBase(), indices);
}
@@ -122,10 +122,10 @@ struct IndexedValue {
ValueHandle operator-(ValueHandle e);
ValueHandle operator*(ValueHandle e);
ValueHandle operator/(ValueHandle e);
ValueHandle operator+=(ValueHandle e);
ValueHandle operator-=(ValueHandle e);
ValueHandle operator*=(ValueHandle e);
ValueHandle operator/=(ValueHandle e);
InstructionHandle operator+=(ValueHandle e);
InstructionHandle operator-=(ValueHandle e);
InstructionHandle operator*=(ValueHandle e);
InstructionHandle operator/=(ValueHandle e);
ValueHandle operator+(IndexedValue e) {
return *this + static_cast<ValueHandle>(e);
}
@@ -138,16 +138,16 @@ struct IndexedValue {
ValueHandle operator/(IndexedValue e) {
return *this / static_cast<ValueHandle>(e);
}
ValueHandle operator+=(IndexedValue e) {
InstructionHandle operator+=(IndexedValue e) {
return this->operator+=(static_cast<ValueHandle>(e));
}
ValueHandle operator-=(IndexedValue e) {
InstructionHandle operator-=(IndexedValue e) {
return this->operator-=(static_cast<ValueHandle>(e));
}
ValueHandle operator*=(IndexedValue e) {
InstructionHandle operator*=(IndexedValue e) {
return this->operator*=(static_cast<ValueHandle>(e));
}
ValueHandle operator/=(IndexedValue e) {
InstructionHandle operator/=(IndexedValue e) {
return this->operator/=(static_cast<ValueHandle>(e));
}

View File

@@ -30,6 +30,7 @@ namespace mlir {
namespace edsc {
class BlockHandle;
class InstructionHandle;
class ValueHandle;
/// Provides a set of first class intrinsics.
@@ -41,7 +42,7 @@ namespace intrinsics {
///
/// Prerequisites:
/// All Handles have already captured previously constructed IR objects.
ValueHandle BR(BlockHandle bh, ArrayRef<ValueHandle> operands);
InstructionHandle BR(BlockHandle bh, ArrayRef<ValueHandle> operands);
/// Creates a new mlir::Block* and branches to it from the current block.
/// Argument types are specified by `operands`.
@@ -56,8 +57,8 @@ ValueHandle BR(BlockHandle bh, ArrayRef<ValueHandle> operands);
/// All `operands` have already captured an mlir::Value*
/// captures.size() == operands.size()
/// captures and operands are pairwise of the same type.
ValueHandle BR(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands);
InstructionHandle BR(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands);
/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with
/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and
@@ -65,9 +66,10 @@ ValueHandle BR(BlockHandle *bh, ArrayRef<ValueHandle *> captures,
///
/// Prerequisites:
/// All Handles have captured previouly constructed IR objects.
ValueHandle COND_BR(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands, BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands);
InstructionHandle COND_BR(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands);
/// Eagerly creates new mlir::Block* with argument types specified by
/// `trueOperands`/`falseOperands`.
@@ -85,12 +87,12 @@ ValueHandle COND_BR(ValueHandle cond, BlockHandle trueBranch,
/// `falseCaptures`.size() == `falseOperands`.size()
/// `trueCaptures` and `trueOperands` are pairwise of the same type
/// `falseCaptures` and `falseOperands` are pairwise of the same type.
ValueHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures,
ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch,
ArrayRef<ValueHandle *> falseCaptures,
ArrayRef<ValueHandle> falseOperands);
InstructionHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures,
ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch,
ArrayRef<ValueHandle *> falseCaptures,
ArrayRef<ValueHandle> falseOperands);
////////////////////////////////////////////////////////////////////////////////
// TODO(ntv): Intrinsics below this line should be TableGen'd.
@@ -103,13 +105,13 @@ ValueHandle LOAD(ValueHandle base, llvm::ArrayRef<ValueHandle> indices);
/// Builds an mlir::ReturnOp with the proper `operands` that each must have
/// captured an mlir::Value*.
/// Returns an empty ValueHandle.
ValueHandle RETURN(llvm::ArrayRef<ValueHandle> operands);
InstructionHandle RETURN(llvm::ArrayRef<ValueHandle> operands);
/// Builds an mlir::StoreOp with the proper `operands` that each must have
/// captured an mlir::Value*.
/// Returns an empty ValueHandle.
ValueHandle STORE(ValueHandle value, ValueHandle base,
llvm::ArrayRef<ValueHandle> indices);
InstructionHandle STORE(ValueHandle value, ValueHandle base,
llvm::ArrayRef<ValueHandle> indices);
} // namespace intrinsics

View File

@@ -92,6 +92,38 @@ mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map,
return ValueHandle(inst->getResult(0));
}
ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
Instruction *inst =
InstructionHandle::create(name, operands, resultTypes, attributes);
if (auto f = inst->dyn_cast<AffineForOp>()) {
// Immediately create the loop body so we can just insert instructions right
// away.
f->createBody();
return ValueHandle(f->getInductionVar());
}
if (inst->getNumResults() == 1) {
return ValueHandle(inst->getResult(0));
}
llvm_unreachable("unsupported instruction, use an InstructionHandle instead");
}
InstructionHandle
InstructionHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
OperationState state(ScopedContext::getContext(),
ScopedContext::getLocation(), name);
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
state.addOperands(ops);
state.addTypes(resultTypes);
for (const auto &attr : attributes) {
state.addAttribute(attr.first, attr.second);
}
return InstructionHandle(ScopedContext::getBuilder()->createOperation(state));
}
BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
BlockHandle res;
res.block = ScopedContext::getBuilder()->createBlock();
@@ -139,7 +171,8 @@ mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
enter(body);
}
ValueHandle mlir::edsc::LoopBuilder::operator()(ArrayRef<ValueHandle> stmts) {
ValueHandle
mlir::edsc::LoopBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator.
/// The particular use case concerns nested blocks:
@@ -176,7 +209,7 @@ mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
}
ValueHandle
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<ValueHandle> stmts) {
mlir::edsc::LoopNestBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
// Iterate on the calling operator() on all the loops in the nest.
// The iteration order is from innermost to outermost because enter/exit needs
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
@@ -212,7 +245,7 @@ mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
/// Only serves as an ordering point between entering nested block and creating
/// stmts.
void mlir::edsc::BlockBuilder::operator()(ArrayRef<ValueHandle> stmts) {
void mlir::edsc::BlockBuilder::operator()(ArrayRef<CapturableHandle> stmts) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator.
exit();

View File

@@ -71,19 +71,19 @@ ValueHandle mlir::edsc::IndexedValue::operator/(ValueHandle e) {
return static_cast<ValueHandle>(*this) / e;
}
ValueHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) {
InstructionHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) {
using op::operator+;
return intrinsics::STORE(*this + e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) {
InstructionHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) {
using op::operator-;
return intrinsics::STORE(*this - e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) {
InstructionHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) {
using op::operator*;
return intrinsics::STORE(*this * e, getBase(), indices);
}
ValueHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) {
InstructionHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) {
using op::operator/;
return intrinsics::STORE(*this / e, getBase(), indices);
}

View File

@@ -22,15 +22,15 @@
using namespace mlir;
using namespace mlir::edsc;
ValueHandle mlir::edsc::intrinsics::BR(BlockHandle bh,
ArrayRef<ValueHandle> operands) {
InstructionHandle mlir::edsc::intrinsics::BR(BlockHandle bh,
ArrayRef<ValueHandle> operands) {
assert(bh && "Expected already captured BlockHandle");
for (auto &o : operands) {
(void)o;
assert(o && "Expected already captured ValueHandle");
}
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
return ValueHandle::create<BranchOp>(bh.getBlock(), ops);
return InstructionHandle::create<BranchOp>(bh.getBlock(), ops);
}
static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands) {
@@ -46,9 +46,9 @@ static void enforceEmptyCapturesMatchOperands(ArrayRef<ValueHandle *> captures,
}
}
ValueHandle mlir::edsc::intrinsics::BR(BlockHandle *bh,
ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands) {
InstructionHandle mlir::edsc::intrinsics::BR(BlockHandle *bh,
ArrayRef<ValueHandle *> captures,
ArrayRef<ValueHandle> operands) {
assert(!*bh && "Unexpected already captured BlockHandle");
enforceEmptyCapturesMatchOperands(captures, operands);
{ // Clone the scope explicitly to avoid modifying the insertion point in the
@@ -60,21 +60,21 @@ ValueHandle mlir::edsc::intrinsics::BR(BlockHandle *bh,
BlockBuilder(bh, captures)({/* no body */});
} // Release before adding the branch to the eagerly created block.
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
return ValueHandle::create<BranchOp>(bh->getBlock(), ops);
return InstructionHandle::create<BranchOp>(bh->getBlock(), ops);
}
ValueHandle
InstructionHandle
mlir::edsc::intrinsics::COND_BR(ValueHandle cond, BlockHandle trueBranch,
ArrayRef<ValueHandle> trueOperands,
BlockHandle falseBranch,
ArrayRef<ValueHandle> falseOperands) {
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
return ValueHandle::create<CondBranchOp>(cond, trueBranch.getBlock(), trueOps,
falseBranch.getBlock(), falseOps);
return InstructionHandle::create<CondBranchOp>(
cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps);
}
ValueHandle mlir::edsc::intrinsics::COND_BR(
InstructionHandle mlir::edsc::intrinsics::COND_BR(
ValueHandle cond, BlockHandle *trueBranch,
ArrayRef<ValueHandle *> trueCaptures, ArrayRef<ValueHandle> trueOperands,
BlockHandle *falseBranch, ArrayRef<ValueHandle *> falseCaptures,
@@ -93,7 +93,7 @@ ValueHandle mlir::edsc::intrinsics::COND_BR(
} // Release before adding the branch to the eagerly created block.
SmallVector<Value *, 4> trueOps(trueOperands.begin(), trueOperands.end());
SmallVector<Value *, 4> falseOps(falseOperands.begin(), falseOperands.end());
return ValueHandle::create<CondBranchOp>(
return InstructionHandle::create<CondBranchOp>(
cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps);
}
@@ -107,14 +107,16 @@ mlir::edsc::intrinsics::LOAD(ValueHandle base,
return ValueHandle::create<LoadOp>(base.getValue(), ops);
}
ValueHandle mlir::edsc::intrinsics::RETURN(ArrayRef<ValueHandle> operands) {
InstructionHandle
mlir::edsc::intrinsics::RETURN(ArrayRef<ValueHandle> operands) {
SmallVector<Value *, 4> ops(operands.begin(), operands.end());
return ValueHandle::create<ReturnOp>(ops);
return InstructionHandle::create<ReturnOp>(ops);
}
ValueHandle
InstructionHandle
mlir::edsc::intrinsics::STORE(ValueHandle value, ValueHandle base,
llvm::ArrayRef<ValueHandle> indices = {}) {
SmallVector<Value *, 4> ops(indices.begin(), indices.end());
return ValueHandle::create<StoreOp>(value.getValue(), base.getValue(), ops);
return InstructionHandle::create<StoreOp>(value.getValue(), base.getValue(),
ops);
}

View File

@@ -363,6 +363,39 @@ TEST_FUNC(builder_helpers) {
f->print(llvm::outs());
}
TEST_FUNC(custom_ops) {
using namespace edsc;
using namespace edsc::intrinsics;
using namespace edsc::op;
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
ScopedContext scope(f.get());
CustomInstruction<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomInstruction<InstructionHandle> MY_CUSTOM_INST_0("my_custom_inst_0");
CustomInstruction<InstructionHandle> MY_CUSTOM_INST_2("my_custom_inst_2");
// clang-format off
ValueHandle vh(indexType);
InstructionHandle ih0, ih2;
IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
IndexHandle ten(index_t(10)), twenty(index_t(20));
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})({
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}),
ih0 = MY_CUSTOM_INST_0({m, m + n}, {}),
ih2 = MY_CUSTOM_INST_2({m, m + n}, {indexType, indexType}),
});
// CHECK-LABEL: @custom_ops
// CHECK: for %i0 {{.*}}
// CHECK: for %i1 {{.*}}
// CHECK: {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index
// CHECK: "my_custom_inst_0"{{.*}} : (index, index) -> ()
// CHECK: {{.*}} = "my_custom_inst_2"{{.*}} : (index, index) -> (index, index)
// clang-format on
f->print(llvm::outs());
}
int main() {
RUN_TESTS();
return 0;