mirror of https://github.com/llvm/torch-mlir
[tosa] support lowering basic torch binary ops with mixed dtypes (#2122)
Lowering torch operations that allow different compatible data types in its operands to tosa end up generating invalid tosa IR with mixed data types. In tosa spec, certain operations (generally element-wise operations) require all operands to have the same data type. Add wrapper functions for those element-wise tosa ops to perform op creation with type conversion if necessary.pull/2134/head
parent
5698893ae4
commit
ed4ecb072f
|
@ -10,6 +10,9 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZECOMMON_H
|
||||||
|
|
||||||
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||||
|
|
||||||
|
@ -21,6 +24,26 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
|
SmallVector<int64_t> indiceOneDimShape, int32_t dim,
|
||||||
ArrayRef<int64_t> indexShape);
|
ArrayRef<int64_t> indexShape);
|
||||||
|
|
||||||
|
mlir::tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||||
|
TensorType outType, Value lhs, Value rhs,
|
||||||
|
int32_t shift);
|
||||||
|
|
||||||
|
// Create TOSA elementwise binary op with type conversion if necessary.
|
||||||
|
template <typename TosaOpT>
|
||||||
|
TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||||
|
TensorType outType, Value lhs, Value rhs) {
|
||||||
|
lhs = promoteType(rewriter, lhs, outType);
|
||||||
|
rhs = promoteType(rewriter, rhs, outType);
|
||||||
|
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This specialization is for Div op. Unlike other binary ops, it doesn't support
|
||||||
|
// floating type.
|
||||||
|
template <>
|
||||||
|
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, TensorType outType,
|
||||||
|
Value lhs, Value rhs);
|
||||||
|
|
||||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
Operation *op,
|
Operation *op,
|
||||||
Value params_value,
|
Value params_value,
|
||||||
|
|
|
@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
|
||||||
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
float val);
|
float val);
|
||||||
|
|
||||||
|
// Create a zero constant tensor of the desired type and shape.
|
||||||
|
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||||
|
Operation *op, Type type);
|
||||||
|
|
||||||
// Templated function to create a constant op for given type and shape.
|
// Templated function to create a constant op for given type and shape.
|
||||||
// T: storage C type.
|
// T: storage C type.
|
||||||
// Default template creates a constant tensor in T.
|
// Default template creates a constant tensor in T.
|
||||||
|
|
|
@ -100,17 +100,13 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto lhsElemTy = lhsTy.getElementType();
|
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
auto rhsElemTy = rhsTy.getElementType();
|
->convertType(op.getType())
|
||||||
|
.template cast<TensorType>();
|
||||||
|
|
||||||
if (lhsElemTy != rhsElemTy)
|
auto binaryOp =
|
||||||
return rewriter.notifyMatchFailure(op, "Input datatypes mismatched");
|
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
|
||||||
|
rewriter.replaceOp(op, binaryOp.getResult());
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
|
||||||
op,
|
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
||||||
op.getType()),
|
|
||||||
lhs, rhs);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -291,52 +287,30 @@ public:
|
||||||
"alpha in conversion to TOSA operation");
|
"alpha in conversion to TOSA operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure input of MulOp is same datetype, otherwise the lowering to
|
auto mulAlphaOp = tosa::createMulOpAndCast(
|
||||||
// arith dialect will bug
|
rewriter, op,
|
||||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
|
rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType),
|
||||||
rhsTensor, alphaTensor, /*shift=*/0);
|
rhsTensor, alphaTensor, /*shift=*/0);
|
||||||
|
|
||||||
if (outElemTy.isa<mlir::FloatType>() || outElemTy.isInteger(32)) {
|
if (outElemTy.isInteger(64)) {
|
||||||
// if outElemTy tensor<f32>, mulTensor must be tensor<f32>,
|
// Tosa doesn't support 64-bit elementwise addition and subtraction.
|
||||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
|
||||||
// tensor<f32> type
|
|
||||||
// if outElemTy tensor<i32>, mulTensor must be tensor<i32>,
|
|
||||||
// left value could be tensor<f32/i32/i64>, cast left value to
|
|
||||||
// tensor<i32> type
|
|
||||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
|
||||||
lhs = rewriter.create<tosa::CastOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
|
||||||
lhs);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
} else if (outElemTy.isInteger(64)) {
|
|
||||||
// if outElemTy tensor<i64>, mulTensor must be tensor<i32>,
|
// if outElemTy tensor<i64>, mulTensor must be tensor<i32>,
|
||||||
// left value could be tensor<f32/i32/i64> type, cast left value to
|
// left value could be tensor<f32/i32/i64> type, cast left value to
|
||||||
// tensor<i32> type
|
// tensor<i32> type
|
||||||
if (lhsType.getElementType() != rhsAlphaMulElemType)
|
auto addOrSubi64Op = tosa::createBinaryOpAndCast<TosaOpT>(
|
||||||
lhs = rewriter.create<tosa::CastOp>(
|
rewriter, op,
|
||||||
op.getLoc(),
|
|
||||||
RankedTensorType::get(lhsType.getShape(), rhsAlphaMulElemType),
|
|
||||||
lhs);
|
|
||||||
|
|
||||||
auto tosaOpTOutputTensor = rewriter.create<TosaOpT>(
|
|
||||||
op.getLoc(),
|
|
||||||
RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs,
|
RankedTensorType::get(outType.getShape(), rhsAlphaMulElemType), lhs,
|
||||||
multTensor);
|
mulAlphaOp);
|
||||||
// cast tensor<i32> back to tensor<i64>
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
|
||||||
tosaOpTOutputTensor);
|
|
||||||
|
|
||||||
|
// cast tensor<i32> back to tensor<i64>
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, addOrSubi64Op);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Only floating-point, i32, i64 datatype legalization supported");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto binaryOp = tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outType,
|
||||||
|
lhs, mulAlphaOp);
|
||||||
|
rewriter.replaceOp(op, binaryOp.getResult());
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
|
@ -457,15 +431,13 @@ public:
|
||||||
|
|
||||||
if (outElemTy.isa<mlir::FloatType>() ||
|
if (outElemTy.isa<mlir::FloatType>() ||
|
||||||
outElemTy.isa<mlir::IntegerType>()) {
|
outElemTy.isa<mlir::IntegerType>()) {
|
||||||
if (lhsType.getElementType() != outElemTy)
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
->convertType(op.getType())
|
||||||
|
.template cast<TensorType>();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||||
op,
|
rhsTensor, /*shift=*/0);
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
rewriter.replaceOp(op, mulOp.getResult());
|
||||||
op.getType()),
|
|
||||||
lhs, rhsTensor,
|
|
||||||
/*shift=*/0);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,23 +479,27 @@ public:
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||||
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<TensorType>();
|
||||||
|
|
||||||
|
// auto result;
|
||||||
|
Value result;
|
||||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||||
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
||||||
rhsTensor);
|
rhsTensor);
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
|
||||||
op,
|
result = tosa::createMulOpAndCast(rewriter, op, outType, lhs,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
rcpOp.getResult(), /*shift=*/0)
|
||||||
op.getType()),
|
.getResult();
|
||||||
lhs, rcpOp.getResult(), /*shift=*/0);
|
|
||||||
} else {
|
} else {
|
||||||
rewriter.replaceOpWithNewOp<tosa::DivOp>(
|
result = tosa::createBinaryOpAndCast<tosa::DivOp>(rewriter, op, outType,
|
||||||
op,
|
lhs, rhsTensor)
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
.getResult();
|
||||||
op.getType()),
|
|
||||||
lhs, rhsTensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {result});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1033,8 +1009,12 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
op, "Currently only scalar constants are supported for "
|
op, "Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA Pow operation");
|
"conversion in TOSA Pow operation");
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::PowOp>(
|
auto outType =
|
||||||
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
|
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||||
|
|
||||||
|
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
|
||||||
|
self, expTensor);
|
||||||
|
rewriter.replaceOp(op, powOp.getResult());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3289,15 +3269,8 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
// +0. (sign bit flips). These are probably acceptable in the short term,
|
||||||
// but we should put a comment acknowledging the danger, as there isn't an
|
// but we should put a comment acknowledging the danger, as there isn't an
|
||||||
// op that avoids the denorm flushing.
|
// op that avoids the denorm flushing.
|
||||||
SmallVector<int64_t> intValues(totalNumElements, 0);
|
Value zeroTensor =
|
||||||
SmallVector<float> floatValues(totalNumElements, 0.0);
|
tosa::getZerosLikeTensor(rewriter, op, resultType).value();
|
||||||
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
|
|
||||||
? tosa::getConstTensor<float>(
|
|
||||||
rewriter, op, floatValues, zeroTensorShape)
|
|
||||||
.value()
|
|
||||||
: tosa::getConstTensor<int64_t>(
|
|
||||||
rewriter, op, intValues, zeroTensorShape)
|
|
||||||
.value();
|
|
||||||
|
|
||||||
// Use add broadcast
|
// Use add broadcast
|
||||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
|
@ -19,7 +18,6 @@
|
||||||
|
|
||||||
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
|
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
|
|
||||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
|
@ -105,6 +103,32 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
|
||||||
return indicesDim;
|
return indicesDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tosa::MulOp createMulOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||||
|
TensorType outType, Value lhs, Value rhs,
|
||||||
|
int32_t shift) {
|
||||||
|
lhs = promoteType(rewriter, lhs, outType);
|
||||||
|
rhs = promoteType(rewriter, rhs, outType);
|
||||||
|
return tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), outType,
|
||||||
|
lhs, rhs, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
|
Operation *op, TensorType outType,
|
||||||
|
Value lhs, Value rhs) {
|
||||||
|
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
||||||
|
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
||||||
|
if (lhsElemTy.isa<mlir::FloatType>() || rhsElemTy.isa<mlir::FloatType>()) {
|
||||||
|
(void)rewriter.notifyMatchFailure(op,
|
||||||
|
"tosa.div only supports integer type");
|
||||||
|
}
|
||||||
|
|
||||||
|
lhs = promoteType(rewriter, lhs, outType);
|
||||||
|
rhs = promoteType(rewriter, rhs, outType);
|
||||||
|
return tosa::CreateOpAndInfer<tosa::DivOp>(rewriter, op->getLoc(), outType,
|
||||||
|
lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
Operation *op,
|
Operation *op,
|
||||||
Value paramsValue,
|
Value paramsValue,
|
||||||
|
|
|
@ -149,6 +149,27 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a zero constant tensor of the desired type and shape.
|
||||||
|
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||||
|
Operation *op, Type type) {
|
||||||
|
RankedTensorType resultType = type.dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!resultType) {
|
||||||
|
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultShape = resultType.getShape();
|
||||||
|
ShapedType zeroType =
|
||||||
|
RankedTensorType::get(resultShape, resultType.getElementType());
|
||||||
|
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
|
||||||
|
|
||||||
|
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
|
||||||
|
zeroAttr.cast<ElementsAttr>())
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Templated function to create a constant op for given type and shape.
|
// Templated function to create a constant op for given type and shape.
|
||||||
// T: storage C type.
|
// T: storage C type.
|
||||||
// Default template creates a constant tensor in T.
|
// Default template creates a constant tensor in T.
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.mul.Scalar$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16>
|
||||||
|
func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> {
|
||||||
|
%float2.000000e00 = torch.constant.float 2.000000e+00
|
||||||
|
%0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16>
|
||||||
|
return %0 : !torch.vtensor<[5],bf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16>
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16>
|
||||||
|
func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> {
|
||||||
|
%float1 = torch.constant.float 1.000000e+00
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16>
|
||||||
|
return %0 : !torch.vtensor<[6],bf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32>
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
|
||||||
|
func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32>
|
||||||
|
return %0 : !torch.vtensor<[5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.Scalar$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32>
|
||||||
|
func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int256 = torch.constant.int 256
|
||||||
|
%0 = torch.aten.add.Scalar %arg0, %int256, %int1 : !torch.vtensor<[1,1,32,64],si16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,32,64],si32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,32,64],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.sub.Scalar$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<bf16>,
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<bf16>}> : () -> tensor<bf16>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor<bf16>, tensor<bf16>) -> tensor<bf16>
|
||||||
|
func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16>
|
||||||
|
return %0 : !torch.vtensor<[],bf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.maximum$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32>
|
||||||
|
func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> {
|
||||||
|
%0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,3,1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<?x?xi32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> {
|
||||||
|
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xi16>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xi16>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
|
||||||
|
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
|
||||||
|
return %0 : !torch.vtensor<[?, ?],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: torch.aten.pow.Tensor$mixed_type
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf16>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<?x?xf16>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
|
||||||
|
func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%fp0 = torch.constant.float 3.123400e+00
|
||||||
|
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue