mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add RSub, ScalarImplicit canonicalize (#1899)
* add rsub, scalarimplit canonicalizer * reformat * address comments * fix bug * fix test * Update elementwise.py * resolve merge conflict * change to 3 * change to 3 * real fix * fix name * add torchdynamo fail test --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/1919/merge
parent
c2ef5f4165
commit
1d3a7419c5
|
@ -91,6 +91,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"AtenIntTensorByteDtypeModule_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
# ERROR: assert isinstance(e, FakeTensor)
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
|
@ -316,6 +318,7 @@ STABLEHLO_PASS_SET = {
|
|||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
|
@ -513,6 +516,7 @@ TOSA_PASS_SET = {
|
|||
"Matmul_3d",
|
||||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"ElementwiseBitwiseAndModule_basic",
|
||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseNotInt32Module_basic",
|
||||
|
|
|
@ -3389,6 +3389,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||
|
@ -10569,6 +10570,7 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||
|
|
|
@ -881,14 +881,17 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only int scalar lhs or rhs is supported");
|
||||
}
|
||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
||||
op)) {
|
||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
|
||||
AtenAddScalarOp>(op)) {
|
||||
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
||||
if (!alpha) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only int scalar alpha is supported");
|
||||
}
|
||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||
if (isa<AtenRsubScalarOp>(op))
|
||||
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
|
||||
else
|
||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||
}
|
||||
|
||||
if (isa<AtenDivTensorModeOp>(op)) {
|
||||
|
@ -941,6 +944,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
|||
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
||||
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
||||
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
||||
} else if (isa<AtenRsubScalarOp>(op)) {
|
||||
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
|
||||
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
||||
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
||||
}
|
||||
|
@ -988,6 +993,16 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenRSubScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
|
||||
return rewrite0DBinaryTensorOp(op, rewriter);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1018,6 +1033,23 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenScalarImplicitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
Value a = op.getA();
|
||||
auto outType = op.getResult().getType();
|
||||
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
||||
if (!scalarValue)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -318,7 +318,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
|
@ -641,7 +641,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)")
|
||||
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True)
|
||||
|
||||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -774,6 +774,26 @@ class RsubIntModule_noalpha(torch.nn.Module):
|
|||
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, high=100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubInt0d_NumToTensor_Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
x = torch.ops.prim.NumToTensor(5)
|
||||
return torch.rsub(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubInt0d_NumToTensor_Module())
|
||||
def RsubInt0d_NumToTensor_Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -1838,3 +1838,52 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
|
|||
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %int-1 = torch.constant.int -1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %int-1 = torch.constant.int -1
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||
// CHECK: return %[[VAL_1]] : !torch.number
|
||||
func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||
// CHECK: %int1 = torch.constant.int 1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||
// CHECK: return %[[VAL_0]] : !torch.number
|
||||
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue