[Torch Dialect] add RSub, ScalarImplicit canonicalize (#1899)

* add rsub, scalarimplit canonicalizer

* reformat

* address comments

* fix bug

* fix test

* Update elementwise.py

* resolve merge conflict

* change to 3

* change to 3

* real fix

* fix name

* add torchdynamo fail test

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/1919/merge
Zhekun Zhang 2023-03-06 17:38:27 -08:00 committed by GitHub
parent c2ef5f4165
commit 1d3a7419c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 112 additions and 5 deletions

View File

@ -91,6 +91,8 @@ TORCHDYNAMO_XFAIL_SET = {
"AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor)
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
# ERROR: assert isinstance(e, FakeTensor)
"RsubInt0d_NumToTensor_Module_basic",
}
STABLEHLO_PASS_SET = {
@ -316,6 +318,7 @@ STABLEHLO_PASS_SET = {
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"SliceStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
@ -513,6 +516,7 @@ TOSA_PASS_SET = {
"Matmul_3d",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic",

View File

@ -3389,6 +3389,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
@ -10569,6 +10570,7 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [

View File

@ -881,14 +881,17 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
return rewriter.notifyMatchFailure(
op, "only int scalar lhs or rhs is supported");
}
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
op)) {
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
AtenAddScalarOp>(op)) {
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
if (!alpha) {
return rewriter.notifyMatchFailure(op,
"only int scalar alpha is supported");
}
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
if (isa<AtenRsubScalarOp>(op))
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
else
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
}
if (isa<AtenDivTensorModeOp>(op)) {
@ -941,6 +944,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
} else if (isa<AtenRsubScalarOp>(op)) {
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
}
@ -988,6 +993,16 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenRSubScalarOp
//===----------------------------------------------------------------------===//
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}
//===----------------------------------------------------------------------===//
// AtenMulTensorOp
//===----------------------------------------------------------------------===//
@ -1018,6 +1033,23 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenScalarImplicitOp
//===----------------------------------------------------------------------===//
void AtenScalarImplicitOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenSizeOp
//===----------------------------------------------------------------------===//

View File

@ -318,7 +318,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::mish : (Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
@ -641,7 +641,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::eq.device : (Device, Device) -> (bool)")
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)")
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True)
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")

View File

@ -774,6 +774,26 @@ class RsubIntModule_noalpha(torch.nn.Module):
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=100))
# ==============================================================================
class RsubInt0d_NumToTensor_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
x = torch.ops.prim.NumToTensor(5)
return torch.rsub(x, 2)
@register_test_case(module_factory=lambda: RsubInt0d_NumToTensor_Module())
def RsubInt0d_NumToTensor_Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================

View File

@ -1838,3 +1838,52 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
return %0 : !torch.vtensor<[4],f32>
}
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64>
func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
// CHECK: return %[[VAL_1]] : !torch.number
func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
%int1 = torch.constant.int 1
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
return %1 : !torch.number
}
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
// CHECK: return %[[VAL_0]] : !torch.number
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
return %1 : !torch.number
}