Add operand type iterators to Operation and cleanup usages of operand->getType. This also simplifies some lingering usages of result->getType.

--

PiperOrigin-RevId: 249889174
This commit is contained in:
River Riddle
2019-05-24 13:28:55 -07:00
committed by Mehdi Amini
parent e53b7d2c02
commit 06734badbc
14 changed files with 119 additions and 94 deletions

View File

@@ -97,7 +97,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
auto elementType = linalg::convertLinalgType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);

View File

@@ -226,8 +226,8 @@ public:
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
return !v->getType().cast<ToyArrayType>().isGeneric();
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
return !ty.cast<ToyArrayType>().isGeneric();
});
});
if (nextop == opWorklist.end())
@@ -308,9 +308,8 @@ public:
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
std::vector<mlir::Type> funcArgs;
for (auto operand : op->getOperands())
funcArgs.push_back(operand->getType());
std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
op->operand_type_end());
funcWorklist.push_back(
{callee, std::move(mangledName), std::move(funcArgs)});
return mlir::success();

View File

@@ -226,8 +226,8 @@ public:
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
return !v->getType().cast<ToyArrayType>().isGeneric();
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
return !ty.cast<ToyArrayType>().isGeneric();
});
});
if (nextop == opWorklist.end())
@@ -312,18 +312,15 @@ public:
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
std::vector<mlir::Type> funcArgs;
for (auto operand : op->getOperands())
funcArgs.push_back(operand->getType());
std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
op->operand_type_end());
funcWorklist.push_back(
{callee, std::move(mangledName), std::move(funcArgs)});
return mlir::success();
}
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands;
for (mlir::Value *v : op->getOperands())
operands.push_back(v);
SmallVector<mlir::Value *, 8> operands(op->getOperands());
mlir::FuncBuilder builder(f);
builder.setInsertionPoint(op);
auto newCall =

View File

@@ -78,9 +78,9 @@ public:
/// Get the SSA values corresponding to kernel block size.
KernelDim3 getBlockSize();
/// Get the operand values passed as kernel arguments.
Operation::operand_range getKernelOperandValues();
/// Append the operand types passed as kernel arguments to `out`.
void getKernelOperandTypes(SmallVectorImpl<Type> &out);
operand_range getKernelOperandValues();
/// Get the operand types passed as kernel arguments.
operand_type_range getKernelOperandTypes();
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();

View File

@@ -328,6 +328,8 @@ template <typename ConcreteType, template <typename> class TraitType>
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
using operand_iterator = Operation::operand_iterator;
using operand_range = Operation::operand_range;
using operand_type_iterator = Operation::operand_type_iterator;
using operand_type_range = Operation::operand_type_range;
/// Return the number of operands.
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
@@ -346,6 +348,17 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
}
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
operand_range getOperands() { return this->getOperation()->getOperands(); }
/// Operand type access.
operand_type_iterator operand_type_begin() {
return this->getOperation()->operand_type_begin();
}
operand_type_iterator operand_type_end() {
return this->getOperation()->operand_type_end();
}
operand_type_range getOperandTypes() {
return this->getOperation()->getOperandTypes();
}
};
} // end namespace detail
@@ -447,6 +460,8 @@ template <typename ConcreteType, template <typename> class TraitType>
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
using result_iterator = Operation::result_iterator;
using result_range = Operation::result_range;
using result_type_iterator = Operation::result_type_iterator;
using result_type_range = Operation::result_type_range;
/// Return the number of results.
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
@@ -468,6 +483,17 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
}
result_iterator result_end() { return this->getOperation()->result_end(); }
result_range getResults() { return this->getOperation()->getResults(); }
/// Result type access.
result_type_iterator result_type_begin() {
return this->getOperation()->result_type_begin();
}
result_type_iterator result_type_end() {
return this->getOperation()->result_type_end();
}
result_type_range getResultTypes() {
return this->getOperation()->getResultTypes();
}
};
} // end namespace detail
@@ -477,7 +503,6 @@ template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value *getResult() { return this->getOperation()->getResult(0); }
Type getType() { return getResult()->getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in

View File

@@ -34,6 +34,7 @@ class BlockAndValueMapping;
class Location;
class MLIRContext;
class OperandIterator;
class OperandTypeIterator;
struct OperationState;
class ResultIterator;
class ResultTypeIterator;
@@ -198,6 +199,13 @@ public:
OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
// Support operand type iteration.
using operand_type_iterator = OperandTypeIterator;
using operand_type_range = llvm::iterator_range<operand_type_iterator>;
operand_type_iterator operand_type_begin();
operand_type_iterator operand_type_end();
operand_type_range getOperandTypes();
//===--------------------------------------------------------------------===//
// Results
//===--------------------------------------------------------------------===//
@@ -226,9 +234,10 @@ public:
// Support result type iteration.
using result_type_iterator = ResultTypeIterator;
using result_type_range = llvm::iterator_range<result_type_iterator>;
result_type_iterator result_type_begin();
result_type_iterator result_type_end();
llvm::iterator_range<result_type_iterator> getResultTypes();
result_type_range getResultTypes();
//===--------------------------------------------------------------------===//
// Attributes
@@ -500,6 +509,19 @@ public:
Value *operator*() const { return this->object->getOperand(this->index); }
};
/// This class implements the operand type iterators for the Operation
/// class in terms of operand_iterator->getType().
class OperandTypeIterator final
: public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
static Type unwrap(Value *value) { return value->getType(); }
public:
/// Initializes the operand type iterator to the specified operand iterator.
OperandTypeIterator(OperandIterator it)
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
}
};
// Implement the inline operand iterator methods.
inline auto Operation::operand_begin() -> operand_iterator {
return operand_iterator(this, 0);
@@ -513,6 +535,18 @@ inline auto Operation::getOperands() -> operand_range {
return {operand_begin(), operand_end()};
}
inline auto Operation::operand_type_begin() -> operand_type_iterator {
return operand_type_iterator(operand_begin());
}
inline auto Operation::operand_type_end() -> operand_type_iterator {
return operand_type_iterator(operand_end());
}
inline auto Operation::getOperandTypes() -> operand_type_range {
return {operand_type_begin(), operand_type_end()};
}
/// This class implements the result iterators for the Operation class
/// in terms of getResult(idx).
class ResultIterator final
@@ -559,8 +593,7 @@ inline auto Operation::result_type_end() -> result_type_iterator {
return result_type_iterator(result_end());
}
inline auto Operation::getResultTypes()
-> llvm::iterator_range<result_type_iterator> {
inline auto Operation::getResultTypes() -> result_type_range {
return {result_type_begin(), result_type_end()};
}

View File

@@ -99,16 +99,12 @@ KernelDim3 LaunchOp::getBlockSize() {
return KernelDim3{args[9], args[10], args[11]};
}
Operation::operand_range LaunchOp::getKernelOperandValues() {
return {getOperation()->operand_begin() + kNumConfigOperands,
getOperation()->operand_end()};
LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
return llvm::drop_begin(getOperands(), kNumConfigOperands);
}
void LaunchOp::getKernelOperandTypes(SmallVectorImpl<Type> &out) {
out.reserve(getNumOperands() - kNumConfigOperands + out.size());
for (unsigned i = kNumConfigOperands; i < getNumOperands(); ++i) {
out.push_back(getOperand(i)->getType());
}
LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
}
KernelDim3 LaunchOp::getGridSizeOperandValues() {

View File

@@ -60,8 +60,7 @@ void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
// Outline the `gpu.launch` operation body into a kernel function.
Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes;
launchOp.getKernelOperandTypes(kernelOperandTypes);
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
FunctionType::get(kernelOperandTypes, {}, module.getContext());
std::string kernelFuncName =

View File

@@ -220,10 +220,10 @@ void ModuleState::visitAttribute(Attribute attr) {
void ModuleState::visitOperation(Operation *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
visitType(operand->getType());
for (auto *result : op->getResults())
visitType(result->getType());
for (auto type : op->getOperandTypes())
visitType(type);
for (auto type : op->getResultTypes())
visitType(type);
// Visit each of the attributes.
for (auto elt : op->getAttrs())

View File

@@ -593,11 +593,7 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
}
}
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(getNumResults());
for (auto *result : getResults())
resultTypes.push_back(result->getType());
SmallVector<Type, 8> resultTypes(getResultTypes());
unsigned numRegions = getNumRegions();
auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes,
attrs, successors, numRegions,
@@ -718,8 +714,8 @@ static Type getTensorOrVectorElementType(Type type) {
}
LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
for (auto *operand : op->getOperands()) {
auto type = getTensorOrVectorElementType(operand->getType());
for (auto opType : op->getOperandTypes()) {
auto type = getTensorOrVectorElementType(opType);
if (!type.isIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
}
@@ -727,8 +723,8 @@ LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
}
LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
for (auto *operand : op->getOperands()) {
auto type = getTensorOrVectorElementType(operand->getType());
for (auto opType : op->getOperandTypes()) {
auto type = getTensorOrVectorElementType(opType);
if (!type.isa<FloatType>())
return op->emitOpError("requires a float type");
}
@@ -742,8 +738,8 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
return success();
auto type = op->getOperand(0)->getType();
for (unsigned i = 1; i < nOperands; ++i)
if (op->getOperand(i)->getType() != type)
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
if (opType != type)
return op->emitOpError() << "requires all operands to have the same type";
return success();
}
@@ -798,13 +794,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
return failure();
auto type = op->getOperand(0)->getType();
for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
if (failed(verifyShapeMatch(op->getResult(i)->getType(), type)))
for (auto resultType : op->getResultTypes()) {
if (failed(verifyShapeMatch(resultType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
if (failed(verifyShapeMatch(op->getOperand(i)->getType(), type)))
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
if (failed(verifyShapeMatch(opType, type)))
return op->emitOpError()
<< "requires the same shape for all operands and results";
}
@@ -849,13 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
return failure();
auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type)
for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
if (resultType != type)
return op->emitOpError()
<< "requires the same type for all operands and results";
}
for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
if (op->getOperand(i)->getType() != type)
for (auto opType : op->getOperandTypes()) {
if (opType != type)
return op->emitOpError()
<< "requires the same type for all operands and results";
}
@@ -905,8 +901,8 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
}
LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
for (auto *result : op->getResults()) {
auto elementType = getTensorOrVectorElementType(result->getType());
for (auto resultType : op->getResultTypes()) {
auto elementType = getTensorOrVectorElementType(resultType);
bool isBoolType = elementType.isInteger(1);
if (!isBoolType)
return op->emitOpError() << "requires a bool result type";
@@ -916,19 +912,17 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
}
LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
for (auto *result : op->getResults())
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
for (auto resultType : op->getResultTypes())
if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
return op->emitOpError() << "requires a floating point type";
return success();
}
LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
for (auto *result : op->getResults()) {
auto type = getTensorOrVectorElementType(result->getType());
if (!type.isIntOrIndex())
for (auto resultType : op->getResultTypes())
if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
return op->emitOpError() << "requires an integer or index type";
}
return success();
}

View File

@@ -177,11 +177,8 @@ static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
//===----------------------------------------------------------------------===//
static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
SmallVector<Type, 8> types;
for (auto *operand : op.getOperands())
types.push_back(operand->getType());
auto funcTy =
FunctionType::get(types, op.getResult()->getType(), op.getContext());
SmallVector<Type, 8> types(op.getOperandTypes());
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
*p << op.getOperationName() << ' ' << *op.base() << '[';
p->printOperands(std::next(op.operand_begin()), op.operand_end());
@@ -326,11 +323,9 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
SmallVector<Type, 8> argTypes;
argTypes.reserve(op.getNumOperands());
for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
argTypes.push_back(operand->getType());
SmallVector<Type, 1> resultTypes(op.getResultTypes());
SmallVector<Type, 8> argTypes(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
}

View File

@@ -262,17 +262,6 @@ protected:
LLVM::LLVMDialect &dialect;
};
// Given a range of MLIR typed objects, return a list of their types.
template <typename T>
SmallVector<Type, 4> getTypes(llvm::iterator_range<T> range) {
SmallVector<Type, 4> types;
types.reserve(llvm::size(range));
for (auto operand : range) {
types.push_back(operand->getType());
}
return types;
}
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
// LLVM Dialect Ops.
template <typename SourceOp, typename TargetOp>
@@ -288,8 +277,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
Type packedType;
if (numResults != 0) {
packedType =
this->lowering.packFunctionResults(getTypes(op->getResults()));
packedType = this->lowering.packFunctionResults(
llvm::to_vector<4>(op->getResultTypes()));
assert(packedType && "type conversion failed, such operation should not "
"have been matched");
}
@@ -832,7 +821,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
auto packedType =
lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {

View File

@@ -316,7 +316,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
void rewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementTy = lowering.convertType(*op->getResultTypes().begin());
auto elementTy = lowering.convertType(*op->result_type_begin());
Value *viewDescriptor = operands[0];
ArrayRef<Value *> indices = operands.drop_front();
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);

View File

@@ -284,8 +284,8 @@ static LogicalResult verify(AllocOp op) {
"operand count does not equal dimension plus symbol operand count");
// Verify that all operands are of type Index.
for (auto *operand : op.getOperands())
if (!operand->getType().isIndex())
for (auto operandType : op.getOperandTypes())
if (!operandType.isIndex())
return op.emitOpError("requires operands to be of type Index");
return success();
}
@@ -475,11 +475,8 @@ static LogicalResult verify(CallOp op) {
}
FunctionType CallOp::getCalleeType() {
SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
SmallVector<Type, 8> argTypes;
argTypes.reserve(getNumOperands());
for (auto *operand : getArgOperands())
argTypes.push_back(operand->getType());
SmallVector<Type, 4> resultTypes(getResultTypes());
SmallVector<Type, 8> argTypes(getOperandTypes());
return FunctionType::get(argTypes, resultTypes, getContext());
}