mirror of https://github.com/llvm/torch-mlir
Decompose AtenNonzeroOp (#3281)
This fixes some onnx lit tests not lowering to linalg in https://github.com/nod-ai/SHARK-Turbine/issues/450pull/3288/head
parent
53299eb224
commit
1af00e6040
|
@ -1201,6 +1201,21 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
||||
public:
|
||||
using OpRewritePattern<AtenNonzeroOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNonzeroOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value zeroScalar =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
rewriter.replaceOpWithNewOp<AtenNeScalarOp>(op, op.getType(), op.getSelf(),
|
||||
zeroScalar);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
namespace {
|
||||
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||
public:
|
||||
|
@ -7740,10 +7755,13 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
||||
// is-xxx ops
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
|
||||
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
|
|
|
@ -78,3 +78,10 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch
|
|||
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
||||
return %0 : !torch.tensor<[?], f16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.nonzero
|
||||
func.func @torch.aten.nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> {
|
||||
%0 = torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
|
||||
return %0 : !torch.vtensor<[3,4,5],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue