Support aten.sign (#2205)

pull/2226/head snapshot-20230611.866
Matthias Gehre 2023-06-10 20:45:35 +02:00 committed by GitHub
parent 5ead1d549e
commit 4e2ba2e0af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 0 deletions

View File

@ -395,6 +395,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseSignModule_basic",
"ElementwisePowModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
@ -846,6 +847,7 @@ TOSA_PASS_SET = {
"SqueezeDimModule_identity",
"SqueezeDimModule_unitDim",
"ReturnTwoTensorF32I64_basic",
"ElementwiseSignModule_basic",
"ElementwisePowModule_basic",
"BmmModule_basic",
"MmDagModule_basic",
@ -863,6 +865,10 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseGeFloatIntScalarModule_basic",
"ElementwiseGeFloatScalarModule_basic",
"ElementwiseGeIntScalarModule_basic",
"ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",

View File

@ -4591,6 +4591,7 @@ public:
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)

View File

@ -6194,6 +6194,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sign\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -8148,6 +8152,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sign\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -4487,6 +4487,52 @@ public:
};
} // namespace
namespace {
// Decompose `aten.sign` op into comparisons and aten.where.
class DecomposeAtenSignOp : public OpRewritePattern<AtenSignOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSignOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
auto zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
auto one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
auto minusOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(-1.0));
auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(),
rewriter.getI1Type());
auto greater =
rewriter.create<AtenGtScalarOp>(loc, compTy, op.getSelf(), zero);
auto greaterEqual =
rewriter.create<AtenGeScalarOp>(loc, compTy, op.getSelf(), zero);
// Pseudo code:
// if (in >= 0)
// if (in > 0)
// return 1
// else
// return 0
// else
// return -1
auto selectGreater =
rewriter.create<AtenWhereScalarOp>(loc, outType, greater, one, zero);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(op, outType, greaterEqual,
selectGreater, minusOne);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -4654,6 +4700,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;

View File

@ -107,6 +107,9 @@ def atenneg〡shape(self: List[int]) -> List[int]:
def atenfloor〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atensign〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atendetach〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -1485,6 +1488,11 @@ def atenflip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> in
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensign〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenfloor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -1291,6 +1291,28 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseSignModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.sign(a)
@register_test_case(module_factory=lambda: ElementwiseSignModule())
def ElementwiseSignModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwisePowModule(torch.nn.Module):
def __init__(self):