mirror of https://github.com/llvm/torch-mlir
build: update llvm tag to 9acc2f37 (#1828)
This commit makes the following changes: - Update dialects to use fold API `kEmitFoldAdaptorFolder` and update signature of `fold` methods (see PSA https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618) - Replace `makeArrayRef` with `ArrayRef` (see https://reviews.llvm.org/D140896) - Remove `TypeRange{}` arg from `b.create<scf::IfOp>` since builder no longer takes that argument - Make `func`s in `Torch/invalid.mlir` private, since symbol declarations cannot be public. (see https://discourse.llvm.org/t/rfc-symbol-definition-declaration-x-visibility-checks/2140)pull/1740/head
parent
8ce2fffca5
commit
6c86bec04f
|
@ -43,6 +43,7 @@ def TMTensor_Dialect : Dialect {
|
|||
to.
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -204,7 +204,7 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
auto scfIf = b.create<scf::IfOp>(
|
||||
loc, TypeRange{}, cond,
|
||||
loc, cond,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
if (isInclusive) {
|
||||
auto value = b.create<memref::LoadOp>(loc, input(), indices);
|
||||
|
@ -266,7 +266,7 @@ static LogicalResult foldMemRefCast(Operation *op) {
|
|||
return success(folded);
|
||||
}
|
||||
|
||||
LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
||||
LogicalResult ScanOp::fold(FoldAdaptor adaptor,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit d23516e9ad477527a9db4d06b1fa9566680ac67c
|
||||
Subproject commit 9acc2f37bdfce08ca0c2faec03392db10d1bb7a9
|
|
@ -1 +1 @@
|
|||
Subproject commit 81e87a95b8683f1c3c33caf9e933897e0fc4a2b7
|
||||
Subproject commit 4a173356bb1291b97046545429d7851cbc771d88
|
|
@ -37,6 +37,7 @@ def Torch_Dialect : Dialect {
|
|||
let hasRegionArgAttrVerify = 1;
|
||||
let hasConstantMaterializer = 1;
|
||||
let useDefaultTypePrinterParser = 0;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Parse a type registered to this dialect.
|
||||
|
|
|
@ -27,6 +27,7 @@ def TorchConversion_Dialect : Dialect {
|
|||
}];
|
||||
|
||||
let hasConstantMaterializer = 1;
|
||||
let useFoldAPI = kEmitFoldAdaptorFolder;
|
||||
}
|
||||
|
||||
#endif // TORCHCONVERSION_BASE
|
||||
|
|
|
@ -463,8 +463,8 @@ public:
|
|||
}
|
||||
|
||||
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
|
||||
ArrayRef<Value> outputShapeInt = llvm::makeArrayRef(outputSizeInt);
|
||||
ArrayRef<Value> inputShapeInt = llvm::makeArrayRef(inputSize);
|
||||
ArrayRef<Value> outputShapeInt = llvm::ArrayRef(outputSizeInt);
|
||||
ArrayRef<Value> inputShapeInt = llvm::ArrayRef(inputSize);
|
||||
|
||||
// Association indices for expand/collapse ops. These two vectors
|
||||
// are populated such that two entries at the same index corresponds
|
||||
|
@ -1136,7 +1136,7 @@ public:
|
|||
|
||||
Value dimIndex = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, rewriter.getIndexAttr(dim));
|
||||
for (auto tensor : makeArrayRef(tensors).drop_front()) {
|
||||
for (auto tensor : ArrayRef(tensors).drop_front()) {
|
||||
auto size = rewriter.createOrFold<tensor::DimOp>(loc, tensor, dimIndex);
|
||||
resultDimSize =
|
||||
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
|
||||
|
@ -1270,7 +1270,7 @@ public:
|
|||
/*resultType=*/selfType,
|
||||
/*inputs=*/broadcastedSrc,
|
||||
/*outputs=*/self,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result = args[0];
|
||||
|
|
|
@ -1086,7 +1086,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
// Reshape input
|
||||
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, input,
|
||||
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
||||
mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
.value());
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ public:
|
|||
// Finding the maximum value in the input tensor.
|
||||
SmallVector<int64_t> maxTensorSizes;
|
||||
ValueTensorType maxTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(maxTensorSizes),
|
||||
context, llvm::ArrayRef(maxTensorSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value maxTensor =
|
||||
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
||||
|
@ -165,7 +165,7 @@ public:
|
|||
SmallVector<int64_t> expandedInputSizes{
|
||||
makeShapeTorchCompatible(inputType.getShape())[0], 1};
|
||||
ValueTensorType expandInputType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedInputSizes),
|
||||
context, llvm::ArrayRef(expandedInputSizes),
|
||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
|
@ -286,9 +286,9 @@ public:
|
|||
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
|
||||
int64_t indexTensorSize = indexTensorType.getSizes()[0];
|
||||
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
|
||||
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIndexTensorSizes),
|
||||
indexTensorType.getDtype());
|
||||
ValueTensorType expandedIndexTensorType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes),
|
||||
indexTensorType.getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
|
|
|
@ -718,8 +718,8 @@ class ConvertAtenMultipleDimsReductionOp
|
|||
"non-const dim parameter unsupported");
|
||||
int64_t N = reduceDims.size();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
|
||||
|
@ -748,8 +748,8 @@ class ConvertAtenOneDimReductionOp
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dim parameter unsupported");
|
||||
auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef({reduceDim}));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim}));
|
||||
|
||||
keepDims = false;
|
||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDims)))
|
||||
|
@ -782,8 +782,8 @@ public:
|
|||
reduceDims.push_back(i);
|
||||
int64_t N = selfTy.getRank();
|
||||
auto reduceDimsType = RankedTensorType::get({N}, rewriter.getI64Type());
|
||||
reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType,
|
||||
llvm::makeArrayRef(reduceDims));
|
||||
reduceDimsAttr =
|
||||
DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef(reduceDims));
|
||||
keepDims = false;
|
||||
|
||||
return success();
|
||||
|
|
|
@ -507,7 +507,7 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
return isValidSubtype(inputs[0], outputs[0]);
|
||||
}
|
||||
|
||||
OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) {
|
||||
auto uncheckedCast = getOperand().getDefiningOp<PrimUncheckedCastOp>();
|
||||
if (!uncheckedCast)
|
||||
return nullptr;
|
||||
|
@ -570,10 +570,10 @@ static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
|
|||
// Aten__RangeLengthOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto lo = operands[0];
|
||||
auto hi = operands[1];
|
||||
auto step = operands[2];
|
||||
OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) {
|
||||
auto lo = adaptor.getLo();
|
||||
auto hi = adaptor.getHi();
|
||||
auto step = adaptor.getStep();
|
||||
if (!lo || !hi || !step)
|
||||
return nullptr;
|
||||
auto loInt = lo.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
|
@ -595,10 +595,10 @@ OpFoldResult Aten__RangeLengthOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__DeriveIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto index = operands[0];
|
||||
auto start = operands[1];
|
||||
auto step = operands[2];
|
||||
OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) {
|
||||
auto index = adaptor.getIndex();
|
||||
auto start = adaptor.getStart();
|
||||
auto step = adaptor.getStep();
|
||||
if (!index || !start || !step)
|
||||
return nullptr;
|
||||
auto indexInt = index.dyn_cast_or_null<IntegerAttr>().getValue();
|
||||
|
@ -612,7 +612,7 @@ OpFoldResult Aten__DeriveIndexOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Is__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) {
|
||||
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true);
|
||||
}
|
||||
|
||||
|
@ -620,7 +620,7 @@ OpFoldResult Aten__Is__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Isnot__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) {
|
||||
return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false);
|
||||
}
|
||||
|
||||
|
@ -628,7 +628,7 @@ OpFoldResult Aten__Isnot__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Not__Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) {
|
||||
bool value;
|
||||
if (!matchPattern(getOperand(), m_TorchConstantBool(&value)))
|
||||
return nullptr;
|
||||
|
@ -639,7 +639,7 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenNeBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), false);
|
||||
|
||||
|
@ -655,7 +655,7 @@ OpFoldResult AtenNeBoolOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqueezeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand();
|
||||
|
@ -667,7 +667,7 @@ OpFoldResult AtenSqueezeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqueezeDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
|
@ -679,7 +679,7 @@ OpFoldResult AtenSqueezeDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenRoundOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto selfType = getSelf().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (selfType.hasDtype() && selfType.getDtype().isa<mlir::IntegerType>())
|
||||
return getSelf();
|
||||
|
@ -691,7 +691,7 @@ OpFoldResult AtenRoundOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenTypeAsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
|
||||
Type inType = getSelf().getType();
|
||||
Type newType = getOther().getType();
|
||||
|
||||
|
@ -705,7 +705,7 @@ OpFoldResult AtenTypeAsOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenToDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
||||
bool nonBlocking, copyArg;
|
||||
// The non_blocking arg must be `False`.
|
||||
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
|
||||
|
@ -736,7 +736,7 @@ OpFoldResult AtenToDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenToDtypeLayoutOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||
// The pin_memory arg should be either constant `False` or `none`.
|
||||
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
bool pinMemory;
|
||||
|
@ -797,7 +797,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
|
@ -812,7 +812,7 @@ OpFoldResult AtenViewOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes())
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
|
@ -825,7 +825,7 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLenTOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) {
|
||||
// `len([1,1,1])` -> `3`, if it is not mutated.
|
||||
if (auto listConstruct =
|
||||
getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||
|
@ -853,7 +853,7 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenLenStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLenStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto stringConstruct = getS().getDefiningOp<ConstantStrOp>())
|
||||
return getI64IntegerAttr(getContext(),
|
||||
stringConstruct.getValueAttr().getValue().size());
|
||||
|
@ -1092,7 +1092,7 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenSizeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSizeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim)))
|
||||
return nullptr;
|
||||
|
@ -1132,7 +1132,7 @@ floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) {
|
|||
// AtenLtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a < b; });
|
||||
}
|
||||
|
@ -1141,7 +1141,7 @@ OpFoldResult AtenLtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGtFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a > b; });
|
||||
}
|
||||
|
@ -1150,7 +1150,7 @@ OpFoldResult AtenGtFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGeFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a >= b; });
|
||||
}
|
||||
|
@ -1159,7 +1159,7 @@ OpFoldResult AtenGeFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return floatComparatorFoldHelper(*this,
|
||||
[](double a, double b) { return a == b; });
|
||||
}
|
||||
|
@ -1225,7 +1225,7 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|||
// AtenNeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a != b; });
|
||||
}
|
||||
|
@ -1234,7 +1234,7 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a == b; });
|
||||
}
|
||||
|
@ -1243,7 +1243,7 @@ OpFoldResult AtenEqIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenEqStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0) == getOperand(1))
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
|
||||
|
@ -1259,7 +1259,7 @@ OpFoldResult AtenEqStrOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a < b; });
|
||||
}
|
||||
|
@ -1268,7 +1268,7 @@ OpFoldResult AtenLtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenLeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a <= b; });
|
||||
}
|
||||
|
@ -1277,7 +1277,7 @@ OpFoldResult AtenLeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a > b; });
|
||||
}
|
||||
|
@ -1286,7 +1286,7 @@ OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenGeIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) {
|
||||
return intComparatorFoldHelper(*this,
|
||||
[](int64_t a, int64_t b) { return a >= b; });
|
||||
}
|
||||
|
@ -1295,7 +1295,7 @@ OpFoldResult AtenGeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenBoolFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
||||
return getI1IntegerAttr(getContext(), c != 0.0);
|
||||
|
@ -1306,7 +1306,7 @@ OpFoldResult AtenBoolFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenBoolIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getI1IntegerAttr(getContext(), c != 0);
|
||||
|
@ -1317,9 +1317,9 @@ OpFoldResult AtenBoolIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenFloatScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold int -> float conversion.
|
||||
if (auto integerAttr = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (auto integerAttr = adaptor.getA().dyn_cast_or_null<IntegerAttr>()) {
|
||||
return FloatAttr::get(
|
||||
mlir::Float64Type::get(getContext()),
|
||||
static_cast<double>(integerAttr.getValue().getSExtValue()));
|
||||
|
@ -1334,9 +1334,9 @@ OpFoldResult AtenFloatScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenIntScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
|
||||
static_cast<long>(floatAttr.getValue().convertToDouble()));
|
||||
|
@ -1351,7 +1351,7 @@ OpFoldResult AtenIntScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenIntBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
|
||||
bool b;
|
||||
if (matchPattern(getOperand(), m_TorchConstantBool(&b))) {
|
||||
return getI64IntegerAttr(getContext(), static_cast<long>(b));
|
||||
|
@ -1452,7 +1452,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1557,7 +1557,7 @@ void CopyToValueTensorOp::getEffects(
|
|||
// ConstantNoneOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantNoneOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) {
|
||||
return TypeAttr::get(Torch::NoneType::get(getContext()));
|
||||
}
|
||||
|
||||
|
@ -1570,9 +1570,7 @@ void ConstantNoneOp::getAsmResultNames(
|
|||
// ConstantStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ConstantStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
return getValueAttr();
|
||||
}
|
||||
OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
||||
|
||||
void ConstantStrOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
|
@ -1610,7 +1608,7 @@ void ConstantIntOp::print(OpAsmPrinter &p) {
|
|||
p.printOptionalAttrDict((*this)->getAttrs(), {"value"});
|
||||
}
|
||||
|
||||
OpFoldResult Torch::ConstantIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1626,7 +1624,7 @@ void Torch::ConstantIntOp::getAsmResultNames(
|
|||
// ConstantFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1656,7 +1654,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
|
|||
// ConstantNumberOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantNumberOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1684,7 +1682,7 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns(
|
|||
// ConstantBoolOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Torch::ConstantBoolOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) {
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
|
@ -1702,7 +1700,7 @@ bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
|||
return isValidSubtype(outputs[0], inputs[0]);
|
||||
}
|
||||
|
||||
OpFoldResult PrimUncheckedCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto derefineOp = getX().getDefiningOp<Torch::DerefineOp>()) {
|
||||
if (derefineOp.getOperand().getType() == getType())
|
||||
return derefineOp.getOperand();
|
||||
|
@ -1836,7 +1834,7 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
// AtenEqIntListOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenEqIntListOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) {
|
||||
auto lhsLiteral = getA().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!lhsLiteral)
|
||||
return nullptr;
|
||||
|
@ -1976,7 +1974,7 @@ static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) {
|
|||
// Aten__Getitem__DictStrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) {
|
||||
auto dictConstruct = getDictConstructIfNotModified(getSelf());
|
||||
if (!dictConstruct)
|
||||
return nullptr;
|
||||
|
@ -1994,7 +1992,7 @@ OpFoldResult Aten__Getitem__DictStrOp::fold(ArrayRef<Attribute> operands) {
|
|||
// Aten__Contains__StrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Contains__StrOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) {
|
||||
auto dictConstruct = getDictConstructIfNotModified(getDict());
|
||||
if (!dictConstruct)
|
||||
return nullptr;
|
||||
|
@ -2017,7 +2015,7 @@ static bool isListConstructNotModified(Value torchList) {
|
|||
});
|
||||
}
|
||||
|
||||
OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) {
|
||||
auto itemConstruct = getItem();
|
||||
if (!isListConstructNotModified(getL()))
|
||||
return nullptr;
|
||||
|
@ -2078,43 +2076,44 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
|||
// AtenFloordivIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloordivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRemainderIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenRemainderIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a % b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenAddIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a + b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) { return a - b; });
|
||||
adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||
if (!list || !list->hasOneUse() || list.getElements().size() != 1)
|
||||
return nullptr;
|
||||
|
@ -2125,7 +2124,7 @@ OpFoldResult AtenCatOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// AtenSliceTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||
|
@ -2144,7 +2143,7 @@ OpFoldResult AtenSliceTensorOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// AtenMulIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
|
@ -2159,42 +2158,45 @@ OpFoldResult AtenMulIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSubOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1]) {
|
||||
OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (operands[0].isa<IntegerAttr>() && operands[1].isa<IntegerAttr>()) {
|
||||
if (adaptor.getA().isa<IntegerAttr>() && adaptor.getB().isa<IntegerAttr>()) {
|
||||
return atenBinaryIntOperatorFoldHelper(
|
||||
operands, [](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
adaptor.getOperands(),
|
||||
[](int64_t a, int64_t b) -> int64_t { return a - b; });
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
operands, [](double a, double b) -> double { return a - b; });
|
||||
adaptor.getOperands(),
|
||||
[](double a, double b) -> double { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenDivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1]) {
|
||||
OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Since AtenDivOp always returns float value, we don't need to deal with the
|
||||
// case where the operands are both integers separately.
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
operands, [](double a, double b) -> double { return a / b; });
|
||||
adaptor.getOperands(),
|
||||
[](double a, double b) -> double { return a / b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCeilScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0]) {
|
||||
OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto floatValue = operands[0].dyn_cast_or_null<FloatAttr>();
|
||||
auto floatValue = adaptor.getA().dyn_cast_or_null<FloatAttr>();
|
||||
if (!floatValue) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -2207,7 +2209,7 @@ OpFoldResult AtenCeilScalarOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenNegIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getI64IntegerAttr(getContext(), -c);
|
||||
|
@ -2218,7 +2220,7 @@ OpFoldResult AtenNegIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenSqrtIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantInt(&c)))
|
||||
return getF64FloatAttr(getContext(), std::sqrt(c));
|
||||
|
@ -2229,7 +2231,7 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
||||
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
|
||||
if (tensorType.hasDtype()) {
|
||||
torch_upstream::ScalarType scalarType =
|
||||
|
@ -2243,7 +2245,7 @@ OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
|
@ -2255,7 +2257,7 @@ OpFoldResult AtenIntTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenFloatTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Float.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
|
@ -2267,7 +2269,7 @@ OpFoldResult AtenFloatTensorOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDivFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs));
|
||||
|
@ -2284,7 +2286,7 @@ OpFoldResult AtenDivFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenDivIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t lhs, rhs;
|
||||
bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs));
|
||||
bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs));
|
||||
|
@ -2297,7 +2299,7 @@ OpFoldResult AtenDivIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// AtenCeilFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
|
||||
double c;
|
||||
if (matchPattern(getOperand(), m_TorchConstantFloat(&c)))
|
||||
return getI64IntegerAttr(getContext(), std::ceil(c));
|
||||
|
@ -2308,13 +2310,13 @@ OpFoldResult AtenCeilFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimMaxIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) {
|
||||
// If both operands are the same, then the operation is an identity.
|
||||
if (getA() == getB())
|
||||
return getA();
|
||||
|
||||
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = adaptor.getA().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = adaptor.getB().dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
// Torch semantics are that !torch.int is 64-bit signed.
|
||||
|
@ -2327,7 +2329,7 @@ OpFoldResult PrimMaxIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// PrimMinSelfIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) {
|
||||
auto list = getOperand().getDefiningOp<PrimListConstructOp>();
|
||||
if (!list)
|
||||
return nullptr;
|
||||
|
|
|
@ -463,7 +463,7 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
|||
}
|
||||
}
|
||||
|
||||
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
|
||||
return lhs.getWithSizesAndDtype(ArrayRef(newSizes), dtype);
|
||||
}
|
||||
|
||||
////===----------------------------------------------------------------------===//
|
||||
|
@ -505,4 +505,4 @@ DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
|||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
|||
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
: llvm::ArrayRef(sizes),
|
||||
tensorType.getOptionalDtype());
|
||||
return resultType;
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
valueType
|
||||
.getWithSizesAndDtype(
|
||||
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(valueType.getSizes()),
|
||||
: llvm::ArrayRef(valueType.getSizes()),
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
||||
.cast<BaseTensorType>();
|
||||
return rewriter
|
||||
|
@ -142,7 +142,7 @@ static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|||
BaseTensorType inputType, Value scalar) {
|
||||
SmallVector<int64_t> sizes;
|
||||
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
||||
makeArrayRef(sizes), inputType.getOptionalDtype());
|
||||
ArrayRef(sizes), inputType.getOptionalDtype());
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||
ValueRange{});
|
||||
|
@ -940,7 +940,7 @@ public:
|
|||
SmallVector<int64_t> sizes;
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[cstDim] = kUnknownSize;
|
||||
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||
Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes),
|
||||
selfTy.getOptionalDtype());
|
||||
Value slice0 = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, sliceTy, input, dim, negShift, constNone, constOne);
|
||||
|
@ -1077,9 +1077,9 @@ public:
|
|||
|
||||
Type dtype = self.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type unsqueezedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIntSizes), dtype);
|
||||
context, llvm::ArrayRef(unsqueezedIntSizes), dtype);
|
||||
Type expandedType =
|
||||
ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype);
|
||||
|
||||
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||
Value unsqueezedDims =
|
||||
|
@ -2004,7 +2004,7 @@ public:
|
|||
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> empty;
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::makeArrayRef(empty),
|
||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
||||
rewriter.getF64Type());
|
||||
Value prob = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType, p);
|
||||
Value output;
|
||||
|
@ -2082,8 +2082,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
std::vector<int64_t> meanVarSizes(inputRank, 1);
|
||||
for (int i = 0; i < axis; i++)
|
||||
meanVarSizes[i] = input.getSizes()[i];
|
||||
auto meanVarType = input.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(meanVarSizes), input.getOptionalDtype());
|
||||
auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes),
|
||||
input.getOptionalDtype());
|
||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||
loc, op.getType(), meanVarType, meanVarType, op.getInput(),
|
||||
op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps());
|
||||
|
@ -2320,7 +2320,7 @@ class DecomposeAtenNativeBatchNormOp
|
|||
runningStatsShapeInt[1] = kUnknownSize;
|
||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
Type reshapeType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(runningStatsShapeInt), dtype);
|
||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||
|
||||
runningMean = rewriter.create<AtenViewOp>(loc, reshapeType, runningMean,
|
||||
runningStatsSizeList);
|
||||
|
@ -2466,8 +2466,7 @@ public:
|
|||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
Type tensorType =
|
||||
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
||||
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(loc, tensorType,
|
||||
op.getFillValue());
|
||||
fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype());
|
||||
|
@ -2503,7 +2502,7 @@ public:
|
|||
SmallVector<int64_t> transposeShape =
|
||||
llvm::to_vector(llvm::reverse(weightType.getSizes()));
|
||||
Type transposeType = weightType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(transposeShape), weightType.getOptionalDtype());
|
||||
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
|
||||
Value transposeWeight =
|
||||
rewriter.create<AtenTOp>(loc, transposeType, weight);
|
||||
|
||||
|
@ -2573,8 +2572,7 @@ public:
|
|||
SmallVector<int64_t> empty;
|
||||
auto dtype =
|
||||
getTypeForTorchType(op.getContext(), op.getFillValue().getType());
|
||||
Type tensorType =
|
||||
outTy.getWithSizesAndDtype(llvm::makeArrayRef(empty), dtype);
|
||||
Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype);
|
||||
Value fillVal = rewriter.create<PrimNumToTensorScalarOp>(
|
||||
op.getLoc(), tensorType, op.getFillValue());
|
||||
fillVal =
|
||||
|
@ -3216,7 +3214,7 @@ public:
|
|||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||
}
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(sizes), srcTensorType.getOptionalDtype());
|
||||
llvm::ArrayRef(sizes), srcTensorType.getOptionalDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||
op, op.getSelf().getType(), self, src, dim, start, startPlusOne,
|
||||
|
@ -3314,7 +3312,7 @@ public:
|
|||
op, "Expected the input tensor to have sizes");
|
||||
BaseTensorType subType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()),
|
||||
.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()),
|
||||
resultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
|
|
|
@ -129,8 +129,7 @@ public:
|
|||
// Truncate the list of users to the number of users we're going to
|
||||
// interpret.
|
||||
allUsers.resize(numUsersToInterpret);
|
||||
auto usersToInterpret =
|
||||
makeArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);
|
||||
|
||||
// For each mutating op (which must be in the same block), we save the
|
||||
// current state of the list as a vector of Value's. These will then
|
||||
|
@ -336,7 +335,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
|||
auto originalResultType = result.getType().cast<BaseTensorType>();
|
||||
auto impliedTypesFromShape =
|
||||
originalResultType.cast<BaseTensorType>()
|
||||
.getWithSizesAndDtype(makeArrayRef(sizes),
|
||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||
originalResultType.getOptionalDtype())
|
||||
.cast<BaseTensorType>();
|
||||
|
||||
|
|
|
@ -75,8 +75,8 @@ LogicalResult FromBuiltinTensorOp::verify() {
|
|||
// FromI64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -88,8 +88,8 @@ OpFoldResult FromI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// ToI64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -101,8 +101,8 @@ OpFoldResult ToI64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// ToF64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
|
||||
OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
@ -114,8 +114,8 @@ OpFoldResult ToF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
// FromF64Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult FromF64Op::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
auto attr = operands[0].dyn_cast_or_null<mlir::FloatAttr>();
|
||||
OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) {
|
||||
auto attr = adaptor.getOperand().dyn_cast_or_null<mlir::FloatAttr>();
|
||||
if (attr) {
|
||||
return attr;
|
||||
} else {
|
||||
|
|
|
@ -392,7 +392,7 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
|||
loc,
|
||||
/*inputs=*/from,
|
||||
/*outputs=*/to,
|
||||
/*indexingMaps=*/llvm::makeArrayRef({id, id}),
|
||||
/*indexingMaps=*/llvm::ArrayRef({id, id}),
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args.front());
|
||||
|
|
|
@ -101,17 +101,17 @@ torch.class_type @c {
|
|||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
|
||||
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||
func.func private @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
|
||||
func.func @f(%arg0: i32 {torch.type_bound = 1})
|
||||
func.func private @f(%arg0: i32 {torch.type_bound = 1})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
|
||||
func.func @f(%arg0: i32 {torch.type_bound = i32})
|
||||
func.func private @f(%arg0: i32 {torch.type_bound = i32})
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue