diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a7456b0d7..48b2a4138 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -18394,7 +18394,6 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ printDefaultTorchOp(printer, *this, 0, 1); } }]; - let hasCanonicalizer = 1; } def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 87d1464e2..eee37302a 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2288,17 +2288,6 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, listElements); return success(); }); - // One-off pattern to erase if dead. - // TODO: Use the effects infra to express the semantics of this op and enable - // a centralized "erase if dead" canonicalization. - // Specifically, we need to mark the op as only MemoryEffects::Allocate - // so that `mlir::wouldOpBeTriviallyDead` does the right thing. - patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { - if (!op.use_empty()) - return failure(); - rewriter.eraseOp(op); - return failure(); - }); } //===----------------------------------------------------------------------===// @@ -3490,20 +3479,6 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } -//===----------------------------------------------------------------------===// -// PrimUninitializedOp -//===----------------------------------------------------------------------===// - -void PrimUninitializedOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) { - if (!op.use_empty()) - return failure(); - rewriter.eraseOp(op); - return success(); - }); -} - //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a8ce5ed20..6ca393798 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -4892,17 +4892,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } else {\n" " %12 = torch.aten.__isnot__ %5#1, %none : !torch.optional, !torch.none -> !torch.bool\n" " %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int\n" -" %16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" +" %15 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " %14 = torch.prim.If %13 -> (!torch.bool) {\n" -" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional -> !torch.int\n" -" %16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n" -" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %17 : !torch.bool\n" +" %15 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %16 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" @@ -4982,17 +4980,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } else {\n" " %9 = torch.aten.__isnot__ %3#1, %none : !torch.optional, !torch.none -> !torch.bool\n" " %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int\n" -" %13 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %13 : !torch.bool\n" +" %12 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional -> !torch.int\n" -" %13 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n" -" %14 = torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %14 : !torch.bool\n" +" %12 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.eq.int %12, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" @@ -5452,7 +5448,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %11 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" " torch.prim.If.yield %11 : !torch.bool\n" " } else {\n" -" %11 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " torch.prim.If %7 -> () {\n" @@ -7444,7 +7439,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.If %0 -> (!torch.bool) {\n" " torch.prim.If.yield %arg2 : !torch.bool\n" " } else {\n" -" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " %2 = torch.prim.If %1 -> (!torch.list) {\n" @@ -11138,9 +11132,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" -" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %12 : !torch.bool\n" +" %11 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" " }\n" " %5:2 = torch.prim.If %4 -> (!torch.optional>, !torch.optional>) {\n" " torch.prim.If.yield %arg1, %arg2 : !torch.optional>, !torch.optional>\n" @@ -11153,7 +11146,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %11 = torch.aten.__is__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" " torch.prim.If.yield %11 : !torch.bool\n" " } else {\n" -" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" @@ -11217,9 +11209,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.If %3 -> (!torch.bool) {\n" " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" -" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" -" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" torch.prim.If.yield %12 : !torch.bool\n" +" %11 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" " }\n" " %5:2 = torch.prim.If %4 -> (!torch.optional>, !torch.optional>) {\n" " torch.prim.If.yield %arg1, %arg2 : !torch.optional>, !torch.optional>\n" @@ -11232,7 +11223,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %11 = torch.aten.__is__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" " torch.prim.If.yield %11 : !torch.bool\n" " } else {\n" -" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" " %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 5da8217f6..c56b1cad0 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -37,8 +37,6 @@ template <> struct QuantInfo { // where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" // and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" bool isQCommutingOp(mlir::Operation *op) { - // if adding a new commuting op here, be sure to add a - // RemoveUnused pattern for that op to clean up afterwards return llvm::isa( op); @@ -419,35 +417,12 @@ public: } }; -template class RemoveUnused : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - auto result = op.getResult(); - if (result.use_empty()) { - op.erase(); - return success(); - } - return failure(); - } -}; - class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.insert< - RemoveUnused, - RemoveUnused, - RemoveUnused, - RemoveUnused, - RemoveUnused, RemoveUnused, - RemoveUnused, RemoveUnused, - RemoveUnused, RemoveUnused, - RemoveUnused, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 634e910d4..c7eeaf5b2 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1396,22 +1396,6 @@ public: }; } // namespace -namespace { -template class RemoveUnusedPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(T op, - PatternRewriter &rewriter) const override { - for (auto use : op->getResults()) - if (!use.use_empty()) - return failure(); - - rewriter.eraseOp(op); - return success(); - } -}; -} // namespace - namespace { bool isItemForSliceOp(Operation *op) { @@ -1512,23 +1496,6 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { patterns.getContext()); } -void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { - patterns.insert, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern>( - patterns.getContext()); -} - } // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { @@ -1545,7 +1512,6 @@ public: populateScalarizationPropagationPatterns(patterns); populateScalarizationFoldPatterns(patterns); populateScalarizationCanonicalizePatterns(patterns); - populateScalarizationRemovePatterns(patterns); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); // don't load torch canonicalization patterns, since these may lead to diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c59edd7ab..dee65579c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1234,7 +1234,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)", has_folder=True) emit("prim::RaiseException : (str, str?) -> ()") - emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"]) + emit("prim::Uninitialized : () -> (Any)", traits=["Pure"]) emit( "prim::unchecked_cast : (t) -> (t)", has_folder=True,