build: update llvm tag to 9acc2f37 (#1828)

This commit makes the following changes:

- Update dialects to use fold API `kEmitFoldAdaptorFolder` and update
signature of `fold` methods (see PSA
https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618)
- Replace `makeArrayRef` with `ArrayRef` (see
https://reviews.llvm.org/D140896)
- Remove `TypeRange{}` arg from `b.create<scf::IfOp>` since builder no
longer takes that argument
- Make `func`s in `Torch/invalid.mlir` private, since symbol
declarations cannot be public. (see https://discourse.llvm.org/t/rfc-symbol-definition-declaration-x-visibility-checks/2140)
pull/1740/head
Ramiro Leal-Cavazos 2023-01-24 17:29:42 -08:00 committed by GitHub
parent 8ce2fffca5
commit 6c86bec04f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 147 additions and 145 deletions

View File

@ -43,6 +43,7 @@ def TMTensor_Dialect : Dialect {
to.
}];
let hasCanonicalizer = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
}
//===----------------------------------------------------------------------===//

View File

@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
}
auto scfIf = b.create<scf::IfOp>(
loc, TypeRange{}, cond,
loc, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices);
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
LogicalResult ScanOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}

@ -1 +1 @@
Subproject commit d23516e9ad477527a9db4d06b1fa9566680ac67c
Subproject commit 9acc2f37bdfce08ca0c2faec03392db10d1bb7a9

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 81e87a95b8683f1c3c33caf9e933897e0fc4a2b7
Subproject commit 4a173356bb1291b97046545429d7851cbc771d88

View File

@ -37,6 +37,7 @@ def Torch_Dialect : Dialect {
let hasRegionArgAttrVerify = 1;
let hasConstantMaterializer = 1;
let useDefaultTypePrinterParser = 0;
let useFoldAPI = kEmitFoldAdaptorFolder;
let extraClassDeclaration = [{
/// Parse a type registered to this dialect.

View File

@ -27,6 +27,7 @@ def TorchConversion_Dialect : Dialect {
}];
let hasConstantMaterializer = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
}
#endif // TORCHCONVERSION_BASE

View File

@ -463,8 +463,8 @@ public:
}
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize);
ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize);
// Association indices for expand/collapse ops. These two vectors
// are populated such that two entries at the same index corresponds
@ -1136,7 +1136,7 @@ public:
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
for (auto tensor : makeArrayRef(tensors).drop_front()) {
for (auto tensor : ArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
@ -1270,7 +1270,7 @@ public:
/*resultType=*/selfType,
/*inputs=*/broadcastedSrc,
/*outputs=*/self,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0];

View File

@ -1086,7 +1086,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
.value());

View File

@ -142,7 +142,7 @@ public:
// Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(maxTensorSizes),
context, llvm::ArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
@ -165,7 +165,7 @@ public:
SmallVector<int64_t> expandedInputSizes{
makeShapeTorchCompatible(inputType.getShape())[0], 1};
ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedInputSizes),
context, llvm::ArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
@ -286,9 +286,9 @@ public:
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
int64_t indexTensorSize = indexTensorType.getSizes()[0];
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype());
ValueTensorType expandedIndexTensorType =
ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(

View File

@ -718,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp
"non-const dim parameter unsupported");
int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -748,8 +748,8 @@ class ConvertAtenOneDimReductionOp
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef({reduceDim}));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -782,8 +782,8 @@ public:
reduceDims.push_back(i);
int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
llvm::makeArrayRef(reduceDims));
reduceDimsAttr =
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false;
return success();

View File

