From ed4ecb072f0b944f050bb0bca2461c639abcd353 Mon Sep 17 00:00:00 2001 From: TatWai Chong <78814694+tatwaichong@users.noreply.github.com> Date: Thu, 18 May 2023 17:12:18 -0700 Subject: [PATCH] [tosa] support lowering basic torch binary ops with mixed dtypes (#2122) Lowering torch operations that allow different compatible data types in its operands to tosa end up generating invalid tosa IR with mixed data types. In tosa spec, certain operations (generally element-wise operations) require all operands to have the same data type. Add wrapper functions for those element-wise tosa ops to perform op creation with type conversion if necessary. --- .../TorchToTosa/TosaLegalizeCommon.h | 27 +++- .../TorchToTosa/TosaLegalizeUtils.h | 4 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 123 +++++++---------- .../TorchToTosa/TosaLegalizeCommon.cpp | 28 +++- .../TorchToTosa/TosaLegalizeUtils.cpp | 21 +++ ...orch-backend-to-tosa-backend-pipeline.mlir | 126 ++++++++++++++++++ 6 files changed, 250 insertions(+), 79 deletions(-) create mode 100644 test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 1ef3ae8a4..3ff4581d6 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -10,8 +10,11 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace tosa { @@ -21,6 +24,26 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indiceOneDimShape, int32_t dim, ArrayRef indexShape); +mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs, + int32_t shift); + +// Create TOSA elementwise binary op with type conversion if necessary. +template +TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { + lhs = promoteType(rewriter, lhs, outType); + rhs = promoteType(rewriter, rhs, outType); + return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); +} + +// This specialization is for Div op. Unlike other binary ops, it doesn't support +// floating type. +template <> +tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, + Operation *op, TensorType outType, + Value lhs, Value rhs); + std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, Operation *op, Value params_value, diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 717972ae9..39cb1eacc 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type); Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); +// Create a zero constant tensor of the desired type and shape. +std::optional getZerosLikeTensor(PatternRewriter &rewriter, + Operation *op, Type type); + // Templated function to create a constant op for given type and shape. // T: storage C type. // Default template creates a constant tensor in T. diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5b8aee0cd..eeae753cf 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -100,17 +100,13 @@ public: return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto lhsElemTy = lhsTy.getElementType(); - auto rhsElemTy = rhsTy.getElementType(); + auto outTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); - if (lhsElemTy != rhsElemTy) - return rewriter.notifyMatchFailure(op, "Input datatypes mismatched"); - - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - lhs, rhs); + auto binaryOp = + tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); + rewriter.replaceOp(op, binaryOp.getResult()); return success(); } }; @@ -291,52 +287,30 @@ public: "alpha in conversion to TOSA operation"); } - // make sure input of MulOp is same datetype, otherwise the lowering to - // arith dialect will bug - auto multTensor = rewriter.create( - op.getLoc(), + auto mulAlphaOp = tosa::createMulOpAndCast( + rewriter, op, rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType), rhsTensor, alphaTensor, /*shift=*/0); - if (outElemTy.isa() || outElemTy.isInteger(32)) { - // if outElemTy tensor, mulTensor must be tensor, - // left value could be tensor, cast left value to - // tensor type - // if outElemTy tensor, mulTensor must be tensor, - // left value could be tensor, cast left value to - // tensor type - if (lhsType.getElementType() != rhsAlphaMulElemType) - lhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType), - lhs); - - rewriter.replaceOpWithNewOp(op, outType, lhs, multTensor); - - return success(); - } else if (outElemTy.isInteger(64)) { + if (outElemTy.isInteger(64)) { + // Tosa doesn't support 64-bit elementwise addition and subtraction. // if outElemTy tensor, mulTensor must be tensor, // left value could be tensor type, cast left value to // tensor type - if (lhsType.getElementType() != rhsAlphaMulElemType) - lhs = rewriter.create( - op.getLoc(), - RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType), - lhs); - - auto tosaOpTOutputTensor = rewriter.create( - op.getLoc(), + auto addOrSubi64Op = tosa::createBinaryOpAndCast( + rewriter, op, RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs, - multTensor); - // cast tensor back to tensor - rewriter.replaceOpWithNewOp(op, outType, - tosaOpTOutputTensor); + mulAlphaOp); + // cast tensor back to tensor + rewriter.replaceOpWithNewOp(op, outType, addOrSubi64Op); return success(); - } else { - return rewriter.notifyMatchFailure( - op, "Only floating-point, i32, i64 datatype legalization supported"); } + + auto binaryOp = tosa::createBinaryOpAndCast(rewriter, op, outType, + lhs, mulAlphaOp); + rewriter.replaceOp(op, binaryOp.getResult()); + return success(); } }; // namespace @@ -457,15 +431,13 @@ public: if (outElemTy.isa() || outElemTy.isa()) { - if (lhsType.getElementType() != outElemTy) - lhs = rewriter.create(op.getLoc(), outType, lhs); + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - lhs, rhsTensor, - /*shift=*/0); + auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs, + rhsTensor, /*shift=*/0); + rewriter.replaceOp(op, mulOp.getResult()); return success(); } @@ -507,23 +479,27 @@ public: "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + // auto result; + Value result; if (lhsElemTy.isa()) { auto rcpOp = rewriter.create( op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy), rhsTensor); - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - lhs, rcpOp.getResult(), /*shift=*/0); + + result = tosa::createMulOpAndCast(rewriter, op, outType, lhs, + rcpOp.getResult(), /*shift=*/0) + .getResult(); } else { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - lhs, rhsTensor); + result = tosa::createBinaryOpAndCast(rewriter, op, outType, + lhs, rhsTensor) + .getResult(); } + + rewriter.replaceOp(op, {result}); return success(); } }; @@ -1033,8 +1009,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Currently only scalar constants are supported for " "conversion in TOSA Pow operation"); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, expTensor); + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + + auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, + self, expTensor); + rewriter.replaceOp(op, powOp.getResult()); return success(); } @@ -3289,15 +3269,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // +0. (sign bit flips). These are probably acceptable in the short term, // but we should put a comment acknowledging the danger, as there isn't an // op that avoids the denorm flushing. - SmallVector intValues(totalNumElements, 0); - SmallVector floatValues(totalNumElements, 0.0); - Value zeroTensor = selfType.getElementType().isa() - ? tosa::getConstTensor( - rewriter, op, floatValues, zeroTensorShape) - .value() - : tosa::getConstTensor( - rewriter, op, intValues, zeroTensorShape) - .value(); + Value zeroTensor = + tosa::getZerosLikeTensor(rewriter, op, resultType).value(); // Use add broadcast rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index ca5ef974f..2bb6045d9 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include @@ -19,7 +18,6 @@ #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -105,6 +103,32 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, return indicesDim; } +tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs, + int32_t shift) { + lhs = promoteType(rewriter, lhs, outType); + rhs = promoteType(rewriter, rhs, outType); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, + lhs, rhs, shift); +} + +template <> +tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, + Operation *op, TensorType outType, + Value lhs, Value rhs) { + auto lhsElemTy = lhs.getType().cast().getElementType(); + auto rhsElemTy = rhs.getType().cast().getElementType(); + if (lhsElemTy.isa() || rhsElemTy.isa()) { + (void)rewriter.notifyMatchFailure(op, + "tosa.div only supports integer type"); + } + + lhs = promoteType(rewriter, lhs, outType); + rhs = promoteType(rewriter, rhs, outType); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), outType, + lhs, rhs); +} + std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, Operation *op, Value paramsValue, diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index fa56ca39d..c4f8d2b0b 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -149,6 +149,27 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +// Create a zero constant tensor of the desired type and shape. +std::optional getZerosLikeTensor(PatternRewriter &rewriter, + Operation *op, Type type) { + RankedTensorType resultType = type.dyn_cast(); + + if (!resultType) { + (void)rewriter.notifyMatchFailure(op, "not ranked tensor type"); + return std::nullopt; + } + + auto resultShape = resultType.getShape(); + ShapedType zeroType = + RankedTensorType::get(resultShape, resultType.getElementType()); + Attribute zeroAttr = rewriter.getZeroAttr(zeroType); + + return CreateOpAndInfer(rewriter, op->getLoc(), zeroType, + zeroAttr.cast()) + .getResult(); +} + + // Templated function to create a constant op for given type and shape. // T: storage C type. // Default template creates a constant tensor in T. diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir new file mode 100644 index 000000000..94dd0aed5 --- /dev/null +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -0,0 +1,126 @@ +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> +// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16> + return %0 : !torch.vtensor<[5],bf16> +} + +// ----- + +// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp +// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16> +// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16> +func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> { + %float1 = torch.constant.float 1.000000e+00 + %0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16> + return %0 : !torch.vtensor<[6],bf16> +} + +// ----- + +// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16> +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32> + return %0 : !torch.vtensor<[5],f32> +} + +// ----- + +// CHECK-LABEL: torch.aten.Scalar$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32> +func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> { + %int1 = torch.constant.int 1 + %int256 = torch.constant.int 256 + %0 = torch.aten.add.Scalar %arg0, %int256, %int1 : !torch.vtensor<[1,1,32,64],si16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,32,64],si32> + return %0 : !torch.vtensor<[1,1,32,64],si32> +} + +// ----- + +// CHECK-LABEL: torch.aten.sub.Scalar$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor, tensor) -> tensor +func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16> + return %0 : !torch.vtensor<[],bf16> +} + +// ----- + +// CHECK-LABEL: torch.aten.maximum$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32> +func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> { + %0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32> + return %0 : !torch.vtensor<[1,3,1],f32> +} + +// ----- + +// CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0 : !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor +// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> + return %0 : !torch.vtensor<[?, ?],si32> +} + +// ----- + +// CHECK-LABEL: torch.aten.pow.Tensor$mixed_type +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor<1x1xf32>) -> tensor +func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> { + %fp0 = torch.constant.float 3.123400e+00 + %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} +