Allow running DecomposeComplexOps more than once (#1671)

The current implementation of `DecomposeComplexOps` fails if an op
expected to be decomposed does not get decomposed in the first
iteration of the `createTorchSimplificationPipeline` in
`LowerToBackendContractPass`. However, some graphs require multiple
iterations of `createTorchSimplificationPipeline` to fully propagate
all statically knowable information, such as dtypes and shapes, to the
entire graph, sometimes resulting in the need to run
`DecomposeComplexOps` more than once.

This commit changes `DecomposeComplexOps` to use a greedy algorithm
for pattern application and moves the legalization check of ops to the
`LowerToBackendContractPass` to allow for the `DecomposeComplexOps` to
run more than once.
pull/1699/head
Ramiro Leal-Cavazos 2022-12-08 09:26:38 -08:00 committed by GitHub
parent e8511840c3
commit a54b334578
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 346 additions and 250 deletions

View File

@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
@ -18,6 +19,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include <cstdint>
using namespace mlir;
@ -675,13 +677,15 @@ public:
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
if (lhsRank == 2 && rhsRank == 2)
if (lhsRank == 2 && rhsRank == 2) {
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
if (lhsRank == 3 && rhsRank == 3)
} else if (lhsRank == 3 && rhsRank == 3) {
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
} else {
return failure();
}
return success();
}
@ -3298,6 +3302,24 @@ public:
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
private:
llvm::StringSet<> legalOpsSet;
template <typename DecomposePattern>
void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) {
MLIRContext *context = &getContext();
Optional<OperationName> opName = DecomposePattern(context).getRootKind();
// Because the `DecomposeComplexOpsPass` uses a greedy algorithm
// to apply patterns, only patterns that we for sure know we want to run
// must be added. This restricts the set of patterns allowed in this file to
// patterns that apply to a single op. In other words, patterns that match
// on `Operation *` are not allowed, since there is no way of telling if
// that pattern will match on an op in the `legalOpsSet` or not.
assert(opName && "All decomposition patterns must target a single op");
if (!legalOpsSet.contains(opName->getStringRef()))
patterns.add<DecomposePattern>(context);
}
public:
DecomposeComplexOpsPass() = default;
DecomposeComplexOpsPass(ArrayRef<std::string> legalOps) {
@ -3306,215 +3328,128 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect>();
// The strings in the `legalOps` ArrayRef don't exist during the call to the
// constructor `DecomposeComplexOpsPass`, so the creation of the
// `legalOpsSet` must be delayed to when `runOnOperation` gets called.
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAten_SoftmaxOp>(context);
target.addIllegalOp<Aten_SoftmaxOp>();
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
target.addIllegalOp<Aten_LogSoftmaxOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenEmptyLikeOp>(context);
target.addIllegalOp<AtenEmptyLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(
context);
target.addIllegalOp<AtenOnesLikeOp>();
patterns.add<DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(
context);
target.addIllegalOp<AtenZerosLikeOp>();
patterns.add<DecomposeAtenRollOp>(context);
target.addIllegalOp<AtenRollOp>();
patterns.add<DecomposeAtenRepeatOp>(context);
target.addIllegalOp<AtenRepeatOp>();
patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenFlattenUsingIntsOp>(context);
target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<DecomposeAtenWhereScalarOp>(context);
target.addIllegalOp<AtenWhereScalarOp>();
patterns.add<DecomposeAtenWhereScalarOtherOp>(context);
target.addIllegalOp<AtenWhereScalarOtherOp>();
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
target.addIllegalOp<AtenWhereScalarSelfOp>();
patterns.add<DecomposeAtenConvolutionBackwardOverrideableOp>(context);
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAtenReshapeOp>(context);
target.addIllegalOp<AtenReshapeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenTanhBackwardOp>(context);
target.addIllegalOp<AtenTanhBackwardOp>();
patterns.add<DecomposeAtenAddmmOp>(context);
target.addIllegalOp<AtenAddmmOp>();
patterns.add<DecomposeAtenMeanOp>(context);
target.addIllegalOp<AtenMeanOp>();
patterns.add<DecomposeAtenMeanDimOp>(context);
target.addIllegalOp<AtenMeanDimOp>();
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addIllegalOp<AtenMvOp>();
patterns.add<DecomposeAtenMvOp>(context);
target.addIllegalOp<AtenTOp>();
patterns.add<DecomposeAtenTOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.getSelf());
int rhsRank = getTensorRank(op.getOther());
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenOnesLikeOp, 1>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorAllocLikeOp<AtenZerosLikeOp, 0>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRollOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeViewOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ReshapeAliasOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeValsemVariantAtenBernoulliFloatOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_EmbeddingBagOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLiftFreshCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorHackedTwinOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMseLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(
context);
target.addIllegalOp<AtenAddcmulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(
context);
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
patterns.add<DecomposeAtenConvolutionOverrideableOp>(context);
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
context);
target.addIllegalOp<AtenConvolutionBackwardOp>();
patterns.add<DecomposeAtenConvolutionBackwardOp>(context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
target.addIllegalOp<AtenConvTranspose2dInputOp>();
patterns.add<DecomposeAtenConvTranspose2dOp>(context);
patterns.add<DecomposeAtenArangeOp>(context);
target.addIllegalOp<AtenArangeOp>();
patterns.add<DecomposeAtenArangeStartOp>(context);
target.addIllegalOp<AtenArangeStartOp>();
patterns.add<DecomposeAtenArgMaxOp>(context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<DecomposeAtenSquareOp>(context);
target.addIllegalOp<AtenSquareOp>();
patterns.add<DecomposeAtenVarOp>(context);
target.addIllegalOp<AtenVarOp>();
patterns.add<DecomposeAtenStdOp>(context);
target.addIllegalOp<AtenStdOp>();
patterns.add<DecomposeAten_UnsafeViewOp>(context);
target.addIllegalOp<Aten_UnsafeViewOp>();
patterns.add<DecomposeAten_ReshapeAliasOp>(context);
target.addIllegalOp<Aten_ReshapeAliasOp>();
patterns.add<DecomposeAtenBernoulliOp>(context);
target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliFloatOp>(context);
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
patterns.add<DecomposeAtenBernoulliTensorOp>(context);
target.addIllegalOp<AtenBernoulliTensorOp>();
patterns.add<DecomposeAtenZeroOp>(context);
target.addIllegalOp<AtenZeroOp>();
patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
patterns.add<DecomposeAtenRelu6Op>(context);
target.addIllegalOp<AtenRelu6Op>();
patterns.add<DecomposeAtenHardswishOp>(context);
target.addIllegalOp<AtenHardswishOp>();
patterns.add<DecomposeAtenSoftplusOp>(context);
target.addIllegalOp<AtenSoftplusOp>();
patterns.add<DecomposeAtenSiluOp>(context);
target.addIllegalOp<AtenSiluOp>();
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
context);
target.addIllegalOp<AtenNewZerosOp>();
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(
context);
target.addIllegalOp<AtenNewOnesOp>();
patterns.add<DecomposeAtenHardtanhOp>(context);
target.addIllegalOp<AtenHardtanhOp>();
patterns.add<DecomposeAtenFullOp>(context);
target.addIllegalOp<AtenFullOp>();
patterns.add<DecomposeAtenLinearOp>(context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<DecomposeAtenMishOp>(context);
target.addIllegalOp<AtenMishOp>();
patterns.add<DecomposeAtenFullLikeOp>(context);
target.addIllegalOp<AtenFullLikeOp>();
patterns.add<DecomposeAtenIndexPutOp>(context);
target.addIllegalOp<AtenIndexPutOp>();
patterns.add<DecomposeAtenExpandAsOp>(context);
target.addIllegalOp<AtenExpandAsOp>();
patterns.add<DecomposeAten_ToCopyOp>(context);
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
patterns.add<DecomposeAtenPadOp>(context);
patterns.add<DecomposeAtenToDtypeLayoutOp>(context);
target.addIllegalOp<AtenToDtypeLayoutOp>();
patterns.add<DecomposeAtenToDeviceOp>(context);
target.addIllegalOp<AtenToDeviceOp>();
patterns.add<DecomposeAtenAdaptiveAvgPool2dOp>(context);
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
patterns.add<DecomposeAtenClampMinOp>(context);
target.addIllegalOp<AtenClampMinOp>();
patterns.add<DecomposeAtenClampMaxOp>(context);
target.addIllegalOp<AtenClampMaxOp>();
patterns.add<DecomposeAtenBaddbmmOp>(context);
target.addIllegalOp<AtenBaddbmmOp>();
patterns.add<DecomposeAtenFloorDivideOp>(context);
target.addIllegalOp<AtenFloorDivideOp>();
patterns.add<DecomposeAtenNumpyTOp>(context);
target.addIllegalOp<AtenNumpyTOp>();
patterns.add<DecomposeAtenSelectScatterOp>(context);
target.addIllegalOp<AtenSelectScatterOp>();
patterns.add<DecomposeAtenVarDimOp>(context);
target.addIllegalOp<AtenVarDimOp>();
patterns.add<DecomposeAtenAmaxOp>(context);
target.addIllegalOp<AtenAmaxOp>();
patterns.add<DecomposeAtenVarCorrectionOp>(context);
target.addIllegalOp<AtenVarCorrectionOp>();
patterns.add<DecomposeAtenStdDimOp>(context);
target.addIllegalOp<AtenStdDimOp>();
patterns.add<DecomposeAtenNarrowOp>(context);
target.addIllegalOp<AtenNarrowOp>();
patterns.add<DecomposeAten_EmbeddingBagOp>(context);
target.addIllegalOp<Aten_EmbeddingBagOp>();
patterns.add<DecomposeAtenLiftFreshCopyOp>(context);
target.addIllegalOp<AtenLiftFreshCopyOp>();
patterns.add<DecomposeAtenIndexTensorHackedTwinOp>(context);
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
patterns.add<DecomposeAtenMseLossOp>(context);
target.addIllegalOp<AtenMseLossOp>();
patterns.add<DecomposeAtenRandintLowOp>(context);
target.addIllegalOp<AtenRandintLowOp>();
patterns.add<DecomposeAtenVarMeanCorrectionOp>(context);
target.addIllegalOp<AtenVarMeanCorrectionOp>();
patterns.add<DecomposePrimsConvertElementTypeOp>(context);
target.addIllegalOp<PrimsConvertElementTypeOp>();
patterns.add<DecomposeAtenRandnOp>(context);
target.addIllegalOp<AtenRandnOp>();
patterns.add<DecomposeAtenRandnGeneratorOp>(context);
target.addIllegalOp<AtenRandnGeneratorOp>();
for (std::string opName : legalOps) {
target.addLegalOp(OperationName(opName, context));
}
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}

View File

@ -11,10 +11,12 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract"
@ -27,6 +29,10 @@ using namespace mlir::torch::Torch;
// Checking the backend contract.
//===----------------------------------------------------------------------===//
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps);
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle.
@ -149,7 +155,24 @@ static LogicalResult checkType(Operation *op, Type type,
}
}
static LogicalResult checkOpIsBackendLegal(Operation *op,
const ConversionTarget &target,
bool actuallyEmitDiagnostics) {
if (target.isLegal(op))
return success();
if (actuallyEmitDiagnostics) {
return op->emitError("found an op that was marked as backend illegal")
.attachNote()
.append("this is likely due to DecomposeComplexOps being unable to "
"decompose this op");
} else {
return failure();
}
}
static bool satisfiesBackendContract(ModuleOp module,
const ConversionTarget &target,
bool actuallyEmitDiagnostics = false) {
// We do not permit `torch.global_slot`'s in the backend contract, since
// support for them is not widespread, and this does not align with PyTorch's
@ -174,7 +197,8 @@ static bool satisfiesBackendContract(ModuleOp module,
if (walkResult0.wasInterrupted())
return false;
// Check all the type of all Value's in the program.
// Check all the types of all Value's in the program and the legality of all
// the ops.
//
// A pre-order walk gives a more intuitive "first error".
// TODO: Should we report more than the first error?
@ -185,10 +209,14 @@ static bool satisfiesBackendContract(ModuleOp module,
actuallyEmitDiagnostics))) {
return WalkResult::interrupt();
}
for (Operation &op : *block)
for (Operation &op : *block) {
if (failed(checkOpIsBackendLegal(&op, target, actuallyEmitDiagnostics)))
return WalkResult::interrupt();
for (OpResult result : op.getResults())
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
@ -210,6 +238,11 @@ public:
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
@ -227,14 +260,14 @@ public:
<< " iterations of the simplification pipeline\n";
});
// Show the diagnostics.
(void)satisfiesBackendContract(module,
(void)satisfiesBackendContract(module, target,
/*actuallyEmitDiagnostics=*/true);
return signalPassFailure();
}
if (failed(runPipeline(pm, module)))
return signalPassFailure();
} while (!satisfiesBackendContract(module));
} while (!satisfiesBackendContract(module, target));
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "succeeded after " << i
@ -247,7 +280,10 @@ class VerifyBackendContractPass
: public VerifyBackendContractBase<VerifyBackendContractPass> {
public:
void runOnOperation() override {
if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
return signalPassFailure();
}
}
@ -265,3 +301,124 @@ std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractPass() {
return std::make_unique<VerifyBackendContractPass>();
}
// The backend contract guarantees that ops with decompositions available will
// be decomposed. The only way to have an op reach the backend contract without
// getting decomposed is by having the user explicitly specify that op in the
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
// here we mark as illegal all ops with decompositions except for those in
// `backendLegalOps`.
//
// The legality check takes place here instead of in the `DecomposeComplexOps`
// pass for two reasons:
// 1. Makes sure the `DecomposeComplexOps` pass always succeeds, allowing it to
// run multiple times. This is needed for graphs where static information such
// as dtypes and shapes takes multiple iterations to propagate through the
// entire graph. `DecomposeComplexOps` pass failing would cause the entire
// `LowerToBackendContractPass` to fail
// 2. Makes the legality requirements in the backend contract for ops with
// decompositions explicit in this file
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps) {
target.addIllegalOp<AtenSoftmaxIntOp>();
target.addIllegalOp<Aten_SoftmaxOp>();
target.addIllegalOp<Aten_LogSoftmaxOp>();
target.addIllegalOp<AtenLogSoftmaxIntOp>();
target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.getSelf());
int rhsRank = getTensorRank(op.getOther());
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
target.addIllegalOp<AtenConvolutionOverrideableOp>();
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
target.addIllegalOp<AtenConvolutionBackwardOp>();
target.addIllegalOp<AtenConv2dOp>();
target.addIllegalOp<AtenConvTranspose2dInputOp>();
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();
target.addIllegalOp<Aten_UnsafeViewOp>();
target.addIllegalOp<Aten_ReshapeAliasOp>();
target.addIllegalOp<AtenBernoulliOp>();
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();
target.addIllegalOp<AtenNewZerosOp>();
target.addIllegalOp<AtenNewOnesOp>();
target.addIllegalOp<AtenHardtanhOp>();
target.addIllegalOp<AtenFullOp>();
target.addIllegalOp<AtenLinearOp>();
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenIndexPutOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMaxOp>();
target.addIllegalOp<AtenBaddbmmOp>();
target.addIllegalOp<AtenFloorDivideOp>();
target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenVarCorrectionOp>();
target.addIllegalOp<AtenStdDimOp>();
target.addIllegalOp<AtenNarrowOp>();
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
target.addIllegalOp<PrimsConvertElementTypeOp>();
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}
}

View File

@ -28,30 +28,27 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
// -----
// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CST7:.*]] = torch.constant.int 7
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[CST6:.*]] = torch.constant.int 6
// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST3:.*]] = torch.constant.int 3
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[DIM2]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size"
// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM3]], %[[CST6]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int7 = torch.constant.int 7
%output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
@ -63,22 +60,19 @@ func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vte
// CHECK-LABEL: func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[CST_1:.*]] = torch.constant.int 1
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST_1]], %[[CST_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CST_2:.*]] = torch.constant.int 2
// CHECK: %[[DIM_2:.*]] = torch.aten.size.int %[[SELF]], %[[CST_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST_3:.*]] = torch.constant.int 3
// CHECK: %[[DIM_3:.*]] = torch.aten.size.int %[[SELF]], %[[CST_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[CST_1_1:.*]] = torch.constant.int 1
// CHECK: %[[CST_0:.*]] = torch.constant.int 0
// CHECK: %[[CEIL_MODE:.*]] = torch.constant.bool false
// CHECK: %[[COUNT_INCLUDE_PAD:.*]] = torch.constant.bool true
// CHECK: %[[DIVISOR_OVERRIDE:.*]] = torch.constant.none
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM_2]], %[[DIM_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST_1_1]], %[[CST_1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST_0]], %[[CST_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[CEIL_MODE]], %[[COUNT_INCLUDE_PAD]], %[[DIVISOR_OVERRIDE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32>
// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[DIM3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%output_size = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>

View File

@ -44,6 +44,16 @@ func.func @f(%arg0: !torch.any) {
// -----
// Decomposition of `aten.t` fails if `inputRank > 2`
func.func @f(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// expected-error @+2 {{found an op that was marked as backend illegal}}
// expected-note @+1 {{this is likely due to}}
%t = torch.aten.t %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
return %t : !torch.vtensor<[?,?,?],f32>
}
// -----
// Test case: checking of op results.
// TODO: In theory we could diagnose every single value, but for now we bail out on the first one.