build: update llvm tag to 9acc2f37 (#1828)

This commit makes the following changes:

- Update dialects to use fold API `kEmitFoldAdaptorFolder` and update
signature of `fold` methods (see PSA
https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618)
- Replace `makeArrayRef` with `ArrayRef` (see
https://reviews.llvm.org/D140896)
- Remove `TypeRange{}` arg from `b.create<scf::IfOp>` since builder no
longer takes that argument
- Make `func`s in `Torch/invalid.mlir` private, since symbol
declarations cannot be public. (see https://discourse.llvm.org/t/rfc-symbol-definition-declaration-x-visibility-checks/2140)
pull/1740/head
Ramiro Leal-Cavazos 2023-01-24 17:29:42 -08:00 committed by GitHub
parent 8ce2fffca5
commit 6c86bec04f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 147 additions and 145 deletions

View File

@ -43,6 +43,7 @@ def TMTensor_Dialect : Dialect {
to. to.
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
} }
auto scfIf = b.create<scf::IfOp>( auto scfIf = b.create<scf::IfOp>(
loc, TypeRange{}, cond, loc, cond,
[&](OpBuilder &b, Location loc) { [&](OpBuilder &b, Location loc) {
if (isInclusive) { if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices); auto value = b.create<memref::LoadOp>(loc, input(), indices);
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded); return success(folded);
} }
LogicalResult ScanOp::fold(ArrayRef<Attribute>, LogicalResult ScanOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &) { SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this); return foldMemRefCast(*this);
} }

@ -1 +1 @@
Subproject commit d23516e9ad477527a9db4d06b1fa9566680ac67c Subproject commit 9acc2f37bdfce08ca0c2faec03392db10d1bb7a9

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 81e87a95b8683f1c3c33caf9e933897e0fc4a2b7 Subproject commit 4a173356bb1291b97046545429d7851cbc771d88

View File

@ -37,6 +37,7 @@ def Torch_Dialect : Dialect {
let hasRegionArgAttrVerify = 1; let hasRegionArgAttrVerify = 1;
let hasConstantMaterializer = 1; let hasConstantMaterializer = 1;
let useDefaultTypePrinterParser = 0; let useDefaultTypePrinterParser = 0;
let useFoldAPI = kEmitFoldAdaptorFolder;
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Parse a type registered to this dialect. /// Parse a type registered to this dialect.

View File

@ -27,6 +27,7 @@ def TorchConversion_Dialect : Dialect {
}]; }];
let hasConstantMaterializer = 1; let hasConstantMaterializer = 1;
let useFoldAPI = kEmitFoldAdaptorFolder;
} }
#endif // TORCHCONVERSION_BASE #endif // TORCHCONVERSION_BASE

View File

