Revert "Decompose AtenNonzeroOp" (#3289)

Reverts llvm/torch-mlir#3281
pull/3296/head
Vivek Khandelwal 2024-05-06 22:22:04 +05:30 committed by GitHub
parent 17c3c15131
commit e60160d793
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 25 deletions

View File

@ -1201,21 +1201,6 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
public:
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value zeroScalar =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
zeroScalar);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> { class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public: public:
@ -7755,13 +7740,10 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
// is-xxx ops
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);

View File

@ -78,10 +78,3 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16> return %0 : !torch.tensor<[?], f16>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.nonzero
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
return %0 : !torch.vtensor<[3,4,5],si64>
}