[tosa] support lowering basic torch binary ops with mixed dtypes (#2122)

Lowering torch operations that allow different compatible data types
in its operands to tosa end up generating invalid tosa IR with mixed
data types. In tosa spec, certain operations (generally element-wise
operations) require all operands to have the same data type.

Add wrapper functions for those element-wise tosa ops to perform op
creation with type conversion if necessary.
pull/2134/head
TatWai Chong 2023-05-18 17:12:18 -07:00 committed by GitHub
parent 5698893ae4
commit ed4ecb072f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 79 deletions

View File

@ -10,6 +10,9 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
@ -21,6 +24,26 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
ArrayRef<int64_t> indexShape);
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift);
// Create TOSA elementwise binary op with type conversion if necessary.
template <typename TosaOpT>
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
}
// This specialization is for Div op. Unlike other binary ops, it doesn't support
// floating type.
template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
Value lhs, Value rhs);
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
Operation *op,
Value params_value,

View File

@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type);
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.

View File

@ -100,17 +100,13 @@ public:
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
if (lhsElemTy != rhsElemTy)
return rewriter.notifyMatchFailure(op, "Input datatypes mismatched");
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhs);
auto binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
rewriter.replaceOp(op, binaryOp.getResult());
return success();
}
};
@ -291,52 +287,30 @@ public:
"alpha in conversion to TOSA operation");
}
// make sure input of MulOp is same datetype, otherwise the lowering to
// arith dialect will bug
auto multTensor = rewriter.create<tosa::MulOp>(
op.getLoc(),
auto mulAlphaOp = tosa::createMulOpAndCast(
rewriter, op,
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
rhsTensor, alphaTensor, /*shift=*/0);
if (outElemTy.isa<mlir::FloatType>() || outElemTy.isInteger(32)) {
// if outElemTy tensor<f32>, mulTensor must be tensor<f32>,
// left value could be tensor<f32/i32/i64>, cast left value to
// tensor<f32> type
// if outElemTy tensor<i32>, mulTensor must be tensor<i32>,
// left value could be tensor<f32/i32/i64>, cast left value to
// tensor<i32> type
if (lhsType.getElementType() != rhsAlphaMulElemType)
lhs = rewriter.create<tosa::CastOp>(
op.getLoc(),
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
lhs);
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
return success();
} else if (outElemTy.isInteger(64)) {
if (outElemTy.isInteger(64)) {
// Tosa doesn't support 64-bit elementwise addition and subtraction.
// if outElemTy tensor<i64>, mulTensor must be tensor<i32>,
// left value could be tensor<f32/i32/i64> type, cast left value to
// tensor<i32> type
if (lhsType.getElementType() != rhsAlphaMulElemType)
lhs = rewriter.create<tosa::CastOp>(
op.getLoc(),
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
lhs);
auto tosaOpTOutputTensor = rewriter.create<TosaOpT>(
op.getLoc(),
auto addOrSubi64Op = tosa::createBinaryOpAndCast<TosaOpT>(
rewriter, op,
RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs,
multTensor);
// cast tensor<i32> back to tensor<i64>
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
tosaOpTOutputTensor);
mulAlphaOp);
// cast tensor<i32> back to tensor<i64>
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, addOrSubi64Op);
return success();
} else {
return rewriter.notifyMatchFailure(
op, "Only floating-point, i32, i64 datatype legalization supported");
}
auto binaryOp = tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outType,
lhs, mulAlphaOp);
rewriter.replaceOp(op, binaryOp.getResult());
return success();
}
}; // namespace
@ -457,15 +431,13 @@ public:
if (outElemTy.isa<mlir::FloatType>() ||
outElemTy.isa<mlir::IntegerType>()) {
if (lhsType.getElementType() != outElemTy)
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhsTensor,
/*shift=*/0);
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rhsTensor, /*shift=*/0);
rewriter.replaceOp(op, mulOp.getResult());
return success();
}
@ -507,23 +479,27 @@ public:
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
// auto result;
Value result;
if (lhsElemTy.isa<mlir::FloatType>()) {
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
rhsTensor);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rcpOp.getResult(), /*shift=*/0);
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
rcpOp.getResult(), /*shift=*/0)
.getResult();
} else {
rewriter.replaceOpWithNewOp<tosa::DivOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhsTensor);
result = tosa::createBinaryOpAndCast<tosa::DivOp>(rewriter, op, outType,
lhs, rhsTensor)
.getResult();
}
rewriter.replaceOp(op, {result});
return success();
}
};
@ -1033,8 +1009,12 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
rewriter.replaceOpWithNewOp<tosa::PowOp>(
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
self, expTensor);
rewriter.replaceOp(op, powOp.getResult());
return success();
}
@ -3289,15 +3269,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
SmallVector<int64_t> intValues(totalNumElements, 0);
SmallVector<float> floatValues(totalNumElements, 0.0);
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
? tosa::getConstTensor<float>(
rewriter, op, floatValues, zeroTensorShape)
.value()
: tosa::getConstTensor<int64_t>(
rewriter, op, intValues, zeroTensorShape)
.value();
Value zeroTensor =
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),

View File

@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include <climits>
@ -19,7 +18,6 @@
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
@ -105,6 +103,32 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
return indicesDim;
}
tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
TensorType outType, Value lhs, Value rhs,
int32_t shift) {
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
lhs, rhs, shift);
}
template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
Value lhs, Value rhs) {
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) {
(void)rewriter.notifyMatchFailure(op,
"tosa.div only supports integer type");
}
lhs = promoteType(rewriter, lhs, outType);
rhs = promoteType(rewriter, rhs, outType);
return tosa::CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), outType,
lhs, rhs);
}
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
Operation *op,
Value paramsValue,

View File

@ -149,6 +149,27 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
return const_op.getResult();
}
// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) {
RankedTensorType resultType = type.dyn_cast<RankedTensorType>();
if (!resultType) {
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
return std::nullopt;
}
auto resultShape = resultType.getShape();
ShapedType zeroType =
RankedTensorType::get(resultShape, resultType.getElementType());
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
zeroAttr.cast<ElementsAttr>())
.getResult();
}
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.

View File

@ -0,0 +1,126 @@
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16>
func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> {
%float2.000000e00 = torch.constant.float 2.000000e+00
%0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16>
return %0 : !torch.vtensor<[5],bf16>
}
// -----
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16>
// CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16>
// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16>
func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> {
%float1 = torch.constant.float 1.000000e+00
%0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16>
return %0 : !torch.vtensor<[6],bf16>
}
// -----
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32>
// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32>
return %0 : !torch.vtensor<[5],f32>
}
// -----
// CHECK-LABEL: torch.aten.Scalar$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32>
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32>
func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> {
%int1 = torch.constant.int 1
%int256 = torch.constant.int 256
%0 = torch.aten.add.Scalar %arg0, %int256, %int1 : !torch.vtensor<[1,1,32,64],si16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,32,64],si32>
return %0 : !torch.vtensor<[1,1,32,64],si32>
}
// -----
// CHECK-LABEL: torch.aten.sub.Scalar$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<bf16>,
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<bf16>}> : () -> tensor<bf16>
// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor<bf16>, tensor<bf16>) -> tensor<bf16>
func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16>
return %0 : !torch.vtensor<[],bf16>
}
// -----
// CHECK-LABEL: torch.aten.maximum$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32>
func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> {
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32>
return %0 : !torch.vtensor<[1,3,1],f32>
}
// -----
// CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}
// -----
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32>
return %0 : !torch.vtensor<[?, ?],f32>
}
// -----
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
return %0 : !torch.vtensor<[?, ?],si32>
}
// -----
// CHECK-LABEL: torch.aten.pow.Tensor$mixed_type
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf16>
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xf16>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 3.123400e+00
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}