Decompose AtenNonzeroOp (#3281)

This fixes some onnx lit tests not lowering to linalg in
https://github.com/nod-ai/SHARK-Turbine/issues/450
pull/3288/head
Xida Ren (Cedar) 2024-05-05 09:59:25 -04:00 committed by GitHub
parent 53299eb224
commit 1af00e6040
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 0 deletions

View File

@ -1201,6 +1201,21 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
};
} // namespace
namespace {
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
public:
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNonzeroOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value zeroScalar =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
zeroScalar);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
@ -7740,10 +7755,13 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
// is-xxx ops
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);

View File

@ -78,3 +78,10 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}
// -----
// CHECK-LABEL: func.func @torch.aten.nonzero
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
return %0 : !torch.vtensor<[3,4,5],si64>
}