mirror of https://github.com/llvm/torch-mlir
[Torch][MHLO] Decompose aten.copy op. Lower aten.rsqrt & sigmoid to mhlo. (#1734)
parent
9dc09ac8c5
commit
15b249777b
|
@ -108,6 +108,8 @@ MHLO_PASS_SET = {
|
|||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSqrtModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
|
@ -186,6 +188,9 @@ MHLO_PASS_SET = {
|
|||
"IndexSelectTwoIdxModule_basic",
|
||||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"MatmulBroadcastBatchDim_basic",
|
||||
"MatmulSingleDynamicBatchDim_basic",
|
||||
"Matmul_3d",
|
||||
|
@ -197,6 +202,8 @@ MHLO_PASS_SET = {
|
|||
"MeanModule_basic",
|
||||
"MmTanhModule_basic",
|
||||
"Mv_basic",
|
||||
"NativeLayerNormModule4D_basic",
|
||||
"NativeLayerNormModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||
|
@ -212,6 +219,10 @@ MHLO_PASS_SET = {
|
|||
"SliceSingleIdxModule_basic",
|
||||
"SqueezeDimModule_dynamic",
|
||||
"SqueezeDimModule_negDim",
|
||||
"ToCopyBoolDTypeStaticModule_basic",
|
||||
"ToCopyModule_basic",
|
||||
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||
"ToCopyWithDTypeModule_basic",
|
||||
"ReduceFrobeniusNormModule_basic",
|
||||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
|
|
|
@ -1313,6 +1313,8 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp);
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
|
|
|
@ -5472,6 +5472,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
|
@ -5479,15 +5480,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %3 : !torch.bool\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
|
|
|
@ -2512,6 +2512,23 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.copy` op into `aten.to.dtype` and `aten.expand_as`.
|
||||
class DecomposeAtenCopyOp : public OpRewritePattern<AtenCopyOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCopyOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type resultDtype = op.getType().cast<BaseTensorType>().getDtype();
|
||||
Value srcToDtype =
|
||||
convertTensorToDtype(rewriter, op.getLoc(), op.getSrc(), resultDtype);
|
||||
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, op.getType(), srcToDtype,
|
||||
op.getSelf());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.newEmpty` op into `aten.empty.memoryFormat` op.
|
||||
class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
||||
|
@ -3476,6 +3493,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||
|
|
|
@ -63,7 +63,7 @@ def aten〇tanh〡shape(self: List[int]) -> List[int]:
|
|||
Invocation(ZeroDTensorWithDtype(torch.bool)),
|
||||
])
|
||||
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16:
|
||||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
|
||||
return self_dtype
|
||||
else:
|
||||
return torch.float32
|
||||
|
|
|
@ -75,6 +75,32 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsqrt$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue