diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index a8f77c112..71ca9de23 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -178,6 +178,9 @@ public: if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); + if (isa(op)) { + std::swap(lhs, rhs); + } } lhs = mhlo::promoteType(rewriter, lhs, outType); @@ -1117,6 +1120,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 655f62a89..aaf5720aa 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -197,6 +197,27 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- +// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor