mirror of https://github.com/llvm/torch-mlir
[MHLO] Add AtenRSubScalarOp conversion pattern to MHLO (#1233)
* [MHLO] Add AtenRSubScalarOp conversion pattern Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1230/head
parent
fde390c766
commit
11a5b5ac52
|
@ -178,6 +178,9 @@ public:
|
||||||
|
|
||||||
if (!rhsType) {
|
if (!rhsType) {
|
||||||
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
|
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
|
||||||
|
if (isa<AtenRsubScalarOp>(op)) {
|
||||||
|
std::swap(lhs, rhs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||||
|
@ -1117,6 +1120,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
|
||||||
|
INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp);
|
||||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
||||||
|
|
|
@ -197,6 +197,27 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[INT9:.*]] = torch.constant.int 9
|
||||||
|
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
|
||||||
|
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
|
||||||
|
// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%int9 = torch.constant.int 9
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
|
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
|
Loading…
Reference in New Issue