mirror of https://github.com/llvm/torch-mlir
Remove unused RemoveUnused patterns
parent
7e30ef798b
commit
c4f1e49a68
|
@ -18394,7 +18394,6 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
|||
printDefaultTorchOp(printer, *this, 0, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
|
||||
|
|
|
@ -2288,17 +2288,6 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
listElements);
|
||||
return success();
|
||||
});
|
||||
// One-off pattern to erase if dead.
|
||||
// TODO: Use the effects infra to express the semantics of this op and enable
|
||||
// a centralized "erase if dead" canonicalization.
|
||||
// Specifically, we need to mark the op as only MemoryEffects::Allocate
|
||||
// so that `mlir::wouldOpBeTriviallyDead` does the right thing.
|
||||
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
|
||||
if (!op.use_empty())
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3490,20 +3479,6 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimUninitializedOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PrimUninitializedOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
|
||||
if (!op.use_empty())
|
||||
return failure();
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleUnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -4892,17 +4892,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" } else {\n"
|
||||
" %12 = torch.aten.__isnot__ %5#1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %16 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" %15 = torch.aten.gt.int %5#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %15 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %14 = torch.prim.If %13 -> (!torch.bool) {\n"
|
||||
" %15 = torch.prim.unchecked_cast %5#1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %16 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %17 = torch.aten.eq.int %16, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %17 : !torch.bool\n"
|
||||
" %15 = torch.aten.remainder.int %1, %5#0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.eq.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %16 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
|
@ -4982,17 +4980,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" } else {\n"
|
||||
" %9 = torch.aten.__isnot__ %3#1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %10 = torch.prim.If %9 -> (!torch.bool) {\n"
|
||||
" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %13 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" %12 = torch.aten.gt.int %3#0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
|
||||
" %12 = torch.prim.unchecked_cast %3#1 : !torch.optional<int> -> !torch.int\n"
|
||||
" %13 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %14 : !torch.bool\n"
|
||||
" %12 = torch.aten.remainder.int %arg1, %3#0 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.eq.int %12, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %13 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
|
@ -5452,7 +5448,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %11 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<float>> -> !torch.list<float>\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
|
@ -7444,7 +7439,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %arg2 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
|
@ -11138,9 +11132,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" %11 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
|
||||
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
|
@ -11153,7 +11146,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
|
||||
|
@ -11217,9 +11209,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" %12 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %12 : !torch.bool\n"
|
||||
" %11 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5:2 = torch.prim.If %4 -> (!torch.optional<list<int>>, !torch.optional<list<float>>) {\n"
|
||||
" torch.prim.If.yield %arg1, %arg2 : !torch.optional<list<int>>, !torch.optional<list<float>>\n"
|
||||
|
@ -11232,7 +11223,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %11 = torch.aten.__is__ %5#1, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n"
|
||||
|
|
|
@ -37,8 +37,6 @@ template <> struct QuantInfo<AtenReluOp> {
|
|||
// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp"
|
||||
// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp"
|
||||
bool isQCommutingOp(mlir::Operation *op) {
|
||||
// if adding a new commuting op here, be sure to add a
|
||||
// RemoveUnused pattern for that op to clean up afterwards
|
||||
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
|
||||
PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
|
||||
op);
|
||||
|
@ -419,35 +417,12 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto result = op.getResult();
|
||||
if (result.use_empty()) {
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<
|
||||
RemoveUnused<AtenDequantizeSelfOp>,
|
||||
RemoveUnused<AtenDequantizeTensorOp>,
|
||||
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
||||
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
|
||||
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
|
||||
RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
|
||||
RemoveUnused<AtenConstantPadNdOp>,
|
||||
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
||||
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
||||
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
||||
|
|
|
@ -1396,22 +1396,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
|
||||
public:
|
||||
using OpRewritePattern<T>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(T op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (auto use : op->getResults())
|
||||
if (!use.use_empty())
|
||||
return failure();
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
bool isItemForSliceOp(Operation *op) {
|
||||
|
@ -1512,23 +1496,6 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
|||
patterns.getContext());
|
||||
}
|
||||
|
||||
void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<RemoveUnusedPattern<Torch::AtenIntBoolOp>,
|
||||
RemoveUnusedPattern<Torch::AtenEqIntOp>,
|
||||
RemoveUnusedPattern<Torch::AtenToDtypeOp>,
|
||||
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
|
||||
RemoveUnusedPattern<Torch::AtenFullOp>,
|
||||
RemoveUnusedPattern<Torch::AtenUnsqueezeOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSqueezeDimOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||
RemoveUnusedPattern<Torch::AtenFloatScalarOp>,
|
||||
RemoveUnusedPattern<Torch::AtenIntScalarOp>,
|
||||
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
namespace {
|
||||
class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
|
||||
|
@ -1545,7 +1512,6 @@ public:
|
|||
populateScalarizationPropagationPatterns(patterns);
|
||||
populateScalarizationFoldPatterns(patterns);
|
||||
populateScalarizationCanonicalizePatterns(patterns);
|
||||
populateScalarizationRemovePatterns(patterns);
|
||||
context->getLoadedDialect<mlir::arith::ArithDialect>()
|
||||
->getCanonicalizationPatterns(patterns);
|
||||
// don't load torch canonicalization patterns, since these may lead to
|
||||
|
|
|
@ -1234,7 +1234,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("prim::max.self_int : (int[]) -> (int)")
|
||||
emit("prim::max.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("prim::RaiseException : (str, str?) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True, traits=["Pure"])
|
||||
emit("prim::Uninitialized : () -> (Any)", traits=["Pure"])
|
||||
emit(
|
||||
"prim::unchecked_cast : (t) -> (t)",
|
||||
has_folder=True,
|
||||
|
|
Loading…
Reference in New Issue