[MHLO] Add AtenRSubScalarOp conversion pattern to MHLO (#1233)

* [MHLO] Add AtenRSubScalarOp conversion pattern
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1230/head
武家伟 2022-08-17 09:07:36 +08:00 committed by GitHub
parent fde390c766
commit 11a5b5ac52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 0 deletions

View File

@ -178,6 +178,9 @@ public:
if (!rhsType) { if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs);
}
} }
lhs = mhlo::promoteType(rewriter, lhs, outType); lhs = mhlo::promoteType(rewriter, lhs, outType);
@ -1117,6 +1120,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp);
#undef INSERT_BINARY_ADDSUB_PATTERN #undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \

View File

@ -197,6 +197,27 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int9 = torch.constant.int 9
%int1 = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>