mirror of https://github.com/llvm/torch-mlir
parent
17c3c15131
commit
e60160d793
|
@ -1201,21 +1201,6 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(AtenNonzeroOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
Value zeroScalar =
|
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
||||||
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
|
|
||||||
zeroScalar);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -7755,13 +7740,10 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
||||||
// is-xxx ops
|
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
|
|
||||||
|
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||||
|
|
|
@ -78,10 +78,3 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
|
||||||
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
||||||
return %0 : !torch.tensor<[?], f16>
|
return %0 : !torch.tensor<[?], f16>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.nonzero
|
|
||||||
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
|
|
||||||
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
|
|
||||||
return %0 : !torch.vtensor<[3,4,5],si64>
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue