diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td index 9a641bcba..4c842db6d 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorBase.td @@ -43,6 +43,7 @@ def TMTensor_Dialect : Dialect { to. }]; let hasCanonicalizer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 267a4c13d..6ce9b502f 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, } auto scfIf = b.create( - loc, TypeRange{}, cond, + loc, cond, [&](OpBuilder &b, Location loc) { if (isInclusive) { auto value = b.create(loc, input(), indices); @@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) { return success(folded); } -LogicalResult ScanOp::fold(ArrayRef, +LogicalResult ScanOp::fold(FoldAdaptor adaptor, SmallVectorImpl &) { return foldMemRefCast(*this); } diff --git a/externals/llvm-project b/externals/llvm-project index d23516e9a..9acc2f37b 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit d23516e9ad477527a9db4d06b1fa9566680ac67c +Subproject commit 9acc2f37bdfce08ca0c2faec03392db10d1bb7a9 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 81e87a95b..4a173356b 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 81e87a95b8683f1c3c33caf9e933897e0fc4a2b7 +Subproject commit 4a173356bb1291b97046545429d7851cbc771d88 diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td index 14fe0e661..3110c612c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td @@ -37,6 +37,7 @@ def Torch_Dialect : Dialect { let hasRegionArgAttrVerify = 1; let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 0; + let useFoldAPI = kEmitFoldAdaptorFolder; let extraClassDeclaration = [{ /// Parse a type registered to this dialect. diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td index fadeb5c8c..2fc7fc45e 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td @@ -27,6 +27,7 @@ def TorchConversion_Dialect : Dialect { }]; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // TORCHCONVERSION_BASE diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6ace4926d..a316e1e9d 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -463,8 +463,8 @@ public: } SmallVector inputSize = getTensorSizes(rewriter, loc, input); - ArrayRef outputShapeInt = llvm::makeArrayRef(outputSizeInt); - ArrayRef inputShapeInt = llvm::makeArrayRef(inputSize); + ArrayRef outputShapeInt = llvm::ArrayRef(outputSizeInt); + ArrayRef inputShapeInt = llvm::ArrayRef(inputSize); // Association indices for expand/collapse ops. These two vectors // are populated such that two entries at the same index corresponds @@ -1136,7 +1136,7 @@ public: Value dimIndex = rewriter.createOrFold( loc, rewriter.getIndexAttr(dim)); - for (auto tensor : makeArrayRef(tensors).drop_front()) { + for (auto tensor : ArrayRef(tensors).drop_front()) { auto size = rewriter.createOrFold(loc, tensor, dimIndex); resultDimSize = rewriter.createOrFold(loc, resultDimSize, size); @@ -1270,7 +1270,7 @@ public: /*resultType=*/selfType, /*inputs=*/broadcastedSrc, /*outputs=*/self, - /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { Value result = args[0]; diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index a773ae652..35776bb88 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -1086,7 +1086,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Reshape input auto mhloInput = rewriter.create( op->getLoc(), mhloBatchNormOutTy, input, - mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), + mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .value()); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b80c35c14..b64c80c13 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -142,7 +142,7 @@ public: // Finding the maximum value in the input tensor. SmallVector maxTensorSizes; ValueTensorType maxTensorType = ValueTensorType::get( - context, llvm::makeArrayRef(maxTensorSizes), + context, llvm::ArrayRef(maxTensorSizes), torchTypeInput.getType().cast().getDtype()); Value maxTensor = rewriter.create(loc, maxTensorType, torchTypeInput); @@ -165,7 +165,7 @@ public: SmallVector expandedInputSizes{ makeShapeTorchCompatible(inputType.getShape())[0], 1}; ValueTensorType expandInputType = ValueTensorType::get( - context, llvm::makeArrayRef(expandedInputSizes), + context, llvm::ArrayRef(expandedInputSizes), torchTypeInput.getType().cast().getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); @@ -286,9 +286,9 @@ public: auto indexTensorType = indexTensor.getType().cast(); int64_t indexTensorSize = indexTensorType.getSizes()[0]; SmallVector expandedIndexTensorSizes{indexTensorSize, 1}; - ValueTensorType expandedIndexTensorType = ValueTensorType::get( - context, llvm::makeArrayRef(expandedIndexTensorSizes), - indexTensorType.getDtype()); + ValueTensorType expandedIndexTensorType = + ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes), + indexTensorType.getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedIndexTensor = rewriter.create( diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 26ac46e40..b285a8a0d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -718,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); - reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, - llvm::makeArrayRef(reduceDims)); + reduceDimsAttr = + DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); keepDims = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) @@ -748,8 +748,8 @@ class ConvertAtenOneDimReductionOp return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); - reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, - llvm::makeArrayRef({reduceDim})); + reduceDimsAttr = + DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim})); keepDims = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) @@ -782,8 +782,8 @@ public: reduceDims.push_back(i); int64_t N = selfTy.getRank(); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); - reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, - llvm::makeArrayRef(reduceDims)); + reduceDimsAttr = + DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims)); keepDims = false; return success(); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5e569ce5a..e3fb59fc0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -507,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs, return isValidSubtype(inputs[0], outputs[0]); } -OpFoldResult DerefineOp::fold(ArrayRef operands) { +OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) { auto uncheckedCast = getOperand().getDefiningOp(); if (!uncheckedCast) return nullptr; @@ -570,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { // Aten__RangeLengthOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__RangeLengthOp::fold(ArrayRef operands) { - auto lo = operands[0]; - auto hi = operands[1]; - auto step = operands[2]; +OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { + auto lo = adaptor.getLo(); + auto hi = adaptor.getHi(); + auto step = adaptor.getStep(); if (!lo || !hi || !step) return nullptr; auto loInt = lo.dyn_cast_or_null().getValue(); @@ -595,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef operands) { // Aten__DeriveIndexOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef operands) { - auto index = operands[0]; - auto start = operands[1]; - auto step = operands[2]; +OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { + auto index = adaptor.getIndex(); + auto start = adaptor.getStart(); + auto step = adaptor.getStep(); if (!index || !start || !step) return nullptr; auto indexInt = index.dyn_cast_or_null().getValue(); @@ -612,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef operands) { // Aten__Is__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Is__Op::fold(ArrayRef operands) { +OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); } @@ -620,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef operands) { // Aten__Isnot__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Isnot__Op::fold(ArrayRef operands) { +OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); } @@ -628,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef operands) { // Aten__Not__Op //===----------------------------------------------------------------------===// -OpFoldResult Aten__Not__Op::fold(ArrayRef operands) { +OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { bool value; if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) return nullptr; @@ -639,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef operands) { // AtenNeBoolOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNeBoolOp::fold(ArrayRef operands) { +OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); @@ -655,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef operands) { // AtenSqueezeOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqueezeOp::fold(ArrayRef operands) { +OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(); @@ -667,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef operands) { // AtenSqueezeDimOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqueezeDimOp::fold(ArrayRef operands) { +OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); @@ -679,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef operands) { // AtenRoundOp //===----------------------------------------------------------------------===// -OpFoldResult AtenRoundOp::fold(ArrayRef operands) { +OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { if (auto selfType = getSelf().getType().dyn_cast()) { if (selfType.hasDtype() && selfType.getDtype().isa()) return getSelf(); @@ -691,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef operands) { // AtenTypeAsOp //===----------------------------------------------------------------------===// -OpFoldResult AtenTypeAsOp::fold(ArrayRef operands) { +OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) { Type inType = getSelf().getType(); Type newType = getOther().getType(); @@ -705,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef operands) { // AtenToDtypeOp //===----------------------------------------------------------------------===// -OpFoldResult AtenToDtypeOp::fold(ArrayRef operands) { +OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { bool nonBlocking, copyArg; // The non_blocking arg must be `False`. if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || @@ -736,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef operands) { // AtenToDtypeLayoutOp //===----------------------------------------------------------------------===// -OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef operands) { +OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { // The pin_memory arg should be either constant `False` or `none`. if (!getPinMemory().getType().isa()) { bool pinMemory; @@ -797,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef operands) { // AtenViewOp //===----------------------------------------------------------------------===// -OpFoldResult AtenViewOp::fold(ArrayRef operands) { +OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { auto inputType = getOperand(0).getType().dyn_cast(); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; @@ -812,7 +812,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef operands) { // AtenDimOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDimOp::fold(ArrayRef operands) { +OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes()) return IntegerAttr::get(IntegerType::get(getContext(), 64), @@ -825,7 +825,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef operands) { // AtenLenTOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLenTOp::fold(ArrayRef operands) { +OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) { // `len([1,1,1])` -> `3`, if it is not mutated. if (auto listConstruct = getOperand().getDefiningOp()) { @@ -853,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenLenStrOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLenStrOp::fold(ArrayRef operands) { +OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) { if (auto stringConstruct = getS().getDefiningOp()) return getI64IntegerAttr(getContext(), stringConstruct.getValueAttr().getValue().size()); @@ -1092,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenSizeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSizeIntOp::fold(ArrayRef operands) { +OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { int64_t dim; if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) return nullptr; @@ -1132,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) { // AtenLtFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLtFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a < b; }); } @@ -1141,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef operands) { // AtenGtFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGtFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a > b; }); } @@ -1150,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef operands) { // AtenGeFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGeFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a >= b; }); } @@ -1159,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef operands) { // AtenEqFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a == b; }); } @@ -1225,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, // AtenNeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNeIntOp::fold(ArrayRef operands) { +OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a != b; }); } @@ -1234,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef operands) { // AtenEqIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqIntOp::fold(ArrayRef operands) { +OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a == b; }); } @@ -1243,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef operands) { // AtenEqStrOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqStrOp::fold(ArrayRef operands) { +OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return getI1IntegerAttr(getContext(), true); @@ -1259,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef operands) { // AtenLtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLtIntOp::fold(ArrayRef operands) { +OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a < b; }); } @@ -1268,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef operands) { // AtenLeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenLeIntOp::fold(ArrayRef operands) { +OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a <= b; }); } @@ -1277,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef operands) { // AtenGtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { +OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a > b; }); } @@ -1286,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { // AtenGeIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenGeIntOp::fold(ArrayRef operands) { +OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a >= b; }); } @@ -1295,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef operands) { // AtenBoolFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenBoolFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI1IntegerAttr(getContext(), c != 0.0); @@ -1306,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef operands) { // AtenBoolIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenBoolIntOp::fold(ArrayRef operands) { +OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI1IntegerAttr(getContext(), c != 0); @@ -1317,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef operands) { // AtenFloatScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloatScalarOp::fold(ArrayRef operands) { +OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { // Constant fold int -> float conversion. - if (auto integerAttr = operands[0].dyn_cast_or_null()) { + if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); @@ -1334,9 +1334,9 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef operands) { // AtenIntScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntScalarOp::fold(ArrayRef operands) { +OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. - if (auto floatAttr = operands[0].dyn_cast_or_null()) { + if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), static_cast(floatAttr.getValue().convertToDouble())); @@ -1351,7 +1351,7 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef operands) { // AtenIntBoolOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntBoolOp::fold(ArrayRef operands) { +OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { bool b; if (matchPattern(getOperand(), m_TorchConstantBool(&b))) { return getI64IntegerAttr(getContext(), static_cast(b)); @@ -1452,7 +1452,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( return success(); } -OpFoldResult ValueTensorLiteralOp::fold(ArrayRef operands) { +OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } @@ -1557,7 +1557,7 @@ void CopyToValueTensorOp::getEffects( // ConstantNoneOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantNoneOp::fold(ArrayRef operands) { +OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) { return TypeAttr::get(Torch::NoneType::get(getContext())); } @@ -1570,9 +1570,7 @@ void ConstantNoneOp::getAsmResultNames( // ConstantStrOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantStrOp::fold(ArrayRef operands) { - return getValueAttr(); -} +OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantStrOp::getAsmResultNames( function_ref setNameFn) { @@ -1610,7 +1608,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); } -OpFoldResult Torch::ConstantIntOp::fold(ArrayRef operands) { +OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } @@ -1626,7 +1624,7 @@ void Torch::ConstantIntOp::getAsmResultNames( // ConstantFloatOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef operands) { +OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } @@ -1656,7 +1654,7 @@ void Torch::ConstantFloatOp::getAsmResultNames( // ConstantNumberOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef operands) { +OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } @@ -1684,7 +1682,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns( // ConstantBoolOp //===----------------------------------------------------------------------===// -OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef operands) { +OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } @@ -1702,7 +1700,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs, return isValidSubtype(outputs[0], inputs[0]); } -OpFoldResult PrimUncheckedCastOp::fold(ArrayRef operands) { +OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) { if (auto derefineOp = getX().getDefiningOp()) { if (derefineOp.getOperand().getType() == getType()) return derefineOp.getOperand(); @@ -1836,7 +1834,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // AtenEqIntListOp //===----------------------------------------------------------------------===// -OpFoldResult AtenEqIntListOp::fold(ArrayRef operands) { +OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { auto lhsLiteral = getA().getDefiningOp(); if (!lhsLiteral) return nullptr; @@ -1976,7 +1974,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) { // Aten__Getitem__DictStrOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef operands) { +OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getSelf()); if (!dictConstruct) return nullptr; @@ -1994,7 +1992,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef operands) { // Aten__Contains__StrOp //===----------------------------------------------------------------------===// -OpFoldResult Aten__Contains__StrOp::fold(ArrayRef operands) { +OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getDict()); if (!dictConstruct) return nullptr; @@ -2017,7 +2015,7 @@ static bool isListConstructNotModified(Value torchList) { }); } -OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { +OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) { auto itemConstruct = getItem(); if (!isListConstructNotModified(getL())) return nullptr; @@ -2078,43 +2076,44 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenFloordivIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloordivIntOp::fold(ArrayRef operands) { +OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( - operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); + adaptor.getOperands(), + [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenRemainderIntOp::fold(ArrayRef operands) { +OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( - operands, [](int64_t a, int64_t b) { return a % b; }); + adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAddIntOp::fold(ArrayRef operands) { +OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( - operands, [](int64_t a, int64_t b) { return a + b; }); + adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenSubIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSubIntOp::fold(ArrayRef operands) { +OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( - operands, [](int64_t a, int64_t b) { return a - b; }); + adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCatOp::fold(llvm::ArrayRef operands) { +OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; @@ -2125,7 +2124,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef operands) { // AtenSliceTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef operands) { +OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) @@ -2144,7 +2143,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef operands) { // AtenMulIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { +OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); @@ -2159,42 +2158,45 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef operands) { // AtenSubOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSubOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) { +OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } - if (operands[0].isa() && operands[1].isa()) { + if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( - operands, [](int64_t a, int64_t b) -> int64_t { return a - b; }); + adaptor.getOperands(), + [](int64_t a, int64_t b) -> int64_t { return a - b; }); } return atenBinaryFloatOperatorFoldHelper( - operands, [](double a, double b) -> double { return a - b; }); + adaptor.getOperands(), + [](double a, double b) -> double { return a - b; }); } //===----------------------------------------------------------------------===// // AtenDivOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) { +OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } // Since AtenDivOp always returns float value, we don't need to deal with the // case where the operands are both integers separately. return atenBinaryFloatOperatorFoldHelper( - operands, [](double a, double b) -> double { return a / b; }); + adaptor.getOperands(), + [](double a, double b) -> double { return a / b; }); } //===----------------------------------------------------------------------===// // AtenCeilScalarOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCeilScalarOp::fold(ArrayRef operands) { - if (!operands[0]) { +OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA()) { return nullptr; } - auto floatValue = operands[0].dyn_cast_or_null(); + auto floatValue = adaptor.getA().dyn_cast_or_null(); if (!floatValue) { return nullptr; } @@ -2207,7 +2209,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef operands) { // AtenNegIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenNegIntOp::fold(ArrayRef operands) { +OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI64IntegerAttr(getContext(), -c); @@ -2218,7 +2220,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef operands) { // AtenSqrtIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { +OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getF64FloatAttr(getContext(), std::sqrt(c)); @@ -2229,7 +2231,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { // PrimDtypeOp //===----------------------------------------------------------------------===// -OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { +OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { BaseTensorType tensorType = getA().getType().cast(); if (tensorType.hasDtype()) { torch_upstream::ScalarType scalarType = @@ -2243,7 +2245,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { // AtenIntTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntTensorOp::fold(ArrayRef operands) { +OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) @@ -2255,7 +2257,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef operands) { // AtenFloatTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenFloatTensorOp::fold(ArrayRef operands) { +OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Float.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) @@ -2267,7 +2269,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef operands) { // AtenDivFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { double lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); @@ -2284,7 +2286,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef operands) { // AtenDivIntOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDivIntOp::fold(ArrayRef operands) { +OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); @@ -2297,7 +2299,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef operands) { // AtenCeilFloatOp //===----------------------------------------------------------------------===// -OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { +OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI64IntegerAttr(getContext(), std::ceil(c)); @@ -2308,13 +2310,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef operands) { // PrimMaxIntOp //===----------------------------------------------------------------------===// -OpFoldResult PrimMaxIntOp::fold(ArrayRef operands) { +OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); - auto lhs = operands[0].dyn_cast_or_null(); - auto rhs = operands[1].dyn_cast_or_null(); + auto lhs = adaptor.getA().dyn_cast_or_null(); + auto rhs = adaptor.getB().dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -2327,7 +2329,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef operands) { // PrimMinSelfIntOp //===----------------------------------------------------------------------===// -OpFoldResult PrimMinSelfIntOp::fold(ArrayRef operands) { +OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { auto list = getOperand().getDefiningOp(); if (!list) return nullptr; diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 968f809a0..440d20158 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -463,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { } } - return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype); + return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype); } ////===----------------------------------------------------------------------===// @@ -505,4 +505,4 @@ DictType::verify(llvm::function_ref emitError, return failure(); } return success(); -} \ No newline at end of file +} diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6f5984110..9e88a2344 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -72,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, Type resultType = tensorType.getWithSizesAndDtype( sizes.size() == 0 ? std::optional>() - : llvm::makeArrayRef(sizes), + : llvm::ArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } @@ -108,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, valueType .getWithSizesAndDtype( !valueType.hasSizes() ? std::optional>() - : llvm::makeArrayRef(valueType.getSizes()), + : llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed)) .cast(); return rewriter @@ -142,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, BaseTensorType inputType, Value scalar) { SmallVector sizes; Type rank0TensorTy = inputType.getWithSizesAndDtype( - makeArrayRef(sizes), inputType.getOptionalDtype()); + ArrayRef(sizes), inputType.getOptionalDtype()); Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), ValueRange{}); @@ -940,7 +940,7 @@ public: SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; - Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), + Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), selfTy.getOptionalDtype()); Value slice0 = rewriter.create( loc, sliceTy, input, dim, negShift, constNone, constOne); @@ -1077,9 +1077,9 @@ public: Type dtype = self.getType().cast().getOptionalDtype(); Type unsqueezedType = ValueTensorType::get( - context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = ValueTensorType::get( - context, llvm::makeArrayRef(expandedIntSizes), dtype); + context, llvm::ArrayRef(unsqueezedIntSizes), dtype); + Type expandedType = + ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = @@ -2004,7 +2004,7 @@ public: auto inputType = input.getType().cast(); SmallVector empty; - Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty), + Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), rewriter.getF64Type()); Value prob = rewriter.create(loc, tensorType, p); Value output; @@ -2082,8 +2082,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { std::vector meanVarSizes(inputRank, 1); for (int i = 0; i < axis; i++) meanVarSizes[i] = input.getSizes()[i]; - auto meanVarType = input.getWithSizesAndDtype( - llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype()); + auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), + input.getOptionalDtype()); auto nativeLayerNorm = rewriter.create( loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); @@ -2320,7 +2320,7 @@ class DecomposeAtenNativeBatchNormOp runningStatsShapeInt[1] = kUnknownSize; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( - context, llvm::makeArrayRef(runningStatsShapeInt), dtype); + context, llvm::ArrayRef(runningStatsShapeInt), dtype); runningMean = rewriter.create(loc, reshapeType, runningMean, runningStatsSizeList); @@ -2466,8 +2466,7 @@ public: SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); - Type tensorType = - outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); + Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create(loc, tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); @@ -2503,7 +2502,7 @@ public: SmallVector transposeShape = llvm::to_vector(llvm::reverse(weightType.getSizes())); Type transposeType = weightType.getWithSizesAndDtype( - llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype()); + llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = rewriter.create(loc, transposeType, weight); @@ -2573,8 +2572,7 @@ public: SmallVector empty; auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); - Type tensorType = - outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype); + Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); Value fillVal = rewriter.create( op.getLoc(), tensorType, op.getFillValue()); fillVal = @@ -3216,7 +3214,7 @@ public: sizes.resize(srcShape.size() + 1, kUnknownSize); } Type srcType = srcTensorType.getWithSizesAndDtype( - llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype()); + llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype()); src = rewriter.create(loc, srcType, src, dim); rewriter.replaceOpWithNewOp( op, op.getSelf().getType(), self, src, dim, start, startPlusOne, @@ -3314,7 +3312,7 @@ public: op, "Expected the input tensor to have sizes"); BaseTensorType subType = inputType - .getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()), + .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), resultType.getOptionalDtype()) .cast(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 0308520d2..f8d3651d9 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -129,8 +129,7 @@ public: // Truncate the list of users to the number of users we're going to // interpret. allUsers.resize(numUsersToInterpret); - auto usersToInterpret = - makeArrayRef(allUsers).take_front(numUsersToInterpret); + auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); // For each mutating op (which must be in the same block), we save the // current state of the list as a vector of Value's. These will then @@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, auto originalResultType = result.getType().cast(); auto impliedTypesFromShape = originalResultType.cast() - .getWithSizesAndDtype(makeArrayRef(sizes), + .getWithSizesAndDtype(ArrayRef(sizes), originalResultType.getOptionalDtype()) .cast(); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index 1b83cce37..c858edb62 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() { // FromI64Op //===----------------------------------------------------------------------===// -OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { - auto attr = operands[0].dyn_cast_or_null(); +OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); if (attr) { return attr; } else { @@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef operands) { // ToI64Op //===----------------------------------------------------------------------===// -OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { - auto attr = operands[0].dyn_cast_or_null(); +OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); if (attr) { return attr; } else { @@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef operands) { // ToF64Op //===----------------------------------------------------------------------===// -OpFoldResult ToF64Op::fold(llvm::ArrayRef operands) { - auto attr = operands[0].dyn_cast_or_null(); +OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); if (attr) { return attr; } else { @@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef operands) { // FromF64Op //===----------------------------------------------------------------------===// -OpFoldResult FromF64Op::fold(llvm::ArrayRef operands) { - auto attr = operands[0].dyn_cast_or_null(); +OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { + auto attr = adaptor.getOperand().dyn_cast_or_null(); if (attr) { return attr; } else { diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index a8f6766f2..597f46381 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, loc, /*inputs=*/from, /*outputs=*/to, - /*indexingMaps=*/llvm::makeArrayRef({id, id}), + /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args.front()); diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 84e7d63da..8ce52bd89 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -101,17 +101,17 @@ torch.class_type @c { // ----- // expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}} -func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) +func.func private @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) // ----- // expected-error @+1 {{'torch.type_bound' must be TypeAttr}} -func.func @f(%arg0: i32 {torch.type_bound = 1}) +func.func private @f(%arg0: i32 {torch.type_bound = 1}) // ----- // expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}} -func.func @f(%arg0: i32 {torch.type_bound = i32}) +func.func private @f(%arg0: i32 {torch.type_bound = i32}) // -----