mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
Add operand type iterators to Operation and cleanup usages of operand->getType. This also simplifies some lingering usages of result->getType.
-- PiperOrigin-RevId: 249889174
This commit is contained in:
committed by
Mehdi Amini
parent
e53b7d2c02
commit
06734badbc
@@ -97,7 +97,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementType = linalg::convertLinalgType(*op->getResultTypes().begin());
|
||||
auto elementType = linalg::convertLinalgType(*op->result_type_begin());
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
|
||||
@@ -226,8 +226,8 @@ public:
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
|
||||
return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
|
||||
return !v->getType().cast<ToyArrayType>().isGeneric();
|
||||
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
|
||||
return !ty.cast<ToyArrayType>().isGeneric();
|
||||
});
|
||||
});
|
||||
if (nextop == opWorklist.end())
|
||||
@@ -308,9 +308,8 @@ public:
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
std::vector<mlir::Type> funcArgs;
|
||||
for (auto operand : op->getOperands())
|
||||
funcArgs.push_back(operand->getType());
|
||||
std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
|
||||
op->operand_type_end());
|
||||
funcWorklist.push_back(
|
||||
{callee, std::move(mangledName), std::move(funcArgs)});
|
||||
return mlir::success();
|
||||
|
||||
@@ -226,8 +226,8 @@ public:
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
|
||||
return llvm::all_of(op->getOperands(), [](mlir::Value *v) {
|
||||
return !v->getType().cast<ToyArrayType>().isGeneric();
|
||||
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
|
||||
return !ty.cast<ToyArrayType>().isGeneric();
|
||||
});
|
||||
});
|
||||
if (nextop == opWorklist.end())
|
||||
@@ -312,18 +312,15 @@ public:
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
std::vector<mlir::Type> funcArgs;
|
||||
for (auto operand : op->getOperands())
|
||||
funcArgs.push_back(operand->getType());
|
||||
std::vector<mlir::Type> funcArgs(op->operand_type_begin(),
|
||||
op->operand_type_end());
|
||||
funcWorklist.push_back(
|
||||
{callee, std::move(mangledName), std::move(funcArgs)});
|
||||
return mlir::success();
|
||||
}
|
||||
// Found a specialized callee! Let's turn this into a normal call
|
||||
// operation.
|
||||
SmallVector<mlir::Value *, 8> operands;
|
||||
for (mlir::Value *v : op->getOperands())
|
||||
operands.push_back(v);
|
||||
SmallVector<mlir::Value *, 8> operands(op->getOperands());
|
||||
mlir::FuncBuilder builder(f);
|
||||
builder.setInsertionPoint(op);
|
||||
auto newCall =
|
||||
|
||||
@@ -78,9 +78,9 @@ public:
|
||||
/// Get the SSA values corresponding to kernel block size.
|
||||
KernelDim3 getBlockSize();
|
||||
/// Get the operand values passed as kernel arguments.
|
||||
Operation::operand_range getKernelOperandValues();
|
||||
/// Append the operand types passed as kernel arguments to `out`.
|
||||
void getKernelOperandTypes(SmallVectorImpl<Type> &out);
|
||||
operand_range getKernelOperandValues();
|
||||
/// Get the operand types passed as kernel arguments.
|
||||
operand_type_range getKernelOperandTypes();
|
||||
|
||||
/// Get the SSA values passed as operands to specify the grid size.
|
||||
KernelDim3 getGridSizeOperandValues();
|
||||
|
||||
@@ -328,6 +328,8 @@ template <typename ConcreteType, template <typename> class TraitType>
|
||||
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
using operand_iterator = Operation::operand_iterator;
|
||||
using operand_range = Operation::operand_range;
|
||||
using operand_type_iterator = Operation::operand_type_iterator;
|
||||
using operand_type_range = Operation::operand_type_range;
|
||||
|
||||
/// Return the number of operands.
|
||||
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
|
||||
@@ -346,6 +348,17 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
}
|
||||
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
|
||||
operand_range getOperands() { return this->getOperation()->getOperands(); }
|
||||
|
||||
/// Operand type access.
|
||||
operand_type_iterator operand_type_begin() {
|
||||
return this->getOperation()->operand_type_begin();
|
||||
}
|
||||
operand_type_iterator operand_type_end() {
|
||||
return this->getOperation()->operand_type_end();
|
||||
}
|
||||
operand_type_range getOperandTypes() {
|
||||
return this->getOperation()->getOperandTypes();
|
||||
}
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
@@ -447,6 +460,8 @@ template <typename ConcreteType, template <typename> class TraitType>
|
||||
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
using result_iterator = Operation::result_iterator;
|
||||
using result_range = Operation::result_range;
|
||||
using result_type_iterator = Operation::result_type_iterator;
|
||||
using result_type_range = Operation::result_type_range;
|
||||
|
||||
/// Return the number of results.
|
||||
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
|
||||
@@ -468,6 +483,17 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
||||
}
|
||||
result_iterator result_end() { return this->getOperation()->result_end(); }
|
||||
result_range getResults() { return this->getOperation()->getResults(); }
|
||||
|
||||
/// Result type access.
|
||||
result_type_iterator result_type_begin() {
|
||||
return this->getOperation()->result_type_begin();
|
||||
}
|
||||
result_type_iterator result_type_end() {
|
||||
return this->getOperation()->result_type_end();
|
||||
}
|
||||
result_type_range getResultTypes() {
|
||||
return this->getOperation()->getResultTypes();
|
||||
}
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
@@ -477,7 +503,6 @@ template <typename ConcreteType>
|
||||
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||
public:
|
||||
Value *getResult() { return this->getOperation()->getResult(0); }
|
||||
|
||||
Type getType() { return getResult()->getType(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
|
||||
@@ -34,6 +34,7 @@ class BlockAndValueMapping;
|
||||
class Location;
|
||||
class MLIRContext;
|
||||
class OperandIterator;
|
||||
class OperandTypeIterator;
|
||||
struct OperationState;
|
||||
class ResultIterator;
|
||||
class ResultTypeIterator;
|
||||
@@ -198,6 +199,13 @@ public:
|
||||
|
||||
OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; }
|
||||
|
||||
// Support operand type iteration.
|
||||
using operand_type_iterator = OperandTypeIterator;
|
||||
using operand_type_range = llvm::iterator_range<operand_type_iterator>;
|
||||
operand_type_iterator operand_type_begin();
|
||||
operand_type_iterator operand_type_end();
|
||||
operand_type_range getOperandTypes();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Results
|
||||
//===--------------------------------------------------------------------===//
|
||||
@@ -226,9 +234,10 @@ public:
|
||||
|
||||
// Support result type iteration.
|
||||
using result_type_iterator = ResultTypeIterator;
|
||||
using result_type_range = llvm::iterator_range<result_type_iterator>;
|
||||
result_type_iterator result_type_begin();
|
||||
result_type_iterator result_type_end();
|
||||
llvm::iterator_range<result_type_iterator> getResultTypes();
|
||||
result_type_range getResultTypes();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Attributes
|
||||
@@ -500,6 +509,19 @@ public:
|
||||
Value *operator*() const { return this->object->getOperand(this->index); }
|
||||
};
|
||||
|
||||
/// This class implements the operand type iterators for the Operation
|
||||
/// class in terms of operand_iterator->getType().
|
||||
class OperandTypeIterator final
|
||||
: public llvm::mapped_iterator<OperandIterator, Type (*)(Value *)> {
|
||||
static Type unwrap(Value *value) { return value->getType(); }
|
||||
|
||||
public:
|
||||
/// Initializes the operand type iterator to the specified operand iterator.
|
||||
OperandTypeIterator(OperandIterator it)
|
||||
: llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {
|
||||
}
|
||||
};
|
||||
|
||||
// Implement the inline operand iterator methods.
|
||||
inline auto Operation::operand_begin() -> operand_iterator {
|
||||
return operand_iterator(this, 0);
|
||||
@@ -513,6 +535,18 @@ inline auto Operation::getOperands() -> operand_range {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
inline auto Operation::operand_type_begin() -> operand_type_iterator {
|
||||
return operand_type_iterator(operand_begin());
|
||||
}
|
||||
|
||||
inline auto Operation::operand_type_end() -> operand_type_iterator {
|
||||
return operand_type_iterator(operand_end());
|
||||
}
|
||||
|
||||
inline auto Operation::getOperandTypes() -> operand_type_range {
|
||||
return {operand_type_begin(), operand_type_end()};
|
||||
}
|
||||
|
||||
/// This class implements the result iterators for the Operation class
|
||||
/// in terms of getResult(idx).
|
||||
class ResultIterator final
|
||||
@@ -559,8 +593,7 @@ inline auto Operation::result_type_end() -> result_type_iterator {
|
||||
return result_type_iterator(result_end());
|
||||
}
|
||||
|
||||
inline auto Operation::getResultTypes()
|
||||
-> llvm::iterator_range<result_type_iterator> {
|
||||
inline auto Operation::getResultTypes() -> result_type_range {
|
||||
return {result_type_begin(), result_type_end()};
|
||||
}
|
||||
|
||||
|
||||
@@ -99,16 +99,12 @@ KernelDim3 LaunchOp::getBlockSize() {
|
||||
return KernelDim3{args[9], args[10], args[11]};
|
||||
}
|
||||
|
||||
Operation::operand_range LaunchOp::getKernelOperandValues() {
|
||||
return {getOperation()->operand_begin() + kNumConfigOperands,
|
||||
getOperation()->operand_end()};
|
||||
LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
|
||||
return llvm::drop_begin(getOperands(), kNumConfigOperands);
|
||||
}
|
||||
|
||||
void LaunchOp::getKernelOperandTypes(SmallVectorImpl<Type> &out) {
|
||||
out.reserve(getNumOperands() - kNumConfigOperands + out.size());
|
||||
for (unsigned i = kNumConfigOperands; i < getNumOperands(); ++i) {
|
||||
out.push_back(getOperand(i)->getType());
|
||||
}
|
||||
LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
|
||||
return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
|
||||
}
|
||||
|
||||
KernelDim3 LaunchOp::getGridSizeOperandValues() {
|
||||
|
||||
@@ -60,8 +60,7 @@ void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
|
||||
// Outline the `gpu.launch` operation body into a kernel function.
|
||||
Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
|
||||
Location loc = launchOp.getLoc();
|
||||
SmallVector<Type, 4> kernelOperandTypes;
|
||||
launchOp.getKernelOperandTypes(kernelOperandTypes);
|
||||
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
|
||||
FunctionType type =
|
||||
FunctionType::get(kernelOperandTypes, {}, module.getContext());
|
||||
std::string kernelFuncName =
|
||||
|
||||
@@ -220,10 +220,10 @@ void ModuleState::visitAttribute(Attribute attr) {
|
||||
|
||||
void ModuleState::visitOperation(Operation *op) {
|
||||
// Visit all the types used in the operation.
|
||||
for (auto *operand : op->getOperands())
|
||||
visitType(operand->getType());
|
||||
for (auto *result : op->getResults())
|
||||
visitType(result->getType());
|
||||
for (auto type : op->getOperandTypes())
|
||||
visitType(type);
|
||||
for (auto type : op->getResultTypes())
|
||||
visitType(type);
|
||||
|
||||
// Visit each of the attributes.
|
||||
for (auto elt : op->getAttrs())
|
||||
|
||||
@@ -593,11 +593,7 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type, 8> resultTypes;
|
||||
resultTypes.reserve(getNumResults());
|
||||
for (auto *result : getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
|
||||
SmallVector<Type, 8> resultTypes(getResultTypes());
|
||||
unsigned numRegions = getNumRegions();
|
||||
auto *newOp = Operation::create(getLoc(), getName(), operands, resultTypes,
|
||||
attrs, successors, numRegions,
|
||||
@@ -718,8 +714,8 @@ static Type getTensorOrVectorElementType(Type type) {
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
|
||||
for (auto *operand : op->getOperands()) {
|
||||
auto type = getTensorOrVectorElementType(operand->getType());
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
auto type = getTensorOrVectorElementType(opType);
|
||||
if (!type.isIntOrIndex())
|
||||
return op->emitOpError() << "requires an integer or index type";
|
||||
}
|
||||
@@ -727,8 +723,8 @@ LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
|
||||
for (auto *operand : op->getOperands()) {
|
||||
auto type = getTensorOrVectorElementType(operand->getType());
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
auto type = getTensorOrVectorElementType(opType);
|
||||
if (!type.isa<FloatType>())
|
||||
return op->emitOpError("requires a float type");
|
||||
}
|
||||
@@ -742,8 +738,8 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
|
||||
return success();
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
for (unsigned i = 1; i < nOperands; ++i)
|
||||
if (op->getOperand(i)->getType() != type)
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (opType != type)
|
||||
return op->emitOpError() << "requires all operands to have the same type";
|
||||
return success();
|
||||
}
|
||||
@@ -798,13 +794,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||
if (failed(verifyShapeMatch(op->getResult(i)->getType(), type)))
|
||||
for (auto resultType : op->getResultTypes()) {
|
||||
if (failed(verifyShapeMatch(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
}
|
||||
for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
|
||||
if (failed(verifyShapeMatch(op->getOperand(i)->getType(), type)))
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
|
||||
if (failed(verifyShapeMatch(opType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
}
|
||||
@@ -849,13 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
|
||||
return failure();
|
||||
|
||||
auto type = op->getResult(0)->getType();
|
||||
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
||||
if (op->getResult(i)->getType() != type)
|
||||
for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
|
||||
if (resultType != type)
|
||||
return op->emitOpError()
|
||||
<< "requires the same type for all operands and results";
|
||||
}
|
||||
for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
|
||||
if (op->getOperand(i)->getType() != type)
|
||||
for (auto opType : op->getOperandTypes()) {
|
||||
if (opType != type)
|
||||
return op->emitOpError()
|
||||
<< "requires the same type for all operands and results";
|
||||
}
|
||||
@@ -905,8 +901,8 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
auto elementType = getTensorOrVectorElementType(result->getType());
|
||||
for (auto resultType : op->getResultTypes()) {
|
||||
auto elementType = getTensorOrVectorElementType(resultType);
|
||||
bool isBoolType = elementType.isInteger(1);
|
||||
if (!isBoolType)
|
||||
return op->emitOpError() << "requires a bool result type";
|
||||
@@ -916,19 +912,17 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
|
||||
for (auto *result : op->getResults())
|
||||
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
|
||||
return op->emitOpError() << "requires a floating point type";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
auto type = getTensorOrVectorElementType(result->getType());
|
||||
if (!type.isIntOrIndex())
|
||||
for (auto resultType : op->getResultTypes())
|
||||
if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
|
||||
return op->emitOpError() << "requires an integer or index type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -177,11 +177,8 @@ static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
|
||||
SmallVector<Type, 8> types;
|
||||
for (auto *operand : op.getOperands())
|
||||
types.push_back(operand->getType());
|
||||
auto funcTy =
|
||||
FunctionType::get(types, op.getResult()->getType(), op.getContext());
|
||||
SmallVector<Type, 8> types(op.getOperandTypes());
|
||||
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
|
||||
|
||||
*p << op.getOperationName() << ' ' << *op.base() << '[';
|
||||
p->printOperands(std::next(op.operand_begin()), op.operand_end());
|
||||
@@ -326,11 +323,9 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
|
||||
// Reconstruct the function MLIR function type from operand and result types.
|
||||
SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
|
||||
SmallVector<Type, 8> argTypes;
|
||||
argTypes.reserve(op.getNumOperands());
|
||||
for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
|
||||
argTypes.push_back(operand->getType());
|
||||
SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||
SmallVector<Type, 8> argTypes(
|
||||
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
|
||||
|
||||
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
||||
}
|
||||
|
||||
@@ -262,17 +262,6 @@ protected:
|
||||
LLVM::LLVMDialect &dialect;
|
||||
};
|
||||
|
||||
// Given a range of MLIR typed objects, return a list of their types.
|
||||
template <typename T>
|
||||
SmallVector<Type, 4> getTypes(llvm::iterator_range<T> range) {
|
||||
SmallVector<Type, 4> types;
|
||||
types.reserve(llvm::size(range));
|
||||
for (auto operand : range) {
|
||||
types.push_back(operand->getType());
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
// Basic lowering implementation for one-to-one rewriting from Standard Ops to
|
||||
// LLVM Dialect Ops.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
@@ -288,8 +277,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
||||
|
||||
Type packedType;
|
||||
if (numResults != 0) {
|
||||
packedType =
|
||||
this->lowering.packFunctionResults(getTypes(op->getResults()));
|
||||
packedType = this->lowering.packFunctionResults(
|
||||
llvm::to_vector<4>(op->getResultTypes()));
|
||||
assert(packedType && "type conversion failed, such operation should not "
|
||||
"have been matched");
|
||||
}
|
||||
@@ -832,7 +821,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
||||
|
||||
// Otherwise, we need to pack the arguments into an LLVM struct type before
|
||||
// returning.
|
||||
auto packedType = lowering.packFunctionResults(getTypes(op->getOperands()));
|
||||
auto packedType =
|
||||
lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
|
||||
|
||||
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
|
||||
for (unsigned i = 0; i < numArguments; ++i) {
|
||||
|
||||
@@ -316,7 +316,7 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
|
||||
void rewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
PatternRewriter &rewriter) const override {
|
||||
edsc::ScopedContext edscContext(rewriter, op->getLoc());
|
||||
auto elementTy = lowering.convertType(*op->getResultTypes().begin());
|
||||
auto elementTy = lowering.convertType(*op->result_type_begin());
|
||||
Value *viewDescriptor = operands[0];
|
||||
ArrayRef<Value *> indices = operands.drop_front();
|
||||
auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
|
||||
|
||||
@@ -284,8 +284,8 @@ static LogicalResult verify(AllocOp op) {
|
||||
"operand count does not equal dimension plus symbol operand count");
|
||||
|
||||
// Verify that all operands are of type Index.
|
||||
for (auto *operand : op.getOperands())
|
||||
if (!operand->getType().isIndex())
|
||||
for (auto operandType : op.getOperandTypes())
|
||||
if (!operandType.isIndex())
|
||||
return op.emitOpError("requires operands to be of type Index");
|
||||
return success();
|
||||
}
|
||||
@@ -475,11 +475,8 @@ static LogicalResult verify(CallOp op) {
|
||||
}
|
||||
|
||||
FunctionType CallOp::getCalleeType() {
|
||||
SmallVector<Type, 4> resultTypes(getOperation()->getResultTypes());
|
||||
SmallVector<Type, 8> argTypes;
|
||||
argTypes.reserve(getNumOperands());
|
||||
for (auto *operand : getArgOperands())
|
||||
argTypes.push_back(operand->getType());
|
||||
SmallVector<Type, 4> resultTypes(getResultTypes());
|
||||
SmallVector<Type, 8> argTypes(getOperandTypes());
|
||||
return FunctionType::get(argTypes, resultTypes, getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user