Remove unused RemoveUnused patterns

pull/3891/head
zjgarvey 2024-11-22 17:52:01 -06:00
parent 7e30ef798b
commit c4f1e49a68
6 changed files with 15 additions and 110 deletions

View File

@ -18394,7 +18394,6 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
printDefaultTorchOp(printer, *this, 0, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [

View File

@ -2288,17 +2288,6 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
listElements);
return success();
});
// One-off pattern to erase if dead.
// TODO: Use the effects infra to express the semantics of this op and enable
// a centralized "erase if dead" canonicalization.
// Specifically, we need to mark the op as only MemoryEffects::Allocate
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
if (!op.use_empty())
return failure();
rewriter.eraseOp(op);
return failure();
});
}
//===----------------------------------------------------------------------===//
@ -3490,20 +3479,6 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// PrimUninitializedOp
//===----------------------------------------------------------------------===//
void PrimUninitializedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
if (!op.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===//
// PrimTupleUnpackOp
//===----------------------------------------------------------------------===//

View File

@ -4892,17 +4892,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } else {\n"
" %12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int\n"
" %16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %16 : !torch.bool\n"
" %15 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %15 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int\n"
" %16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n"
" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %17 : !torch.bool\n"
" %15 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %16 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
@ -4982,17 +4980,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } else {\n"
" %9 = torch.aten.__isnot__ %3#1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional<int> -> !torch.int\n"
" %13 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" %12 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional<int> -> !torch.int\n"
" %13 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %14 : !torch.bool\n"
" %12 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n"
" %13 = torch.aten.eq.int %12, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
@ -5452,7 +5448,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %11 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<float>> -> !torch.list<float>\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %7 -> () {\n"
@ -7444,7 +7439,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %arg2 : !torch.bool\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
@ -11138,9 +11132,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" %11 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" }\n"
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
@ -11153,7 +11146,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
@ -11217,9 +11209,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %12 : !torch.bool\n"
" %11 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" }\n"
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
@ -11232,7 +11223,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If.yield %11 : !torch.bool\n"
" } else {\n"
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"

View File

@ -37,8 +37,6 @@ template <> struct QuantInfo<AtenReluOp> {
// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp"
// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp"
bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
op);
@ -419,35 +417,12 @@ public:
}
};
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto result = op.getResult();
if (result.use_empty()) {
op.erase();
return success();
}
return failure();
}
};
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<
RemoveUnused<AtenDequantizeSelfOp>,
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
RemoveUnused<AtenConstantPadNdOp>,
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,

View File

@ -1396,22 +1396,6 @@ public:
};
} // namespace
namespace {
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
public:
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
for (auto use : op->getResults())
if (!use.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
bool isItemForSliceOp(Operation *op) {
@ -1512,23 +1496,6 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
patterns.getContext());
}
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
RemoveUnusedPattern<Torch::AtenFullOp>,
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext());
}
} // namespace
namespace {
class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
@ -1545,7 +1512,6 @@ public:
populateScalarizationPropagationPatterns(patterns);
populateScalarizationFoldPatterns(patterns);
populateScalarizationCanonicalizePatterns(patterns);
populateScalarizationRemovePatterns(patterns);
context->getLoadedDialect<mlir::arith::ArithDialect>()
->getCanonicalizationPatterns(patterns);
// don't load torch canonicalization patterns, since these may lead to

View File

@ -1234,7 +1234,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
emit("prim::RaiseException : (str, str?) -> ()")
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"])
emit("prim::Uninitialized : () -> (Any)", traits=["Pure"])
emit(
"prim::unchecked_cast : (t) -> (t)",
has_folder=True,