mirror of https://github.com/llvm/torch-mlir
Allow running DecomposeComplexOps more than once (#1671)
The current implementation of `DecomposeComplexOps` fails if an op expected to be decomposed does not get decomposed in the first iteration of the `createTorchSimplificationPipeline` in `LowerToBackendContractPass`. However, some graphs require multiple iterations of `createTorchSimplificationPipeline` to fully propagate all statically knowable information, such as dtypes and shapes, to the entire graph, sometimes resulting in the need to run `DecomposeComplexOps` more than once. This commit changes `DecomposeComplexOps` to use a greedy algorithm for pattern application and moves the legalization check of ops to the `LowerToBackendContractPass` to allow for the `DecomposeComplexOps` to run more than once.pull/1699/head
parent
e8511840c3
commit
a54b334578
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
@ -18,6 +19,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include <cstdint>
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -675,13 +677,15 @@ public:
|
|||
int lhsRank = getTensorRank(lhs);
|
||||
int rhsRank = getTensorRank(rhs);
|
||||
|
||||
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
||||
if (lhsRank == 2 && rhsRank == 2)
|
||||
if (lhsRank == 2 && rhsRank == 2) {
|
||||
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
||||
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
||||
|
||||
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
||||
if (lhsRank == 3 && rhsRank == 3)
|
||||
} else if (lhsRank == 3 && rhsRank == 3) {
|
||||
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
||||
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3298,6 +3302,24 @@ public:
|
|||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
private:
|
||||
llvm::StringSet<> legalOpsSet;
|
||||
|
||||
template <typename DecomposePattern>
|
||||
void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) {
|
||||
MLIRContext *context = &getContext();
|
||||
Optional<OperationName> opName = DecomposePattern(context).getRootKind();
|
||||
// Because the `DecomposeComplexOpsPass` uses a greedy algorithm
|
||||
// to apply patterns, only patterns that we for sure know we want to run
|
||||
// must be added. This restricts the set of patterns allowed in this file to
|
||||
// patterns that apply to a single op. In other words, patterns that match
|
||||
// on `Operation *` are not allowed, since there is no way of telling if
|
||||
// that pattern will match on an op in the `legalOpsSet` or not.
|
||||
assert(opName && "All decomposition patterns must target a single op");
|
||||
if (!legalOpsSet.contains(opName->getStringRef()))
|
||||
patterns.add<DecomposePattern>(context);
|
||||
}
|
||||
|
||||
public:
|
||||
DecomposeComplexOpsPass() = default;
|
||||
DecomposeComplexOpsPass(ArrayRef<std::string> legalOps) {
|
||||
|
@ -3306,215 +3328,128 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect>();
|
||||
// The strings in the `legalOps` ArrayRef don't exist during the call to the
|
||||
// constructor `DecomposeComplexOpsPass`, so the creation of the
|
||||
// `legalOpsSet` must be delayed to when `runOnOperation` gets called.
|
||||
legalOpsSet.clear();
|
||||
legalOpsSet.insert(legalOps.begin(), legalOps.end());
|
||||
|
||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
|
||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
||||
target.addIllegalOp<AtenEmptyLikeOp>();
|
||||
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenOnesLikeOp>();
|
||||
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenZerosLikeOp>();
|
||||
patterns.add<DecomposeAtenRollOp>(context);
|
||||
target.addIllegalOp<AtenRollOp>();
|
||||
patterns.add<DecomposeAtenRepeatOp>(context);
|
||||
target.addIllegalOp<AtenRepeatOp>();
|
||||
patterns.add<DecomposeAtenExpandOp>(context);
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
patterns.add<DecomposeAtenFlattenUsingIntsOp>(context);
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
patterns.add<DecomposeAtenConvolutionBackwardOverrideableOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
patterns.add<DecomposeAtenSizeOp>(context);
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAtenReshapeOp>(context);
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
patterns.add<DecomposeAtenTanhBackwardOp>(context);
|
||||
target.addIllegalOp<AtenTanhBackwardOp>();
|
||||
patterns.add<DecomposeAtenAddmmOp>(context);
|
||||
target.addIllegalOp<AtenAddmmOp>();
|
||||
patterns.add<DecomposeAtenMeanOp>(context);
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
patterns.add<DecomposeAtenMeanDimOp>(context);
|
||||
target.addIllegalOp<AtenMeanDimOp>();
|
||||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
patterns.add<DecomposeAtenMvOp>(context);
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
patterns.add<DecomposeAtenTOp>(context);
|
||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.getSelf());
|
||||
int rhsRank = getTensorRank(op.getOther());
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
||||
|
||||
// Make aten.matmul legal if the following condition is satisfied.
|
||||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||
});
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
|
||||
|
||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
|
||||
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
||||
patterns.add<DecomposeAtenConvolutionBackwardOp>(context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
patterns.add<DecomposeAtenConv2dOp>(context);
|
||||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||
patterns.add<DecomposeAtenConvTranspose2dOp>(context);
|
||||
patterns.add<DecomposeAtenArangeOp>(context);
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
patterns.add<DecomposeAtenArangeStartOp>(context);
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
patterns.add<DecomposeAtenArgMaxOp>(context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
patterns.add<DecomposeAtenSquareOp>(context);
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
patterns.add<DecomposeAtenVarOp>(context);
|
||||
target.addIllegalOp<AtenVarOp>();
|
||||
patterns.add<DecomposeAtenStdOp>(context);
|
||||
target.addIllegalOp<AtenStdOp>();
|
||||
patterns.add<DecomposeAten_UnsafeViewOp>(context);
|
||||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||
patterns.add<DecomposeAten_ReshapeAliasOp>(context);
|
||||
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
||||
patterns.add<DecomposeAtenBernoulliOp>(context);
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposeAtenBernoulliTensorOp>(context);
|
||||
target.addIllegalOp<AtenBernoulliTensorOp>();
|
||||
patterns.add<DecomposeAtenZeroOp>(context);
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
patterns.add<DecomposeAtenRandLikeOp>(context);
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
patterns.add<DecomposeAtenRelu6Op>(context);
|
||||
target.addIllegalOp<AtenRelu6Op>();
|
||||
patterns.add<DecomposeAtenHardswishOp>(context);
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
patterns.add<DecomposeAtenSoftplusOp>(context);
|
||||
target.addIllegalOp<AtenSoftplusOp>();
|
||||
patterns.add<DecomposeAtenSiluOp>(context);
|
||||
target.addIllegalOp<AtenSiluOp>();
|
||||
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenNewZerosOp>();
|
||||
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenNewOnesOp>();
|
||||
patterns.add<DecomposeAtenHardtanhOp>(context);
|
||||
target.addIllegalOp<AtenHardtanhOp>();
|
||||
patterns.add<DecomposeAtenFullOp>(context);
|
||||
target.addIllegalOp<AtenFullOp>();
|
||||
patterns.add<DecomposeAtenLinearOp>(context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
patterns.add<DecomposeAtenMishOp>(context);
|
||||
target.addIllegalOp<AtenMishOp>();
|
||||
patterns.add<DecomposeAtenFullLikeOp>(context);
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
patterns.add<DecomposeAtenIndexPutOp>(context);
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
patterns.add<DecomposeAtenExpandAsOp>(context);
|
||||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
patterns.add<DecomposeAten_ToCopyOp>(context);
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
patterns.add<DecomposeAtenDropoutOp>(context);
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
patterns.add<DecomposeAtenNewEmptyOp>(context);
|
||||
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
patterns.add<DecomposeAtenPadOp>(context);
|
||||
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
patterns.add<DecomposeAtenToDeviceOp>(context);
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
patterns.add<DecomposeAtenClampMinOp>(context);
|
||||
target.addIllegalOp<AtenClampMinOp>();
|
||||
patterns.add<DecomposeAtenClampMaxOp>(context);
|
||||
target.addIllegalOp<AtenClampMaxOp>();
|
||||
patterns.add<DecomposeAtenBaddbmmOp>(context);
|
||||
target.addIllegalOp<AtenBaddbmmOp>();
|
||||
patterns.add<DecomposeAtenFloorDivideOp>(context);
|
||||
target.addIllegalOp<AtenFloorDivideOp>();
|
||||
patterns.add<DecomposeAtenNumpyTOp>(context);
|
||||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
patterns.add<DecomposeAtenSelectScatterOp>(context);
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
patterns.add<DecomposeAtenVarDimOp>(context);
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
patterns.add<DecomposeAtenAmaxOp>(context);
|
||||
target.addIllegalOp<AtenAmaxOp>();
|
||||
patterns.add<DecomposeAtenVarCorrectionOp>(context);
|
||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||
patterns.add<DecomposeAtenStdDimOp>(context);
|
||||
target.addIllegalOp<AtenStdDimOp>();
|
||||
patterns.add<DecomposeAtenNarrowOp>(context);
|
||||
target.addIllegalOp<AtenNarrowOp>();
|
||||
patterns.add<DecomposeAten_EmbeddingBagOp>(context);
|
||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||
patterns.add<DecomposeAtenLiftFreshCopyOp>(context);
|
||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
|
||||
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
||||
patterns.add<DecomposeAtenMseLossOp>(context);
|
||||
target.addIllegalOp<AtenMseLossOp>();
|
||||
patterns.add<DecomposeAtenRandintLowOp>(context);
|
||||
target.addIllegalOp<AtenRandintLowOp>();
|
||||
patterns.add<DecomposeAtenVarMeanCorrectionOp>(context);
|
||||
target.addIllegalOp<AtenVarMeanCorrectionOp>();
|
||||
patterns.add<DecomposePrimsConvertElementTypeOp>(context);
|
||||
target.addIllegalOp<PrimsConvertElementTypeOp>();
|
||||
patterns.add<DecomposeAtenRandnOp>(context);
|
||||
target.addIllegalOp<AtenRandnOp>();
|
||||
patterns.add<DecomposeAtenRandnGeneratorOp>(context);
|
||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||
|
||||
for (std::string opName : legalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
}
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,10 +11,12 @@
|
|||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||
|
@ -27,6 +29,10 @@ using namespace mlir::torch::Torch;
|
|||
// Checking the backend contract.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||
ConversionTarget &target,
|
||||
ArrayRef<std::string> backendLegalOps);
|
||||
|
||||
static LogicalResult checkType(Operation *op, Type type,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
// Allow various scalar types that backends are expected to be able to handle.
|
||||
|
@ -149,7 +155,24 @@ static LogicalResult checkType(Operation *op, Type type,
|
|||
}
|
||||
}
|
||||
|
||||
static LogicalResult checkOpIsBackendLegal(Operation *op,
|
||||
const ConversionTarget &target,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
if (target.isLegal(op))
|
||||
return success();
|
||||
|
||||
if (actuallyEmitDiagnostics) {
|
||||
return op->emitError("found an op that was marked as backend illegal")
|
||||
.attachNote()
|
||||
.append("this is likely due to DecomposeComplexOps being unable to "
|
||||
"decompose this op");
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
static bool satisfiesBackendContract(ModuleOp module,
|
||||
const ConversionTarget &target,
|
||||
bool actuallyEmitDiagnostics = false) {
|
||||
// We do not permit `torch.global_slot`'s in the backend contract, since
|
||||
// support for them is not widespread, and this does not align with PyTorch's
|
||||
|
@ -174,7 +197,8 @@ static bool satisfiesBackendContract(ModuleOp module,
|
|||
if (walkResult0.wasInterrupted())
|
||||
return false;
|
||||
|
||||
// Check all the type of all Value's in the program.
|
||||
// Check all the types of all Value's in the program and the legality of all
|
||||
// the ops.
|
||||
//
|
||||
// A pre-order walk gives a more intuitive "first error".
|
||||
// TODO: Should we report more than the first error?
|
||||
|
@ -185,10 +209,14 @@ static bool satisfiesBackendContract(ModuleOp module,
|
|||
actuallyEmitDiagnostics))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
for (Operation &op : *block)
|
||||
for (Operation &op : *block) {
|
||||
if (failed(checkOpIsBackendLegal(&op, target, actuallyEmitDiagnostics)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
for (OpResult result : op.getResults())
|
||||
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
@ -210,6 +238,11 @@ public:
|
|||
}
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
|
||||
if (decompose)
|
||||
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
|
||||
|
||||
OpPassManager pm(module.getOperationName());
|
||||
TorchLoweringPipelineOptions options;
|
||||
|
@ -227,14 +260,14 @@ public:
|
|||
<< " iterations of the simplification pipeline\n";
|
||||
});
|
||||
// Show the diagnostics.
|
||||
(void)satisfiesBackendContract(module,
|
||||
(void)satisfiesBackendContract(module, target,
|
||||
/*actuallyEmitDiagnostics=*/true);
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(runPipeline(pm, module)))
|
||||
return signalPassFailure();
|
||||
} while (!satisfiesBackendContract(module));
|
||||
} while (!satisfiesBackendContract(module, target));
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "LowerToBackendContractPass: "
|
||||
<< "succeeded after " << i
|
||||
|
@ -247,7 +280,10 @@ class VerifyBackendContractPass
|
|||
: public VerifyBackendContractBase<VerifyBackendContractPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
if (!satisfiesBackendContract(getOperation(), target,
|
||||
/*actuallyEmitDiagnostics=*/true)) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -265,3 +301,124 @@ std::unique_ptr<OperationPass<ModuleOp>>
|
|||
mlir::torch::Torch::createVerifyBackendContractPass() {
|
||||
return std::make_unique<VerifyBackendContractPass>();
|
||||
}
|
||||
|
||||
// The backend contract guarantees that ops with decompositions available will
|
||||
// be decomposed. The only way to have an op reach the backend contract without
|
||||
// getting decomposed is by having the user explicitly specify that op in the
|
||||
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
|
||||
// here we mark as illegal all ops with decompositions except for those in
|
||||
// `backendLegalOps`.
|
||||
//
|
||||
// The legality check takes place here instead of in the `DecomposeComplexOps`
|
||||
// pass for two reasons:
|
||||
// 1. Makes sure the `DecomposeComplexOps` pass always succeeds, allowing it to
|
||||
// run multiple times. This is needed for graphs where static information such
|
||||
// as dtypes and shapes takes multiple iterations to propagate through the
|
||||
// entire graph. `DecomposeComplexOps` pass failing would cause the entire
|
||||
// `LowerToBackendContractPass` to fail
|
||||
// 2. Makes the legality requirements in the backend contract for ops with
|
||||
// decompositions explicit in this file
|
||||
static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||
ConversionTarget &target,
|
||||
ArrayRef<std::string> backendLegalOps) {
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
target.addIllegalOp<AtenEmptyLikeOp>();
|
||||
target.addIllegalOp<AtenOnesLikeOp>();
|
||||
target.addIllegalOp<AtenZerosLikeOp>();
|
||||
target.addIllegalOp<AtenRollOp>();
|
||||
target.addIllegalOp<AtenRepeatOp>();
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOp>();
|
||||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
target.addIllegalOp<AtenTanhBackwardOp>();
|
||||
target.addIllegalOp<AtenAddmmOp>();
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
target.addIllegalOp<AtenMeanDimOp>();
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.getSelf());
|
||||
int rhsRank = getTensorRank(op.getOther());
|
||||
// Make aten.matmul legal if the following condition is satisfied.
|
||||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||
});
|
||||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
target.addIllegalOp<AtenSquareOp>();
|
||||
target.addIllegalOp<AtenVarOp>();
|
||||
target.addIllegalOp<AtenStdOp>();
|
||||
target.addIllegalOp<Aten_UnsafeViewOp>();
|
||||
target.addIllegalOp<Aten_ReshapeAliasOp>();
|
||||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
|
||||
target.addIllegalOp<AtenBernoulliTensorOp>();
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
target.addIllegalOp<AtenRelu6Op>();
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
target.addIllegalOp<AtenSoftplusOp>();
|
||||
target.addIllegalOp<AtenSiluOp>();
|
||||
target.addIllegalOp<AtenNewZerosOp>();
|
||||
target.addIllegalOp<AtenNewOnesOp>();
|
||||
target.addIllegalOp<AtenHardtanhOp>();
|
||||
target.addIllegalOp<AtenFullOp>();
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
target.addIllegalOp<AtenMishOp>();
|
||||
target.addIllegalOp<AtenFullLikeOp>();
|
||||
target.addIllegalOp<AtenIndexPutOp>();
|
||||
target.addIllegalOp<AtenExpandAsOp>();
|
||||
target.addIllegalOp<Aten_ToCopyOp>();
|
||||
target.addIllegalOp<AtenDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
target.addIllegalOp<AtenClampMinOp>();
|
||||
target.addIllegalOp<AtenClampMaxOp>();
|
||||
target.addIllegalOp<AtenBaddbmmOp>();
|
||||
target.addIllegalOp<AtenFloorDivideOp>();
|
||||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
target.addIllegalOp<AtenVarDimOp>();
|
||||
target.addIllegalOp<AtenAmaxOp>();
|
||||
target.addIllegalOp<AtenVarCorrectionOp>();
|
||||
target.addIllegalOp<AtenStdDimOp>();
|
||||
target.addIllegalOp<AtenNarrowOp>();
|
||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
|
||||
target.addIllegalOp<AtenMseLossOp>();
|
||||
target.addIllegalOp<AtenRandintLowOp>();
|
||||
target.addIllegalOp<AtenVarMeanCorrectionOp>();
|
||||
target.addIllegalOp<PrimsConvertElementTypeOp>();
|
||||
target.addIllegalOp<AtenRandnOp>();
|
||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||
for (std::string opName : backendLegalOps) {
|
||||
target.addLegalOp(OperationName(opName, context));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,30 +28,27 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
|
|||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
|
||||
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[DIM2]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
|
||||
// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM3]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int7 = torch.constant.int 7
|
||||
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -63,22 +60,19 @@ func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vte
|
|||
|
||||
// CHECK-LABEL: func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
// CHECK: %[[CST_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST_1]], %[[CST_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CST_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[DIM_2:.*]] = torch.aten.size.int %[[SELF]], %[[CST_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST_3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[DIM_3:.*]] = torch.aten.size.int %[[SELF]], %[[CST_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST_1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CEIL_MODE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[COUNT_INCLUDE_PAD:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[DIVISOR_OVERRIDE:.*]] = torch.constant.none
|
||||
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM_2]], %[[DIM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST_1_1]], %[[CST_1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST_0]], %[[CST_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[CEIL_MODE]], %[[COUNT_INCLUDE_PAD]], %[[DIVISOR_OVERRIDE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
|
||||
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[DIM3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%output_size = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
|
|
@ -44,6 +44,16 @@ func.func @f(%arg0: !torch.any) {
|
|||
|
||||
// -----
|
||||
|
||||
// Decomposition of `aten.t` fails if `inputRank > 2`
|
||||
func.func @f(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// expected-error @+2 {{found an op that was marked as backend illegal}}
|
||||
// expected-note @+1 {{this is likely due to}}
|
||||
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %t : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: checking of op results.
|
||||
// TODO: In theory we could diagnose every single value, but for now we bail out on the first one.
|
||||
|
||||
|
|
Loading…
Reference in New Issue