[Torch][MHLO] Decompose aten.copy op. Lower aten.rsqrt & sigmoid to mhlo. (#1734)

pull/1744/head
Jiahao Li 2022-12-22 10:13:59 +08:00 committed by GitHub
parent 9dc09ac8c5
commit 15b249777b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 5 deletions

View File

@ -108,6 +108,8 @@ MHLO_PASS_SET = {
"ElementwiseExpModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
@ -186,6 +188,9 @@ MHLO_PASS_SET = {
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulSingleDynamicBatchDim_basic",
"Matmul_3d",
@ -197,6 +202,8 @@ MHLO_PASS_SET = {
"MeanModule_basic",
"MmTanhModule_basic",
"Mv_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"PrimsConvertElementTypeModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
@ -212,6 +219,10 @@ MHLO_PASS_SET = {
"SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
"ToCopyBoolDTypeStaticModule_basic",
"ToCopyModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"ReduceFrobeniusNormModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",

View File

@ -1313,6 +1313,8 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
#undef INSERT_UNARY_FPONLY_PATTERN
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \

View File

@ -5472,6 +5472,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n"
@ -5479,15 +5480,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %3 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"

View File

@ -2512,6 +2512,23 @@ public:
};
} // namespace
namespace {
// Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`.
class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopyOp op,
PatternRewriter &rewriter) const override {
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
Value srcToDtype =
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
op.getSelf());
return success();
}
};
} // namespace
namespace {
// Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op.
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
@ -3476,6 +3493,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);

View File

@ -63,7 +63,7 @@ def atentanh〡shape(self: List[int]) -> List[int]:
Invocation(ZeroDTensorWithDtype(torch.bool)),
])
def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:
if self_dtype == torch.float64 or self_dtype == torch.bfloat16:
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
return self_dtype
else:
return torch.float32

View File

@ -75,6 +75,32 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
// -----
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>