mirror of https://github.com/llvm/torch-mlir
[tosa] support lowering basic torch binary ops with mixed dtypes (#2122)
Lowering torch operations that allow different compatible data types in its operands to tosa end up generating invalid tosa IR with mixed data types. In tosa spec, certain operations (generally element-wise operations) require all operands to have the same data type. Add wrapper functions for those element-wise tosa ops to perform op creation with type conversion if necessary.pull/2134/head
parent
5698893ae4
commit
ed4ecb072f
|
@ -10,8 +10,11 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
|
@ -21,6 +24,26 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
|||
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
|
||||
ArrayRef<int64_t> indexShape);
|
||||
|
||||
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||
TensorType outType, Value lhs, Value rhs,
|
||||
int32_t shift);
|
||||
|
||||
// Create TOSA elementwise binary op with type conversion if necessary.
|
||||
template <typename TosaOpT>
|
||||
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||
TensorType outType, Value lhs, Value rhs) {
|
||||
lhs = promoteType(rewriter, lhs, outType);
|
||||
rhs = promoteType(rewriter, rhs, outType);
|
||||
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
|
||||
}
|
||||
|
||||
// This specialization is for Div op. Unlike other binary ops, it doesn't support
|
||||
// floating type.
|
||||
template <>
|
||||
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||
Operation *op, TensorType outType,
|
||||
Value lhs, Value rhs);
|
||||
|
||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
Value params_value,
|
||||
|
|
|
@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
|
|||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||
float val);
|
||||
|
||||
// Create a zero constant tensor of the desired type and shape.
|
||||
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Type type);
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
|
|
|
@ -100,17 +100,13 @@ public:
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return rewriter.notifyMatchFailure(op, "Input datatypes mismatched");
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhs);
|
||||
auto binaryOp =
|
||||
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||
rewriter.replaceOp(op, binaryOp.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -291,52 +287,30 @@ public:
|
|||
"alpha in conversion to TOSA operation");
|
||||
}
|
||||
|
||||
// make sure input of MulOp is same datetype, otherwise the lowering to
|
||||
// arith dialect will bug
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(),
|
||||
auto mulAlphaOp = tosa::createMulOpAndCast(
|
||||
rewriter, op,
|
||||
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
|
||||
rhsTensor, alphaTensor, /*shift=*/0);
|
||||
|
||||
if (outElemTy.isa<mlir::FloatType>() || outElemTy.isInteger(32)) {
|
||||
// if outElemTy tensor<f32>, mulTensor must be tensor<f32>,
|
||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
||||
// tensor<f32> type
|
||||
// if outElemTy tensor<i32>, mulTensor must be tensor<i32>,
|
||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
||||
// tensor<i32> type
|
||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
||||
lhs = rewriter.create<tosa::CastOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
||||
lhs);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
|
||||
|
||||
return success();
|
||||
} else if (outElemTy.isInteger(64)) {
|
||||
if (outElemTy.isInteger(64)) {
|
||||
// Tosa doesn't support 64-bit elementwise addition and subtraction.
|
||||
// if outElemTy tensor<i64>, mulTensor must be tensor<i32>,
|
||||
// left value could be tensor<f32/i32/i64> type, cast left value to
|
||||
// tensor<i32> type
|
||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
||||
lhs = rewriter.create<tosa::CastOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
||||
lhs);
|
||||
|
||||
auto tosaOpTOutputTensor = rewriter.create<TosaOpT>(
|
||||
op.getLoc(),
|
||||
auto addOrSubi64Op = tosa::createBinaryOpAndCast<TosaOpT>(
|
||||
rewriter, op,
|
||||
RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs,
|
||||
multTensor);
|
||||
// cast tensor<i32> back to tensor<i64>
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||
tosaOpTOutputTensor);
|
||||
mulAlphaOp);
|
||||
|
||||
// cast tensor<i32> back to tensor<i64>
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, addOrSubi64Op);
|
||||
return success();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point, i32, i64 datatype legalization supported");
|
||||
}
|
||||
|
||||
auto binaryOp = tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outType,
|
||||
lhs, mulAlphaOp);
|
||||
rewriter.replaceOp(op, binaryOp.getResult());
|
||||
return success();
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
|
@ -457,15 +431,13 @@ public:
|
|||
|
||||
if (outElemTy.isa<mlir::FloatType>() ||
|
||||
outElemTy.isa<mlir::IntegerType>()) {
|
||||
if (lhsType.getElementType() != outElemTy)
|
||||
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhsTensor,
|
||||
/*shift=*/0);
|
||||
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||
rhsTensor, /*shift=*/0);
|
||||
rewriter.replaceOp(op, mulOp.getResult());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -507,23 +479,27 @@ public:
|
|||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
// auto result;
|
||||
Value result;
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
||||
rhsTensor);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rcpOp.getResult(), /*shift=*/0);
|
||||
|
||||
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||
rcpOp.getResult(), /*shift=*/0)
|
||||
.getResult();
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tosa::DivOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhsTensor);
|
||||
result = tosa::createBinaryOpAndCast<tosa::DivOp>(rewriter, op, outType,
|
||||
lhs, rhsTensor)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1033,8 +1009,12 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::PowOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
|
||||
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
|
||||
self, expTensor);
|
||||
rewriter.replaceOp(op, powOp.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3289,15 +3269,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
||||
// but we should put a comment acknowledging the danger, as there isn't an
|
||||
// op that avoids the denorm flushing.
|
||||
SmallVector<int64_t> intValues(totalNumElements, 0);
|
||||
SmallVector<float> floatValues(totalNumElements, 0.0);
|
||||
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
|
||||
? tosa::getConstTensor<float>(
|
||||
rewriter, op, floatValues, zeroTensorShape)
|
||||
.value()
|
||||
: tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, intValues, zeroTensorShape)
|
||||
.value();
|
||||
Value zeroTensor =
|
||||
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
|
||||
|
||||
// Use add broadcast
|
||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include <climits>
|
||||
|
@ -19,7 +18,6 @@
|
|||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
|
@ -105,6 +103,32 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
|||
return indicesDim;
|
||||
}
|
||||
|
||||
tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||
TensorType outType, Value lhs, Value rhs,
|
||||
int32_t shift) {
|
||||
lhs = promoteType(rewriter, lhs, outType);
|
||||
rhs = promoteType(rewriter, rhs, outType);
|
||||
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
|
||||
lhs, rhs, shift);
|
||||
}
|
||||
|
||||
template <>
|
||||
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||
Operation *op, TensorType outType,
|
||||
Value lhs, Value rhs) {
|
||||
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
||||
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
||||
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) {
|
||||
(void)rewriter.notifyMatchFailure(op,
|
||||
"tosa.div only supports integer type");
|
||||
}
|
||||
|
||||
lhs = promoteType(rewriter, lhs, outType);
|
||||
rhs = promoteType(rewriter, rhs, outType);
|
||||
return tosa::CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), outType,
|
||||
lhs, rhs);
|
||||
}
|
||||
|
||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
Value paramsValue,
|
||||
|
|
|
@ -149,6 +149,27 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
|||
return const_op.getResult();
|
||||
}
|
||||
|
||||
// Create a zero constant tensor of the desired type and shape.
|
||||
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Type type) {
|
||||
RankedTensorType resultType = type.dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!resultType) {
|
||||
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto resultShape = resultType.getShape();
|
||||
ShapedType zeroType =
|
||||
RankedTensorType::get(resultShape, resultType.getElementType());
|
||||
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
|
||||
|
||||
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
|
||||
zeroAttr.cast<ElementsAttr>())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
||||
// Templated function to create a constant op for given type and shape.
|
||||
// T: storage C type.
|
||||
// Default template creates a constant tensor in T.
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16>
|
||||
func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> {
|
||||
%float2.000000e00 = torch.constant.float 2.000000e+00
|
||||
%0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16>
|
||||
return %0 : !torch.vtensor<[5],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16>
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16>
|
||||
func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> {
|
||||
%float1 = torch.constant.float 1.000000e+00
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16>
|
||||
return %0 : !torch.vtensor<[6],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32>
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||
func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32>
|
||||
return %0 : !torch.vtensor<[5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.Scalar$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32>
|
||||
func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int256 = torch.constant.int 256
|
||||
%0 = torch.aten.add.Scalar %arg0, %int256, %int1 : !torch.vtensor<[1,1,32,64],si16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,32,64],si32>
|
||||
return %0 : !torch.vtensor<[1,1,32,64],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.sub.Scalar$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<bf16>,
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<bf16>}> : () -> tensor<bf16>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor<bf16>, tensor<bf16>) -> tensor<bf16>
|
||||
func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16>
|
||||
return %0 : !torch.vtensor<[],bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.maximum$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32>
|
||||
func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> {
|
||||
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32>
|
||||
return %0 : !torch.vtensor<[1,3,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0 : !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32>
|
||||
return %0 : !torch.vtensor<[?, ?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
|
||||
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
|
||||
return %0 : !torch.vtensor<[?, ?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: torch.aten.pow.Tensor$mixed_type
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf16>
|
||||
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xf16>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
|
||||
func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> {
|
||||
%fp0 = torch.constant.float 3.123400e+00
|
||||
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue