diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c7ebc2ecb..d815b5fb2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" @@ -18,6 +19,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" #include using namespace mlir; @@ -675,13 +677,15 @@ public: int lhsRank = getTensorRank(lhs); int rhsRank = getTensorRank(rhs); - // If both lhs and rhs ranks are 2 then map it to `aten.mm` op. - if (lhsRank == 2 && rhsRank == 2) + if (lhsRank == 2 && rhsRank == 2) { + // If both lhs and rhs ranks are 2 then map it to `aten.mm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); - - // If both lhs and rhs ranks are 3 then map it to `aten.bmm` op. - if (lhsRank == 3 && rhsRank == 3) + } else if (lhsRank == 3 && rhsRank == 3) { + // If both lhs and rhs ranks are 3 then map it to `aten.bmm` op. rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); + } else { + return failure(); + } return success(); } @@ -3298,6 +3302,24 @@ public: namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { +private: + llvm::StringSet<> legalOpsSet; + + template + void addPatternIfTargetOpIsIllegal(RewritePatternSet &patterns) { + MLIRContext *context = &getContext(); + Optional opName = DecomposePattern(context).getRootKind(); + // Because the `DecomposeComplexOpsPass` uses a greedy algorithm + // to apply patterns, only patterns that we for sure know we want to run + // must be added. This restricts the set of patterns allowed in this file to + // patterns that apply to a single op. In other words, patterns that match + // on `Operation *` are not allowed, since there is no way of telling if + // that pattern will match on an op in the `legalOpsSet` or not. + assert(opName && "All decomposition patterns must target a single op"); + if (!legalOpsSet.contains(opName->getStringRef())) + patterns.add(context); + } + public: DecomposeComplexOpsPass() = default; DecomposeComplexOpsPass(ArrayRef legalOps) { @@ -3306,215 +3328,128 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - ConversionTarget target(*context); - target.addLegalDialect(); + // The strings in the `legalOps` ArrayRef don't exist during the call to the + // constructor `DecomposeComplexOpsPass`, so the creation of the + // `legalOpsSet` must be delayed to when `runOnOperation` gets called. + legalOpsSet.clear(); + legalOpsSet.insert(legalOps.begin(), legalOps.end()); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add>( - context); - target.addIllegalOp(); - patterns.add>( - context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - patterns.add(context); - target.addIllegalOp(); - target.addDynamicallyLegalOp([](AtenMatmulOp op) { - int lhsRank = getTensorRank(op.getSelf()); - int rhsRank = getTensorRank(op.getOther()); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorAllocLikeOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorAllocLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvolutionBackwardOverrideableOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAten_ConvolutionLikeOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAten_ConvolutionLikeOp>( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorNewLikeOp>( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeConstantTensorNewLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); - // Make aten.matmul legal if the following condition is satisfied. - return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3); - }); - patterns.add>( - context); - target.addIllegalOp(); - patterns.add>( - context); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add, - DecomposeAten_ConvolutionLikeOp>( - context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add>( - context); - target.addIllegalOp(); - patterns.add>( - context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add(context); - patterns.add(context); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add(context); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); - - for (std::string opName : legalOps) { - target.addLegalOp(OperationName(opName, context)); - } - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 9307030e2..fa853796a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -11,10 +11,12 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "torch-lower-to-backend-contract" @@ -27,6 +29,10 @@ using namespace mlir::torch::Torch; // Checking the backend contract. //===----------------------------------------------------------------------===// +static void markDecomposedOpsAsIllegal(MLIRContext *context, + ConversionTarget &target, + ArrayRef backendLegalOps); + static LogicalResult checkType(Operation *op, Type type, bool actuallyEmitDiagnostics) { // Allow various scalar types that backends are expected to be able to handle. @@ -149,7 +155,24 @@ static LogicalResult checkType(Operation *op, Type type, } } +static LogicalResult checkOpIsBackendLegal(Operation *op, + const ConversionTarget &target, + bool actuallyEmitDiagnostics) { + if (target.isLegal(op)) + return success(); + + if (actuallyEmitDiagnostics) { + return op->emitError("found an op that was marked as backend illegal") + .attachNote() + .append("this is likely due to DecomposeComplexOps being unable to " + "decompose this op"); + } else { + return failure(); + } +} + static bool satisfiesBackendContract(ModuleOp module, + const ConversionTarget &target, bool actuallyEmitDiagnostics = false) { // We do not permit `torch.global_slot`'s in the backend contract, since // support for them is not widespread, and this does not align with PyTorch's @@ -174,7 +197,8 @@ static bool satisfiesBackendContract(ModuleOp module, if (walkResult0.wasInterrupted()) return false; - // Check all the type of all Value's in the program. + // Check all the types of all Value's in the program and the legality of all + // the ops. // // A pre-order walk gives a more intuitive "first error". // TODO: Should we report more than the first error? @@ -185,10 +209,14 @@ static bool satisfiesBackendContract(ModuleOp module, actuallyEmitDiagnostics))) { return WalkResult::interrupt(); } - for (Operation &op : *block) + for (Operation &op : *block) { + if (failed(checkOpIsBackendLegal(&op, target, actuallyEmitDiagnostics))) + return WalkResult::interrupt(); + for (OpResult result : op.getResults()) if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics))) return WalkResult::interrupt(); + } return WalkResult::advance(); }); @@ -210,6 +238,11 @@ public: } void runOnOperation() override { ModuleOp module = getOperation(); + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + if (decompose) + markDecomposedOpsAsIllegal(context, target, backendLegalOps); OpPassManager pm(module.getOperationName()); TorchLoweringPipelineOptions options; @@ -227,14 +260,14 @@ public: << " iterations of the simplification pipeline\n"; }); // Show the diagnostics. - (void)satisfiesBackendContract(module, + (void)satisfiesBackendContract(module, target, /*actuallyEmitDiagnostics=*/true); return signalPassFailure(); } if (failed(runPipeline(pm, module))) return signalPassFailure(); - } while (!satisfiesBackendContract(module)); + } while (!satisfiesBackendContract(module, target)); LLVM_DEBUG({ llvm::dbgs() << "LowerToBackendContractPass: " << "succeeded after " << i @@ -247,7 +280,10 @@ class VerifyBackendContractPass : public VerifyBackendContractBase { public: void runOnOperation() override { - if (!satisfiesBackendContract(getOperation(), /*actuallyEmitDiagnostics=*/true)) { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + if (!satisfiesBackendContract(getOperation(), target, + /*actuallyEmitDiagnostics=*/true)) { return signalPassFailure(); } } @@ -265,3 +301,124 @@ std::unique_ptr> mlir::torch::Torch::createVerifyBackendContractPass() { return std::make_unique(); } + +// The backend contract guarantees that ops with decompositions available will +// be decomposed. The only way to have an op reach the backend contract without +// getting decomposed is by having the user explicitly specify that op in the +// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore, +// here we mark as illegal all ops with decompositions except for those in +// `backendLegalOps`. +// +// The legality check takes place here instead of in the `DecomposeComplexOps` +// pass for two reasons: +// 1. Makes sure the `DecomposeComplexOps` pass always succeeds, allowing it to +// run multiple times. This is needed for graphs where static information such +// as dtypes and shapes takes multiple iterations to propagate through the +// entire graph. `DecomposeComplexOps` pass failing would cause the entire +// `LowerToBackendContractPass` to fail +// 2. Makes the legality requirements in the backend contract for ops with +// decompositions explicit in this file +static void markDecomposedOpsAsIllegal(MLIRContext *context, + ConversionTarget &target, + ArrayRef backendLegalOps) { + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([](AtenMatmulOp op) { + int lhsRank = getTensorRank(op.getSelf()); + int rhsRank = getTensorRank(op.getOther()); + // Make aten.matmul legal if the following condition is satisfied. + return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3); + }); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + for (std::string opName : backendLegalOps) { + target.addLegalOp(OperationName(opName, context)); + } +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 58e1a1290..abaa2860c 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -28,30 +28,27 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch // ----- // CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$non_unit_output_size( // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[CST7:.*]] = torch.constant.int 7 -// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST7]], %[[CST7]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[CST6:.*]] = torch.constant.int 6 +// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7 +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-DAG: %[[NONE:.*]] = torch.constant.none // CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[CST3:.*]] = torch.constant.int 3 // CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool // CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases where input and output size are equal for non-unit output size" -// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM2]], %[[T1]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[T1:.*]] = torch.aten.sub.int %[[DIM2]], %[[CST6]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.bool // CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases where input and output size are equal for non-unit output size" -// CHECK: %[[T3:.*]] = torch.aten.sub.int %[[CST7]], %[[CST1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[T4:.*]] = torch.aten.sub.int %[[DIM3]], %[[T3]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T2]], %[[T4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T6:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T7:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[T5]], %[[T6]], %[[T7]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32> +// CHECK: %[[T2:.*]] = torch.aten.sub.int %[[DIM3]], %[[CST6]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int7 = torch.constant.int 7 %output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list @@ -63,22 +60,19 @@ func.func @torch.aten.adaptive_avg_pool2d$non_unit_output_size(%arg0: !torch.vte // CHECK-LABEL: func.func @torch.aten.adaptive_avg_pool2d$unit_output_size( // CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[CST_1:.*]] = torch.constant.int 1 -// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[CST_1]], %[[CST_1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[CST_2:.*]] = torch.constant.int 2 -// CHECK: %[[DIM_2:.*]] = torch.aten.size.int %[[SELF]], %[[CST_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[CST_3:.*]] = torch.constant.int 3 -// CHECK: %[[DIM_3:.*]] = torch.aten.size.int %[[SELF]], %[[CST_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[CST_1_1:.*]] = torch.constant.int 1 -// CHECK: %[[CST_0:.*]] = torch.constant.int 0 -// CHECK: %[[CEIL_MODE:.*]] = torch.constant.bool false -// CHECK: %[[COUNT_INCLUDE_PAD:.*]] = torch.constant.bool true -// CHECK: %[[DIVISOR_OVERRIDE:.*]] = torch.constant.none -// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM_2]], %[[DIM_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST_1_1]], %[[CST_1_1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST_0]], %[[CST_0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[OUT:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[CEIL_MODE]], %[[COUNT_INCLUDE_PAD]], %[[DIVISOR_OVERRIDE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -// CHECK: return %[[OUT]] : !torch.vtensor<[?,?,?,?],f32> +// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[CST1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-DAG: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[DIM3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int1 = torch.constant.int 1 %output_size = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list diff --git a/test/Dialect/Torch/lower-to-backend-contract-error.mlir b/test/Dialect/Torch/lower-to-backend-contract-error.mlir index 824f3ae23..7b3dda039 100644 --- a/test/Dialect/Torch/lower-to-backend-contract-error.mlir +++ b/test/Dialect/Torch/lower-to-backend-contract-error.mlir @@ -44,6 +44,16 @@ func.func @f(%arg0: !torch.any) { // ----- +// Decomposition of `aten.t` fails if `inputRank > 2` +func.func @f(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // expected-error @+2 {{found an op that was marked as backend illegal}} + // expected-note @+1 {{this is likely due to}} + %t = torch.aten.t %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %t : !torch.vtensor<[?,?,?],f32> +} + +// ----- + // Test case: checking of op results. // TODO: In theory we could diagnose every single value, but for now we bail out on the first one.