@ -463,8 +463,8 @@ public:
} }
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input); SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt); ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt);
ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize); ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize);
// Association indices for expand/collapse ops. These two vectors // Association indices for expand/collapse ops. These two vectors
// are populated such that two entries at the same index corresponds // are populated such that two entries at the same index corresponds
@ -1136,7 +1136,7 @@ public:
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>( Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim)); loc, rewriter.getIndexAttr(dim));
for (auto tensor : makeArrayRef(tensors).drop_front()) { for (auto tensor : ArrayRef(tensors).drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex); auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize = resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size); rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
@ -1270,7 +1270,7 @@ public:
/*resultType=*/selfType, /*resultType=*/selfType,
/*inputs=*/broadcastedSrc, /*inputs=*/broadcastedSrc,
/*outputs=*/self, /*outputs=*/self,
/*indexingMaps=*/llvm::makeArrayRef({id, id}), /*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) { [](OpBuilder &b, Location loc, ValueRange args) {
Value result = args[0]; Value result = args[0];

View File

@ -1086,7 +1086,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
// Reshape input // Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>( auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input, op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())}) {static_cast<int64_t>(inputFlattenShape.size())})
.value()); .value());

View File

@ -142,7 +142,7 @@ public:
// Finding the maximum value in the input tensor. // Finding the maximum value in the input tensor.
SmallVector<int64_t> maxTensorSizes; SmallVector<int64_t> maxTensorSizes;
ValueTensorType maxTensorType = ValueTensorType::get( ValueTensorType maxTensorType = ValueTensorType::get(
context, llvm::makeArrayRef(maxTensorSizes), context, llvm::ArrayRef(maxTensorSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value maxTensor = Value maxTensor =
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput); rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
@ -165,7 +165,7 @@ public:
SmallVector<int64_t> expandedInputSizes{ SmallVector<int64_t> expandedInputSizes{
makeShapeTorchCompatible(inputType.getShape())[0], 1}; makeShapeTorchCompatible(inputType.getShape())[0], 1};
ValueTensorType expandInputType = ValueTensorType::get( ValueTensorType expandInputType = ValueTensorType::get(
context, llvm::makeArrayRef(expandedInputSizes), context, llvm::ArrayRef(expandedInputSizes),
torchTypeInput.getType().cast<ValueTensorType>().getDtype()); torchTypeInput.getType().cast<ValueTensorType>().getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
@ -286,9 +286,9 @@ public:
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>(); auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
int64_t indexTensorSize = indexTensorType.getSizes()[0]; int64_t indexTensorSize = indexTensorType.getSizes()[0];
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1}; SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
ValueTensorType expandedIndexTensorType = ValueTensorType::get( ValueTensorType expandedIndexTensorType =
context, llvm::makeArrayRef(expandedIndexTensorSizes), ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes),
indexTensorType.getDtype()); indexTensorType.getDtype());
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>( Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1)); loc, rewriter.getI64IntegerAttr(1));
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>( Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(

View File

@ -718,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
int64_t N = reduceDims.size(); int64_t N = reduceDims.size();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, reduceDimsAttr =
llvm::makeArrayRef(reduceDims)); DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false; keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -748,8 +748,8 @@ class ConvertAtenOneDimReductionOp
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported"); "non-const dim parameter unsupported");
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, reduceDimsAttr =
llvm::makeArrayRef({reduceDim})); DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
keepDims = false; keepDims = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims))) if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
@ -782,8 +782,8 @@ public:
reduceDims.push_back(i); reduceDims.push_back(i);
int64_t N = selfTy.getRank(); int64_t N = selfTy.getRank();
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type()); auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, reduceDimsAttr =
llvm::makeArrayRef(reduceDims)); DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
keepDims = false; keepDims = false;
return success(); return success();

View File

@ -507,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]); return isValidSubtype(inputs[0], outputs[0]);
} }
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) { OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>(); auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
if (!uncheckedCast) if (!uncheckedCast)
return nullptr; return nullptr;
@ -570,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
// Aten__RangeLengthOp // Aten__RangeLengthOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
auto lo = operands[0]; auto lo = adaptor.getLo();
auto hi = operands[1]; auto hi = adaptor.getHi();
auto step = operands[2]; auto step = adaptor.getStep();
if (!lo || !hi || !step) if (!lo || !hi || !step)
return nullptr; return nullptr;
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue(); auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
@ -595,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
// Aten__DeriveIndexOp // Aten__DeriveIndexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
auto index = operands[0]; auto index = adaptor.getIndex();
auto start = operands[1]; auto start = adaptor.getStart();
auto step = operands[2]; auto step = adaptor.getStep();
if (!index || !start || !step) if (!index || !start || !step)
return nullptr; return nullptr;
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue(); auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
@ -612,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
// Aten__Is__Op // Aten__Is__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
} }
@ -620,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Isnot__Op // Aten__Isnot__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
} }
@ -628,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
// Aten__Not__Op // Aten__Not__Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
bool value; bool value;
if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
return nullptr; return nullptr;
@ -639,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
// AtenNeBoolOp // AtenNeBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1)) if (getOperand(0) == getOperand(1))
return IntegerAttr::get(IntegerType::get(getContext(), 1), false); return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
@ -655,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeOp // AtenSqueezeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(); return getOperand();
@ -667,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
// AtenSqueezeDimOp // AtenSqueezeDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0); return getOperand(0);
@ -679,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
// AtenRoundOp // AtenRoundOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) { if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>()) if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
return getSelf(); return getSelf();
@ -691,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
// AtenTypeAsOp // AtenTypeAsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
Type inType = getSelf().getType(); Type inType = getSelf().getType();
Type newType = getOther().getType(); Type newType = getOther().getType();
@ -705,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeOp // AtenToDtypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
bool nonBlocking, copyArg; bool nonBlocking, copyArg;
// The non_blocking arg must be `False`. // The non_blocking arg must be `False`.
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
@ -736,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenToDtypeLayoutOp // AtenToDtypeLayoutOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
// The pin_memory arg should be either constant `False` or `none`. // The pin_memory arg should be either constant `False` or `none`.
if (!getPinMemory().getType().isa<Torch::NoneType>()) { if (!getPinMemory().getType().isa<Torch::NoneType>()) {
bool pinMemory; bool pinMemory;
@ -797,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
// AtenViewOp // AtenViewOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>(); auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
return nullptr; return nullptr;
@ -812,7 +812,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
// AtenDimOp // AtenDimOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes()) if (tensorType.hasSizes())
return IntegerAttr::get(IntegerType::get(getContext(), 64), return IntegerAttr::get(IntegerType::get(getContext(), 64),
@ -825,7 +825,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
// AtenLenTOp // AtenLenTOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
// `len([1,1,1])` -> `3`, if it is not mutated. // `len([1,1,1])` -> `3`, if it is not mutated.
if (auto listConstruct = if (auto listConstruct =
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) { getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
@ -853,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenLenStrOp // AtenLenStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>()) if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
return getI64IntegerAttr(getContext(), return getI64IntegerAttr(getContext(),
stringConstruct.getValueAttr().getValue().size()); stringConstruct.getValueAttr().getValue().size());
@ -1092,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenSizeIntOp // AtenSizeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
int64_t dim; int64_t dim;
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
return nullptr; return nullptr;
@ -1132,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
// AtenLtFloatOp // AtenLtFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a < b; }); [](double a, double b) { return a < b; });
} }
@ -1141,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGtFloatOp // AtenGtFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a > b; }); [](double a, double b) { return a > b; });
} }
@ -1150,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenGeFloatOp // AtenGeFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a >= b; }); [](double a, double b) { return a >= b; });
} }
@ -1159,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenEqFloatOp // AtenEqFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
return floatComparatorFoldHelper(*this, return floatComparatorFoldHelper(*this,
[](double a, double b) { return a == b; }); [](double a, double b) { return a == b; });
} }
@ -1225,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
// AtenNeIntOp // AtenNeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a != b; }); [](int64_t a, int64_t b) { return a != b; });
} }
@ -1234,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqIntOp // AtenEqIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a == b; }); [](int64_t a, int64_t b) { return a == b; });
} }
@ -1243,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
// AtenEqStrOp // AtenEqStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1)) if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), true); return getI1IntegerAttr(getContext(), true);
@ -1259,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
// AtenLtIntOp // AtenLtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a < b; }); [](int64_t a, int64_t b) { return a < b; });
} }
@ -1268,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenLeIntOp // AtenLeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a <= b; }); [](int64_t a, int64_t b) { return a <= b; });
} }
@ -1277,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGtIntOp // AtenGtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a > b; }); [](int64_t a, int64_t b) { return a > b; });
} }
@ -1286,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
// AtenGeIntOp // AtenGeIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
return intComparatorFoldHelper(*this, return intComparatorFoldHelper(*this,
[](int64_t a, int64_t b) { return a >= b; }); [](int64_t a, int64_t b) { return a >= b; });
} }
@ -1295,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolFloatOp // AtenBoolFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
double c; double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI1IntegerAttr(getContext(), c != 0.0); return getI1IntegerAttr(getContext(), c != 0.0);
@ -1306,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenBoolIntOp // AtenBoolIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI1IntegerAttr(getContext(), c != 0); return getI1IntegerAttr(getContext(), c != 0);
@ -1317,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatScalarOp // AtenFloatScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold int -> float conversion. // Constant fold int -> float conversion.
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) { if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
return FloatAttr::get( return FloatAttr::get(
mlir::Float64Type::get(getContext()), mlir::Float64Type::get(getContext()),
static_cast<double>(integerAttr.getValue().getSExtValue())); static_cast<double>(integerAttr.getValue().getSExtValue()));
@ -1334,9 +1334,9 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenIntScalarOp // AtenIntScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion. // Constant fold float -> int conversion.
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) { if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get( return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed), mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<long>(floatAttr.getValue().convertToDouble())); static_cast<long>(floatAttr.getValue().convertToDouble()));
@ -1351,7 +1351,7 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenIntBoolOp // AtenIntBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
bool b; bool b;
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) { if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
return getI64IntegerAttr(getContext(), static_cast<long>(b)); return getI64IntegerAttr(getContext(), static_cast<long>(b));
@ -1452,7 +1452,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
return success(); return success();
} }
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); return getValueAttr();
} }
@ -1557,7 +1557,7 @@ void CopyToValueTensorOp::getEffects(
// ConstantNoneOp // ConstantNoneOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
return TypeAttr::get(Torch::NoneType::get(getContext())); return TypeAttr::get(Torch::NoneType::get(getContext()));
} }
@ -1570,9 +1570,7 @@ void ConstantNoneOp::getAsmResultNames(
// ConstantStrOp // ConstantStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
return getValueAttr();
}
void ConstantStrOp::getAsmResultNames( void ConstantStrOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) { function_ref<void(Value, StringRef)> setNameFn) {
@ -1610,7 +1608,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
} }
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); return getValueAttr();
} }
@ -1626,7 +1624,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
// ConstantFloatOp // ConstantFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); return getValueAttr();
} }
@ -1656,7 +1654,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
// ConstantNumberOp // ConstantNumberOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); return getValueAttr();
} }
@ -1684,7 +1682,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
// ConstantBoolOp // ConstantBoolOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); return getValueAttr();
} }
@ -1702,7 +1700,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(outputs[0], inputs[0]); return isValidSubtype(outputs[0], inputs[0]);
} }
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) { OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) { if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
if (derefineOp.getOperand().getType() == getType()) if (derefineOp.getOperand().getType() == getType())
return derefineOp.getOperand(); return derefineOp.getOperand();
@ -1836,7 +1834,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// AtenEqIntListOp // AtenEqIntListOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>(); auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsLiteral) if (!lhsLiteral)
return nullptr; return nullptr;
@ -1976,7 +1974,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
// Aten__Getitem__DictStrOp // Aten__Getitem__DictStrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getSelf()); auto dictConstruct = getDictConstructIfNotModified(getSelf());
if (!dictConstruct) if (!dictConstruct)
return nullptr; return nullptr;
@ -1994,7 +1992,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
// Aten__Contains__StrOp // Aten__Contains__StrOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
auto dictConstruct = getDictConstructIfNotModified(getDict()); auto dictConstruct = getDictConstructIfNotModified(getDict());
if (!dictConstruct) if (!dictConstruct)
return nullptr; return nullptr;
@ -2017,7 +2015,7 @@ static bool isListConstructNotModified(Value torchList) {
}); });
} }
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) { OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
auto itemConstruct = getItem(); auto itemConstruct = getItem();
if (!isListConstructNotModified(getL())) if (!isListConstructNotModified(getL()))
return nullptr; return nullptr;
@ -2078,43 +2076,44 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenFloordivIntOp // AtenFloordivIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); adaptor.getOperands(),
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenRemainderIntOp // AtenRemainderIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a % b; }); adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenAddIntOp // AtenAddIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a + b; }); adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubIntOp // AtenSubIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) { return a - b; }); adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenCatOp // AtenCatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>(); auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || !list->hasOneUse() || list.getElements().size() != 1) if (!list || !list->hasOneUse() || list.getElements().size() != 1)
return nullptr; return nullptr;
@ -2125,7 +2124,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenSliceTensorOp // AtenSliceTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>(); auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>(); auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
@ -2144,7 +2143,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// AtenMulIntOp // AtenMulIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs; int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2159,42 +2158,45 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
// AtenSubOp // AtenSubOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
if (!operands[0] || !operands[1]) { if (!adaptor.getA() || !adaptor.getB()) {
return nullptr; return nullptr;
} }
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) { if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
return atenBinaryIntOperatorFoldHelper( return atenBinaryIntOperatorFoldHelper(
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; }); adaptor.getOperands(),
[](int64_t a, int64_t b) -> int64_t { return a - b; });
} }
return atenBinaryFloatOperatorFoldHelper( return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a - b; }); adaptor.getOperands(),
[](double a, double b) -> double { return a - b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenDivOp // AtenDivOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
if (!operands[0] || !operands[1]) { if (!adaptor.getA() || !adaptor.getB()) {
return nullptr; return nullptr;
} }
// Since AtenDivOp always returns float value, we don't need to deal with the // Since AtenDivOp always returns float value, we don't need to deal with the
// case where the operands are both integers separately. // case where the operands are both integers separately.
return atenBinaryFloatOperatorFoldHelper( return atenBinaryFloatOperatorFoldHelper(
operands, [](double a, double b) -> double { return a / b; }); adaptor.getOperands(),
[](double a, double b) -> double { return a / b; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenCeilScalarOp // AtenCeilScalarOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
if (!operands[0]) { if (!adaptor.getA()) {
return nullptr; return nullptr;
} }
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>(); auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
if (!floatValue) { if (!floatValue) {
return nullptr; return nullptr;
} }
@ -2207,7 +2209,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
// AtenNegIntOp // AtenNegIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getI64IntegerAttr(getContext(), -c); return getI64IntegerAttr(getContext(), -c);
@ -2218,7 +2220,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
// AtenSqrtIntOp // AtenSqrtIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
int64_t c; int64_t c;
if (matchPattern(getOperand(), m_TorchConstantInt(&c))) if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
return getF64FloatAttr(getContext(), std::sqrt(c)); return getF64FloatAttr(getContext(), std::sqrt(c));
@ -2229,7 +2231,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
// PrimDtypeOp // PrimDtypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>(); BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) { if (tensorType.hasDtype()) {
torch_upstream::ScalarType scalarType = torch_upstream::ScalarType scalarType =
@ -2243,7 +2245,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
// AtenIntTensorOp // AtenIntTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to // If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number. // aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>()) if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2255,7 +2257,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenFloatTensorOp // AtenFloatTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to // If a scalar number is converted to a 0-d tensor and passed on to
// aten.Float.Tensor, fold to the scalar number. // aten.Float.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>()) if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
@ -2267,7 +2269,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
// AtenDivFloatOp // AtenDivFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
double lhs, rhs; double lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
@ -2284,7 +2286,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
// AtenDivIntOp // AtenDivIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
int64_t lhs, rhs; int64_t lhs, rhs;
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
@ -2297,7 +2299,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
// AtenCeilFloatOp // AtenCeilFloatOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) { OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
double c; double c;
if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
return getI64IntegerAttr(getContext(), std::ceil(c)); return getI64IntegerAttr(getContext(), std::ceil(c));
@ -2308,13 +2310,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
// PrimMaxIntOp // PrimMaxIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
// If both operands are the same, then the operation is an identity. // If both operands are the same, then the operation is an identity.
if (getA() == getB()) if (getA() == getB())
return getA(); return getA();
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>(); auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>(); auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs) if (!lhs || !rhs)
return nullptr; return nullptr;
// Torch semantics are that !torch.int is 64-bit signed. // Torch semantics are that !torch.int is 64-bit signed.
@ -2327,7 +2329,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
// PrimMinSelfIntOp // PrimMinSelfIntOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) { OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
auto list = getOperand().getDefiningOp<PrimListConstructOp>(); auto list = getOperand().getDefiningOp<PrimListConstructOp>();
if (!list) if (!list)
return nullptr; return nullptr;

View File

@ -463,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
} }
} }
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype); return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype);
} }
////===----------------------------------------------------------------------===// ////===----------------------------------------------------------------------===//

View File

@ -72,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Type resultType = tensorType.getWithSizesAndDtype( Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>() sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes), : llvm::ArrayRef(sizes),
tensorType.getOptionalDtype()); tensorType.getOptionalDtype());
return resultType; return resultType;
} }
@ -108,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
valueType valueType
.getWithSizesAndDtype( .getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>() !valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(valueType.getSizes()), : llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed)) IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>(); .cast<BaseTensorType>();
return rewriter return rewriter
@ -142,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
BaseTensorType inputType, Value scalar) { BaseTensorType inputType, Value scalar) {
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
Type rank0TensorTy = inputType.getWithSizesAndDtype( Type rank0TensorTy = inputType.getWithSizesAndDtype(
makeArrayRef(sizes), inputType.getOptionalDtype()); ArrayRef(sizes), inputType.getOptionalDtype());
Value dimList = rewriter.create<PrimListConstructOp>( Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
ValueRange{}); ValueRange{});
@ -940,7 +940,7 @@ public:
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
sizes.append(inputShape.begin(), inputShape.end()); sizes.append(inputShape.begin(), inputShape.end());
sizes[cstDim] = kUnknownSize; sizes[cstDim] = kUnknownSize;
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
selfTy.getOptionalDtype()); selfTy.getOptionalDtype());
Value slice0 = rewriter.create<AtenSliceTensorOp>( Value slice0 = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, input, dim, negShift, constNone, constOne); loc, sliceTy, input, dim, negShift, constNone, constOne);
@ -1077,9 +1077,9 @@ public:
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
Type unsqueezedType = ValueTensorType::get( Type unsqueezedType = ValueTensorType::get(
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
Type expandedType = ValueTensorType::get( Type expandedType =
context, llvm::makeArrayRef(expandedIntSizes), dtype); ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value unsqueezedDims = Value unsqueezedDims =
@ -2004,7 +2004,7 @@ public:
auto inputType = input.getType().cast<BaseTensorType>(); auto inputType = input.getType().cast<BaseTensorType>();
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty), Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
rewriter.getF64Type()); rewriter.getF64Type());
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p); Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
Value output; Value output;
@ -2082,8 +2082,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
std::vector<int64_t> meanVarSizes(inputRank, 1); std::vector<int64_t> meanVarSizes(inputRank, 1);
for (int i = 0; i < axis; i++) for (int i = 0; i < axis; i++)
meanVarSizes[i] = input.getSizes()[i]; meanVarSizes[i] = input.getSizes()[i];
auto meanVarType = input.getWithSizesAndDtype( auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype()); input.getOptionalDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>( auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.getInput(), loc, op.getType(), meanVarType, meanVarType, op.getInput(),
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
@ -2320,7 +2320,7 @@ class DecomposeAtenNativeBatchNormOp
runningStatsShapeInt[1] = kUnknownSize; runningStatsShapeInt[1] = kUnknownSize;
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype(); Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
Type reshapeType = ValueTensorType::get( Type reshapeType = ValueTensorType::get(
context, llvm::makeArrayRef(runningStatsShapeInt), dtype); context, llvm::ArrayRef(runningStatsShapeInt), dtype);
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean, runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
runningStatsSizeList); runningStatsSizeList);
@ -2466,8 +2466,7 @@ public:
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
auto dtype = auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType()); getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType = Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
op.getFillValue()); op.getFillValue());
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
@ -2503,7 +2502,7 @@ public:
SmallVector<int64_t> transposeShape = SmallVector<int64_t> transposeShape =
llvm::to_vector(llvm::reverse(weightType.getSizes())); llvm::to_vector(llvm::reverse(weightType.getSizes()));
Type transposeType = weightType.getWithSizesAndDtype( Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype()); llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
Value transposeWeight = Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight); rewriter.create<AtenTOp>(loc, transposeType, weight);
@ -2573,8 +2572,7 @@ public:
SmallVector<int64_t> empty; SmallVector<int64_t> empty;
auto dtype = auto dtype =
getTypeForTorchType(op.getContext(), op.getFillValue().getType()); getTypeForTorchType(op.getContext(), op.getFillValue().getType());
Type tensorType = Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>( Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
op.getLoc(), tensorType, op.getFillValue()); op.getLoc(), tensorType, op.getFillValue());
fillVal = fillVal =
@ -3216,7 +3214,7 @@ public:
sizes.resize(srcShape.size() + 1, kUnknownSize); sizes.resize(srcShape.size() + 1, kUnknownSize);
} }
Type srcType = srcTensorType.getWithSizesAndDtype( Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype()); llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim); src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>( rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.getSelf().getType(), self, src, dim, start, startPlusOne, op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
@ -3314,7 +3312,7 @@ public:
op, "Expected the input tensor to have sizes"); op, "Expected the input tensor to have sizes");
BaseTensorType subType = BaseTensorType subType =
inputType inputType
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()), .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
resultType.getOptionalDtype()) resultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();

View File

@ -129,8 +129,7 @@ public:
// Truncate the list of users to the number of users we're going to // Truncate the list of users to the number of users we're going to
// interpret. // interpret.
allUsers.resize(numUsersToInterpret); allUsers.resize(numUsersToInterpret);
auto usersToInterpret = auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
makeArrayRef(allUsers).take_front(numUsersToInterpret);
// For each mutating op (which must be in the same block), we save the // For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then // current state of the list as a vector of Value's. These will then
@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
auto originalResultType = result.getType().cast<BaseTensorType>(); auto originalResultType = result.getType().cast<BaseTensorType>();
auto impliedTypesFromShape = auto impliedTypesFromShape =
originalResultType.cast<BaseTensorType>() originalResultType.cast<BaseTensorType>()
.getWithSizesAndDtype(makeArrayRef(sizes), .getWithSizesAndDtype(ArrayRef(sizes),
originalResultType.getOptionalDtype()) originalResultType.getOptionalDtype())
.cast<BaseTensorType>(); .cast<BaseTensorType>();

View File

@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
// FromI64Op // FromI64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToI64Op // ToI64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>(); auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// ToF64Op // ToF64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>(); auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// FromF64Op // FromF64Op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) { OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>(); auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
if (attr) { if (attr) {
return attr; return attr;
} else { } else {

View File

@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
loc, loc,
/*inputs=*/from, /*inputs=*/from,
/*outputs=*/to, /*outputs=*/to,
/*indexingMaps=*/llvm::makeArrayRef({id, id}), /*indexingMaps=*/llvm::ArrayRef({id, id}),
/*iteratorTypes=*/iteratorTypes, /*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) { [](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args.front()); b.create<linalg::YieldOp>(loc, args.front());

View File

@ -101,17 +101,17 @@ torch.class_type @c {
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}} // expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) func.func private @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}} // expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
func.func @f(%arg0: i32 {torch.type_bound = 1}) func.func private @f(%arg0: i32 {torch.type_bound = 1})
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}} // expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
func.func @f(%arg0: i32 {torch.type_bound = i32}) func.func private @f(%arg0: i32 {torch.type_bound = i32})
// ----- // -----