mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add RSub, ScalarImplicit canonicalize (#1899)
* add rsub, scalarimplit canonicalizer * reformat * address comments * fix bug * fix test * Update elementwise.py * resolve merge conflict * change to 3 * change to 3 * real fix * fix name * add torchdynamo fail test --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/1919/merge
parent
c2ef5f4165
commit
1d3a7419c5
|
@ -91,6 +91,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
"AtenIntTensorByteDtypeModule_basic",
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
# ERROR: assert isinstance(e, FakeTensor)
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
|
# ERROR: assert isinstance(e, FakeTensor)
|
||||||
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
@ -316,6 +318,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"RsubFloatModule_noalpha_basic",
|
"RsubFloatModule_noalpha_basic",
|
||||||
"RsubIntModule_basic",
|
"RsubIntModule_basic",
|
||||||
"RsubIntModule_noalpha_basic",
|
"RsubIntModule_noalpha_basic",
|
||||||
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
"SliceModule_basic",
|
"SliceModule_basic",
|
||||||
"SliceNegIdxModule_basic",
|
"SliceNegIdxModule_basic",
|
||||||
|
@ -513,6 +516,7 @@ TOSA_PASS_SET = {
|
||||||
"Matmul_3d",
|
"Matmul_3d",
|
||||||
"RsubFloatModule_basic",
|
"RsubFloatModule_basic",
|
||||||
"RsubFloatModule_noalpha_basic",
|
"RsubFloatModule_noalpha_basic",
|
||||||
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
"ElementwiseBitwiseAndModule_basic",
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||||
"ElementwiseBitwiseNotInt32Module_basic",
|
"ElementwiseBitwiseNotInt32Module_basic",
|
||||||
|
|
|
@ -3389,6 +3389,7 @@ def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
|
||||||
printDefaultTorchOp(printer, *this, 3, 1);
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
||||||
|
@ -10569,6 +10570,7 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||||
|
|
|
@ -881,14 +881,17 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only int scalar lhs or rhs is supported");
|
op, "only int scalar lhs or rhs is supported");
|
||||||
}
|
}
|
||||||
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenAddTensorOp, AtenAddScalarOp>(
|
if (isa<AtenSubTensorOp, AtenSubScalarOp, AtenRsubScalarOp, AtenAddTensorOp,
|
||||||
op)) {
|
AtenAddScalarOp>(op)) {
|
||||||
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter);
|
||||||
if (!alpha) {
|
if (!alpha) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only int scalar alpha is supported");
|
"only int scalar alpha is supported");
|
||||||
}
|
}
|
||||||
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
if (isa<AtenRsubScalarOp>(op))
|
||||||
|
lhs = rewriter.create<AtenMulIntOp>(loc, lhs, alpha);
|
||||||
|
else
|
||||||
|
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenDivTensorModeOp>(op)) {
|
if (isa<AtenDivTensorModeOp>(op)) {
|
||||||
|
@ -941,6 +944,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
result = rewriter.create<AtenAddIntOp>(loc, lhs, rhs);
|
||||||
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
} else if (isa<AtenSubScalarOp, AtenSubTensorOp>(op)) {
|
||||||
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
result = rewriter.create<AtenSubIntOp>(loc, lhs, rhs);
|
||||||
|
} else if (isa<AtenRsubScalarOp>(op)) {
|
||||||
|
result = rewriter.create<AtenSubIntOp>(loc, rhs, lhs);
|
||||||
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
} else if (isa<AtenMulScalarOp, AtenMulTensorOp>(op)) {
|
||||||
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
result = rewriter.create<AtenMulIntOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
@ -988,6 +993,16 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenRSubScalarOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) {
|
||||||
|
return rewrite0DBinaryTensorOp(op, rewriter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenMulTensorOp
|
// AtenMulTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1018,6 +1033,23 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenScalarImplicitOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void AtenScalarImplicitOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value a = op.getA();
|
||||||
|
auto outType = op.getResult().getType();
|
||||||
|
Value scalarValue = getScalarIntValue(a, loc, rewriter);
|
||||||
|
if (!scalarValue)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSizeOp
|
// AtenSizeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -318,7 +318,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
@ -641,7 +641,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||||
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||||
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)")
|
emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True)
|
||||||
|
|
||||||
# backprop ops
|
# backprop ops
|
||||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
|
|
|
@ -774,6 +774,26 @@ class RsubIntModule_noalpha(torch.nn.Module):
|
||||||
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
|
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, high=100))
|
module.forward(tu.randint(3, 4, high=100))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class RsubInt0d_NumToTensor_Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
x = torch.ops.prim.NumToTensor(5)
|
||||||
|
return torch.rsub(x, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RsubInt0d_NumToTensor_Module())
|
||||||
|
def RsubInt0d_NumToTensor_Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -1838,3 +1838,52 @@ func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<
|
||||||
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
|
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
|
||||||
return %0 : !torch.vtensor<[4],f32>
|
return %0 : !torch.vtensor<[4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||||
|
// CHECK: %int-1 = torch.constant.int -1
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64>
|
||||||
|
func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
return %2 : !torch.vtensor<[],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||||
|
// CHECK: %int-1 = torch.constant.int -1
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64>
|
||||||
|
func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%2 = torch.aten.rsub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
return %2 : !torch.vtensor<[],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||||
|
// CHECK: return %[[VAL_1]] : !torch.number
|
||||||
|
func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
|
||||||
|
return %1 : !torch.number
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
|
||||||
|
// CHECK: return %[[VAL_0]] : !torch.number
|
||||||
|
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
|
||||||
|
%0 = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number
|
||||||
|
return %1 : !torch.number
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue