From 7c0b9e8b6270d6cda402875facaeb045a4b03d4a Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 6 Mar 2019 13:54:41 -0800 Subject: [PATCH] Add helper classes to declarative builders to help write end-to-end custom ops. This CL adds the same helper classes that exist in the AST form of EDSCs to support a basic indexing notation and emit the proper load and store operations and capture MemRefViews as function arguments. This CL also adds a wrapper class LoopNestBuilder to allow generic rank-agnostic loops over indices. PiperOrigin-RevId: 237113755 --- mlir/include/mlir/EDSC/Builders.h | 34 ++++++ mlir/include/mlir/EDSC/Helpers.h | 162 ++++++++++++++++++++++++++++ mlir/include/mlir/EDSC/Intrinsics.h | 11 ++ mlir/lib/EDSC/Builders.cpp | 26 +++++ mlir/lib/EDSC/Helpers.cpp | 110 +++++++++++++++++++ mlir/lib/EDSC/Intrinsics.cpp | 13 +++ mlir/test/EDSC/builder-api-test.cpp | 53 ++++++++- 7 files changed, 407 insertions(+), 2 deletions(-) create mode 100644 mlir/include/mlir/EDSC/Helpers.h create mode 100644 mlir/lib/EDSC/Helpers.cpp diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index de6f8b291b2a..d5d8c669a327 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -33,6 +33,7 @@ namespace edsc { struct index_t { explicit index_t(int64_t v) : v(v) {} + explicit operator int64_t() { return v; } int64_t v; }; @@ -147,6 +148,39 @@ public: ValueHandle operator()(ArrayRef stmts); }; +/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid +/// explicitly writing all the loops in a nest. This simple functionality is +/// also useful to write rank-agnostic custom ops. +/// +/// Usage: +/// +/// ```c++ +/// LoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, 1})({ +/// ... +/// }); +/// ``` +/// +/// ```c++ +/// LoopNestBuilder({&i}, {lb}, {ub}, {1})({ +/// LoopNestBuilder({&j}, {lb}, {ub}, {1})({ +/// LoopNestBuilder({&k}, {lb}, {ub}, {1})({ +/// ... +/// }), +/// }), +/// }); +/// ``` +class LoopNestBuilder { +public: + LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); + + // TODO(ntv): when loops return escaping ssa-values, this should be adapted. + ValueHandle operator()(ArrayRef stmts); + +private: + SmallVector loops; +}; + // This class exists solely to handle the C++ vexing parse case when // trying to enter a Block that has already been constructed. class Append {}; diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h new file mode 100644 index 000000000000..9c4618211e74 --- /dev/null +++ b/mlir/include/mlir/EDSC/Helpers.h @@ -0,0 +1,162 @@ +//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Provides helper classes and syntactic sugar for declarative builders. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDSC_HELPERS_H_ +#define MLIR_EDSC_HELPERS_H_ + +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { + +class IndexedValue; + +/// An IndexHandle is a simple wrapper around a ValueHandle. +/// IndexHandles are ubiquitous enough to justify a new type to allow simple +/// declarations without boilerplate such as: +/// +/// ```c++ +/// IndexHandle i, j, k; +/// ``` +struct IndexHandle : public ValueHandle { + explicit IndexHandle() + : ValueHandle(ScopedContext::getBuilder()->getIndexType()) {} + explicit IndexHandle(index_t v) : ValueHandle(v) {} + explicit IndexHandle(Value *v) : ValueHandle(v) { + assert(v->getType() == ScopedContext::getBuilder()->getIndexType() && + "Expected index type"); + } + explicit IndexHandle(ValueHandle v) : ValueHandle(v) {} +}; + +/// A MemRefView represents the information required to step through a +/// MemRef. It has placeholders for non-contiguous tensors that fit within the +/// Fortran subarray model. +/// At the moment it can only capture a MemRef with an identity layout map. +// TODO(ntv): Support MemRefs with layoutMaps. +class MemRefView { +public: + explicit MemRefView(Value *v); + MemRefView(const MemRefView &) = default; + MemRefView &operator=(const MemRefView &) = default; + + unsigned rank() const { return lbs.size(); } + unsigned fastestVarying() const { return rank() - 1; } + + std::tuple range(unsigned idx) { + return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); + } + +private: + friend IndexedValue; + + ValueHandle base; + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + +ValueHandle operator+(ValueHandle v, IndexedValue i); +ValueHandle operator-(ValueHandle v, IndexedValue i); +ValueHandle operator*(ValueHandle v, IndexedValue i); +ValueHandle operator/(ValueHandle v, IndexedValue i); + +/// This helper class is an abstraction over memref, that purely for sugaring +/// purposes and allows writing compact expressions such as: +/// +/// ```mlir +/// IndexedValue A(...), B(...), C(...); +/// For(ivs, zeros, shapeA, ones, { +/// C(ivs) = A(ivs) + B(ivs) +/// }); +/// ``` +/// +/// Assigning to an IndexedValue emits an actual store operation, while using +/// converting an IndexedValue to a ValueHandle emits an actual load operation. +struct IndexedValue { + explicit IndexedValue(MemRefView &v, llvm::ArrayRef indices = {}) + : view(v), indices(indices.begin(), indices.end()) {} + + IndexedValue(const IndexedValue &rhs) = default; + IndexedValue &operator=(const IndexedValue &rhs) = default; + + /// Returns a new `IndexedValue`. + IndexedValue operator()(llvm::ArrayRef indices = {}) { + return IndexedValue(view, indices); + } + + /// Emits a `store`. + // NOLINTNEXTLINE: unconventional-assign-operator + ValueHandle operator=(ValueHandle rhs) { + return intrinsics::STORE(rhs, getBase(), indices); + } + + ValueHandle getBase() const { return view.base; } + + /// Emits a `load` when converting to a ValueHandle. + explicit operator ValueHandle() { + return intrinsics::LOAD(getBase(), indices); + } + + /// Operator overloadings. + ValueHandle operator+(ValueHandle e); + ValueHandle operator-(ValueHandle e); + ValueHandle operator*(ValueHandle e); + ValueHandle operator/(ValueHandle e); + ValueHandle operator+=(ValueHandle e); + ValueHandle operator-=(ValueHandle e); + ValueHandle operator*=(ValueHandle e); + ValueHandle operator/=(ValueHandle e); + ValueHandle operator+(IndexedValue e) { + return *this + static_cast(e); + } + ValueHandle operator-(IndexedValue e) { + return *this - static_cast(e); + } + ValueHandle operator*(IndexedValue e) { + return *this * static_cast(e); + } + ValueHandle operator/(IndexedValue e) { + return *this / static_cast(e); + } + ValueHandle operator+=(IndexedValue e) { + return this->operator+=(static_cast(e)); + } + ValueHandle operator-=(IndexedValue e) { + return this->operator-=(static_cast(e)); + } + ValueHandle operator*=(IndexedValue e) { + return this->operator*=(static_cast(e)); + } + ValueHandle operator/=(IndexedValue e) { + return this->operator/=(static_cast(e)); + } + +private: + MemRefView &view; + llvm::SmallVector indices; +}; + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_EDSC_HELPERS_H_ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index fcaf7a1c5010..6e69506fb213 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -95,11 +95,22 @@ ValueHandle COND_BR(ValueHandle cond, BlockHandle *trueBranch, //////////////////////////////////////////////////////////////////////////////// // TODO(ntv): Intrinsics below this line should be TableGen'd. //////////////////////////////////////////////////////////////////////////////// +/// Builds an mlir::LoadOp with the proper `operands` that each must have +/// captured an mlir::Value*. +/// Returns a ValueHandle to the produced mlir::Value*. +ValueHandle LOAD(ValueHandle base, llvm::ArrayRef indices); + /// Builds an mlir::ReturnOp with the proper `operands` that each must have /// captured an mlir::Value*. /// Returns an empty ValueHandle. ValueHandle RETURN(llvm::ArrayRef operands); +/// Builds an mlir::StoreOp with the proper `operands` that each must have +/// captured an mlir::Value*. +/// Returns an empty ValueHandle. +ValueHandle STORE(ValueHandle value, ValueHandle base, + llvm::ArrayRef indices); + } // namespace intrinsics } // namespace edsc diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp index 5cbcb289b0d3..7a5da8ab0dda 100644 --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -162,6 +162,32 @@ ValueHandle mlir::edsc::LoopBuilder::operator()(ArrayRef stmts) { return ValueHandle::null(); } +mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef steps) { + assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); + for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { + loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it), + std::get<3>(it)); + } +} + +ValueHandle +mlir::edsc::LoopNestBuilder::operator()(ArrayRef stmts) { + // Iterate on the calling operator() on all the loops in the nest. + // The iteration order is from innermost to outermost because enter/exit needs + // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() + // occurs on calling operator()). The asymmetry is required for properly + // nesting imperfectly nested regions (see LoopBuilder::operator()). + for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) { + (*lit)({}); + } + return ValueHandle::null(); +} + mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { assert(bh && "Expected already captured BlockHandle"); enter(bh.getBlock()); diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/EDSC/Helpers.cpp new file mode 100644 index 000000000000..fcf7c8213153 --- /dev/null +++ b/mlir/lib/EDSC/Helpers.cpp @@ -0,0 +1,110 @@ +//===- Helpers.cpp - MLIR Declarative Helper Functionality ------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/StandardOps/Ops.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Builders.h" +#include "third_party/llvm/llvm/projects/google_mlir/include/mlir/EDSC/Helpers.h" + +using namespace mlir; +using namespace mlir::edsc; + +static SmallVector getMemRefSizes(Value *memRef) { + MemRefType memRefType = memRef->getType().cast(); + + auto maps = memRefType.getAffineMaps(); + assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) && + "Layout maps not supported"); + SmallVector res; + res.reserve(memRefType.getShape().size()); + const auto &shape = memRefType.getShape(); + for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) { + if (shape[idx] == -1) { + res.push_back(IndexHandle(ValueHandle::create(memRef, idx))); + } else { + res.push_back(IndexHandle(static_cast(shape[idx]))); + } + } + return res; +} + +mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) { + assert(v->getType().isa() && "MemRefType expected"); + + auto memrefSizeValues = getMemRefSizes(v); + for (auto &size : memrefSizeValues) { + lbs.push_back(IndexHandle(static_cast(0))); + ubs.push_back(size); + steps.push_back(1); + } +} + +/// Operator overloadings. +ValueHandle mlir::edsc::IndexedValue::operator+(ValueHandle e) { + using op::operator+; + return static_cast(*this) + e; +} +ValueHandle mlir::edsc::IndexedValue::operator-(ValueHandle e) { + using op::operator-; + return static_cast(*this) - e; +} +ValueHandle mlir::edsc::IndexedValue::operator*(ValueHandle e) { + using op::operator*; + return static_cast(*this) * e; +} +ValueHandle mlir::edsc::IndexedValue::operator/(ValueHandle e) { + using op::operator/; + return static_cast(*this) / e; +} + +ValueHandle mlir::edsc::IndexedValue::operator+=(ValueHandle e) { + using op::operator+; + return intrinsics::STORE(*this + e, getBase(), indices); +} +ValueHandle mlir::edsc::IndexedValue::operator-=(ValueHandle e) { + using op::operator-; + return intrinsics::STORE(*this - e, getBase(), indices); +} +ValueHandle mlir::edsc::IndexedValue::operator*=(ValueHandle e) { + using op::operator*; + return intrinsics::STORE(*this * e, getBase(), indices); +} +ValueHandle mlir::edsc::IndexedValue::operator/=(ValueHandle e) { + using op::operator/; + return intrinsics::STORE(*this / e, getBase(), indices); +} + +ValueHandle mlir::edsc::operator+(ValueHandle v, IndexedValue i) { + using op::operator+; + return v + static_cast(i); +} + +ValueHandle mlir::edsc::operator-(ValueHandle v, IndexedValue i) { + using op::operator-; + return v - static_cast(i); +} + +ValueHandle mlir::edsc::operator*(ValueHandle v, IndexedValue i) { + using op::operator*; + return v * static_cast(i); +} + +ValueHandle mlir::edsc::operator/(ValueHandle v, IndexedValue i) { + using op::operator/; + return v / static_cast(i); +} diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/EDSC/Intrinsics.cpp index 58ea87ceb8b0..87c9fec5bdc7 100644 --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/EDSC/Intrinsics.cpp @@ -100,8 +100,21 @@ ValueHandle mlir::edsc::intrinsics::COND_BR( //////////////////////////////////////////////////////////////////////////////// // TODO(ntv): Intrinsics below this line should be TableGen'd. //////////////////////////////////////////////////////////////////////////////// +ValueHandle +mlir::edsc::intrinsics::LOAD(ValueHandle base, + llvm::ArrayRef indices = {}) { + SmallVector ops(indices.begin(), indices.end()); + return ValueHandle::create(base.getValue(), ops); +} ValueHandle mlir::edsc::intrinsics::RETURN(ArrayRef operands) { SmallVector ops(operands.begin(), operands.end()); return ValueHandle::create(ops); } + +ValueHandle +mlir::edsc::intrinsics::STORE(ValueHandle value, ValueHandle base, + llvm::ArrayRef indices = {}) { + SmallVector ops(indices.begin(), indices.end()); + return ValueHandle::create(value.getValue(), base.getValue(), ops); +} diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 09cb26777750..9119f3b68c50 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -20,9 +20,8 @@ #include "mlir/AffineOps/AffineOps.h" #include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" -#include "mlir/EDSC/MLIREmitter.h" -#include "mlir/EDSC/Types.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" @@ -317,6 +316,56 @@ TEST_FUNC(builder_cond_branch_eager) { f->print(llvm::outs()); } +TEST_FUNC(builder_helpers) { + using namespace edsc; + using namespace edsc::intrinsics; + using namespace edsc::op; + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0); + auto f = + makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType}); + + ScopedContext scope(f.get()); + // clang-format off + ValueHandle f7( + ValueHandle::create(llvm::APFloat(7.0f), f32Type)); + MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), vC(f->getArgument(2)); + IndexedValue A(vA), B(vB), C(vC); + IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; + int64_t step0, step1, step2; + std::tie(lb0, ub0, step0) = vA.range(0); + std::tie(lb1, ub1, step1) = vA.range(1); + std::tie(lb2, ub2, step2) = vA.range(2); + LoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})({ + LoopBuilder(&k1, lb2, ub2, step2)({ + C({i, j, k1}) = f7 + A({i, j, k1}) + B({i, j, k1}), + }), + LoopBuilder(&k2, lb2, ub2, step2)({ + C({i, j, k2}) += A({i, j, k2}) + B({i, j, k2}), + }), + }); + + // CHECK-LABEL: @builder_helpers + // CHECK: for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK-NEXT: for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK-NEXT: for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) { + // CHECK-NEXT: [[a:%.*]] = load %arg0[%i0, %i1, %i2] : memref + // CHECK-NEXT: [[b:%.*]] = addf {{.*}}, [[a]] : f32 + // CHECK-NEXT: [[c:%.*]] = load %arg1[%i0, %i1, %i2] : memref + // CHECK-NEXT: [[d:%.*]] = addf [[b]], [[c]] : f32 + // CHECK-NEXT: store [[d]], %arg2[%i0, %i1, %i2] : memref + // CHECK-NEXT: } + // CHECK-NEXT: for %i3 = (d0) -> (d0)(%c0_1) to (d0) -> (d0)(%2) { + // CHECK-NEXT: [[a:%.*]] = load %arg1[%i0, %i1, %i3] : memref + // CHECK-NEXT: [[b:%.*]] = load %arg0[%i0, %i1, %i3] : memref + // CHECK-NEXT: [[c:%.*]] = addf [[b]], [[a]] : f32 + // CHECK-NEXT: [[d:%.*]] = load %arg2[%i0, %i1, %i3] : memref + // CHECK-NEXT: [[e:%.*]] = addf [[d]], [[c]] : f32 + // CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref + // clang-format on + f->print(llvm::outs()); +} + int main() { RUN_TESTS(); return 0;