Revert "Decompose AtenNonzeroOp" (#3289)

Reverts llvm/torch-mlir#3281
pull/3296/head
Vivek Khandelwal 2024-05-06 22:22:04 +05:30 committed by GitHub
parent 17c3c15131
commit e60160d793
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 0 additions and 25 deletions

View File

@ -1201,21 +1201,6 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
};
} // namespace
namespace {
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
public:
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value zeroScalar =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
zeroScalar);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
@ -7755,13 +7740,10 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
// is-xxx ops
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);

View File

@ -78,10 +78,3 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.nonzero
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
return %0 : !torch.vtensor<[3,4,5],si64>
}