@ -507,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]);
}
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
if (!uncheckedCast)
return nullptr;
@ -570,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// Aten__RangeLengthOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
auto lo = operands[0];
auto hi = operands[1];
auto step = operands[2];
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
auto lo = adaptor.getLo();
auto hi = adaptor.getHi();
auto step = adaptor.getStep();
if (!lo || !hi || !step)
return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
@ -595,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
// Aten__DeriveIndexOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
auto index = operands[0];
auto start = operands[1];
auto step = operands[2];
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
auto index = adaptor.getIndex();
auto start = adaptor.getStart();
auto step = adaptor.getStep();
if (!index || !start || !step)
return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
@ -612,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
// Aten__Is__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
}
@ -620,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Isnot__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
}
@ -628,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Not__Op
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
bool value;
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
return nullptr;
@ -639,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
// AtenNeBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
@ -655,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand();
@ -667,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
@ -679,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
// AtenRoundOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf();
@ -691,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
// AtenTypeAsOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
Type inType = getSelf().getType();
Type newType = getOther().getType();
@ -705,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
bool nonBlocking, copyArg;
// The non_blocking arg must be `False`.
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
@ -736,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeLayoutOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
// The pin_memory arg should be either constant `False` or `none`.
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
bool pinMemory;
@ -797,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
// AtenViewOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr;
@ -812,7 +812,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
// AtenDimOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes())
return IntegerAttr::get(IntegerType::get(getContext(), 64),
@ -825,7 +825,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
// AtenLenTOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
// `len([1,1,1])` -> `3`, if it is not mutated.
if (auto listConstruct =
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
@ -853,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenLenStrOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
return getI64IntegerAttr(getContext(),
stringConstruct.getValueAttr().getValue().size());
@ -1092,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenSizeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
int64_t dim;
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
return nullptr;
@ -1132,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
// AtenLtFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a < b; });
}
@ -1141,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGtFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a > b; });
}
@ -1150,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGeFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a >= b; });
}
@ -1159,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenEqFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this,
[](double a, double b) { return a == b; });
}
@ -1225,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
// AtenNeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a != b; });
}
@ -1234,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a == b; });
}
@ -1243,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqStrOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), true);
@ -1259,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
// AtenLtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a < b; });
}
@ -1268,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenLeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a <= b; });
}
@ -1277,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a > b; });
}
@ -1286,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a >= b; });
}
@ -1295,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI1IntegerAttr(getContext(), c != 0.0);
@ -1306,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI1IntegerAttr(getContext(), c != 0);
@ -1317,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold int -> float conversion.
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
return FloatAttr::get(
mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -1334,9 +1334,9 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenIntScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -1351,7 +1351,7 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenIntBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b));
@ -1452,7 +1452,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
return success();
}
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1557,7 +1557,7 @@ void CopyToValueTensorOp::getEffects(
// ConstantNoneOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
return TypeAttr::get(Torch::NoneType::get(getContext()));
}
@ -1570,9 +1570,7 @@ void ConstantNoneOp::getAsmResultNames(
// ConstantStrOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
void ConstantStrOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
@ -1610,7 +1608,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
}
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1626,7 +1624,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
// ConstantFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1656,7 +1654,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
// ConstantNumberOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1684,7 +1682,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
// ConstantBoolOp
//===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
@ -1702,7 +1700,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(outputs[0], inputs[0]);
}
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
if (derefineOp.getOperand().getType() == getType())
return derefineOp.getOperand();
@ -1836,7 +1834,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenEqIntListOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsLiteral)
return nullptr;
@ -1976,7 +1974,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
// Aten__Getitem__DictStrOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getSelf());
if (!dictConstruct)
return nullptr;
@ -1994,7 +1992,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
// Aten__Contains__StrOp
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getDict());
if (!dictConstruct)
return nullptr;
@ -2017,7 +2015,7 @@ static bool isListConstructNotModified(Value torchList) {
});
}
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
auto itemConstruct = getItem();
if (!isListConstructNotModified(getL()))
return nullptr;
@ -2078,43 +2076,44 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenFloordivIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
adaptor.getOperands(),
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
}
//===----------------------------------------------------------------------===//
// AtenRemainderIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a % b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
}
//===----------------------------------------------------------------------===//
// AtenAddIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a + b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
}
//===----------------------------------------------------------------------===//
// AtenSubIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a - b; });
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenCatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr;
@ -2125,7 +2124,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenSliceTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
@ -2144,7 +2143,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenMulIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2159,42 +2158,45 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
// AtenSubOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1]) {
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a - b; });
}
return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a - b; });
adaptor.getOperands(),
[](double a, double b) -> double { return a - b; });
}
//===----------------------------------------------------------------------===//
// AtenDivOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1]) {
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
// Since AtenDivOp always returns float value, we don't need to deal with the
// case where the operands are both integers separately.
return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a / b; });
adaptor.getOperands(),
[](double a, double b) -> double { return a / b; });
}
//===----------------------------------------------------------------------===//
// AtenCeilScalarOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0]) {
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA()) {
return nullptr;
}
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
if (!floatValue) {
return nullptr;
}
@ -2207,7 +2209,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenNegIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI64IntegerAttr(getContext(), -c);
@ -2218,7 +2220,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
// AtenSqrtIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getF64FloatAttr(getContext(), std::sqrt(c));
@ -2229,7 +2231,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
// PrimDtypeOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) {
torch_upstream::ScalarType scalarType =
@ -2243,7 +2245,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2255,7 +2257,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Float.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2267,7 +2269,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenDivFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
double lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
@ -2284,7 +2286,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenDivIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2297,7 +2299,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI64IntegerAttr(getContext(), std::ceil(c));
@ -2308,13 +2310,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
// PrimMaxIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity.
if (getA() == getB())
return getA();
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return nullptr;
// Torch semantics are that !torch.int is 64-bit signed.
@ -2327,7 +2329,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
// PrimMinSelfIntOp
//===----------------------------------------------------------------------===//
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
if (!list)
return nullptr;

View File

@ -463,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
}
}
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype);
}
////===----------------------------------------------------------------------===//

View File

@ -72,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
: llvm::ArrayRef(sizes),
tensorType.getOptionalDtype());
return resultType;
}
@ -108,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
valueType
.getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(valueType.getSizes()),
: llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
return rewriter
@ -142,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) {
SmallVector<int64_t> sizes;
Type rank0TensorTy = inputType.getWithSizesAndDtype(
makeArrayRef(sizes), inputType.getOptionalDtype());
ArrayRef(sizes), inputType.getOptionalDtype());
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{});
@ -940,7 +940,7 @@ public:
SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne);
@ -1077,9 +1077,9 @@ public:
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedIntSizes), dtype);
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
Type expandedType =
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims =
@ -2004,7 +2004,7 @@ public:
auto inputType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
rewriter.getF64Type());
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
Value output;
@ -2082,8 +2082,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
std::vector<int64_t> meanVarSizes(inputRank, 1);
for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
@ -2320,7 +2320,7 @@ class DecomposeAtenNativeBatchNormOp
runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList);
@ -2466,8 +2466,7 @@ public:
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
op.getFillValue());
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
@ -2503,7 +2502,7 @@ public:
SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
@ -2573,8 +2572,7 @@ public:
SmallVector<int64_t> empty;
auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType =
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(), tensorType, op.getFillValue());
fillVal =
@ -3216,7 +3214,7 @@ public:
sizes.resize(srcShape.size() + 1, kUnknownSize);
}
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
@ -3314,7 +3312,7 @@ public:
op, "Expected the input tensor to have sizes");
BaseTensorType subType =
inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
resultType.getOptionalDtype())
.cast<BaseTensorType>();

View File

@ -129,8 +129,7 @@ public:
// Truncate the list of users to the number of users we're going to
// interpret.
allUsers.resize(numUsersToInterpret);
auto usersToInterpret =
makeArrayRef(allUsers).take_front(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then
@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(makeArrayRef(sizes),
.getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype())
.cast<BaseTensorType>();

View File

@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
// FromI64Op
//===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToI64Op
//===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) {
return attr;
} else {
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToF64Op
//===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// FromF64Op
//===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) {
return attr;
} else {

View File

@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
loc,
/*inputs=*/from,
/*outputs=*/to,
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
/*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front());

View File

@ -101,17 +101,17 @@ torch.class_type @c {
// -----
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
func.func private @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
// -----
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
func.func @f(%arg0: i32 {torch.type_bound = 1})
func.func private @f(%arg0: i32 {torch.type_bound = 1})
// -----
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
func.func @f(%arg0: i32 {torch.type_bound = i32})
func.func private @f(%arg0: i32 {torch.type_bound = i32})
// -----