* [tosa] Support for AtenRsubScalarOp for scalar constants (#531)

* [tosa] Support for AtenCeilOp and AtenReciprocalOp
* [tosa] Support for comparator ops, Aten[Gt|Lt|Eq][Tensor|Scalar]Op with scalar constant
* [tosa] Support for Scalar variants of Aten[Mul|Div|Add|Sub] Ops with scalar constants

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
pull/536/head snapshot-20220120.218
Anup Gangwar 2022-01-20 12:58:30 -06:00 committed by GitHub
parent c8ee8d02eb
commit f8080bd1c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 469 additions and 109 deletions

View File

@ -54,4 +54,31 @@ TOSA_PASS_SET = {
"BmmModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",
"RsubModule_noalpha_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseGtFloatTensorModule_basic",
"ElementwiseGtIntTensorModule_basic",
"ElementwiseLtFloatScalarModule_basic",
"ElementwiseLtIntScalarModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseLtFloatTensorModule_basic",
"ElementwiseLtIntTensorModule_basic",
"ElementwiseEqFloatScalarModule_basic",
"ElementwiseEqIntScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatTensorModule_basic",
"ElementwiseEqIntTensorModule_basic",
"ElementwiseMulScalarModule_int",
"ElementwiseMulScalarModule_float",
"ElementwiseMulTensorIntModule_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseCeilModule_basic",
"ElementwiseReciprocalModule_basic",
"TypePromotionAlphaWiderModule_basic",
}

View File

@ -110,6 +110,67 @@ public:
}
};
// FIXME: This will eventually go into a Tosa*Utils file.
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &tosaTensor, Type dtype) {
if (dtype.isa<mlir::FloatType>()) {
double scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
return failure();
tosaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
int64_t scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue)))
return failure();
auto w = intType.getWidth();
if (w != 32 && w != 64)
return op->emitError("Unsupported integer type") << intType;
if (w == 32) {
tosaTensor = tosa::getConstTensor<int32_t>(
rewriter, op, {static_cast<int32_t>(scalarValue)}, {})
.getValue();
} else if (w == 64) {
tosaTensor =
tosa::getConstTensor<int64_t>(rewriter, op, {scalarValue}, {})
.getValue();
}
return success();
} else
return op->emitError("Usupported element type");
return success();
}
LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
bool checkForUnity) {
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
dtype)))
return success();
// `alpha` has not been specified.
int64_t alphaValue;
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
return op->emitError("Currently only scalar constants are supported for "
"alpha in TOSA operation");
// When no alpha has been specified, this must be 1.
if (checkForUnity && alphaValue != 1)
return op->emitError("Unsupported integer value for alpha");
alphaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue);
return success();
}
// These binary op legalizations are specific to add/sub which have an
// alpha multiplier.
template <typename AtenOpT, typename TosaOpT>
@ -121,34 +182,191 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy || !rhsTy)
if (!lhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// FIXME: Handle alpha.
// Needs extraction of floating point constant.
// Handle alpha.
Value alphaTensor;
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
alphaTensor, lhsElemTy, false)))
return op.emitError("Currently only scalar constants are supported for "
"alpha in conversion to TOSA operation");
auto multTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
rhsTensor, alphaTensor, /*shift*/ 0);
if (lhsElemTy.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhs);
lhs, multTensor);
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization supported");
}
}
}; // namespace
// Binary op legalizations for comparator ops.
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// There is no Lesser operator in TOSA
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>());
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
return success();
}
};
// Binary op legalizations for Mul variants.
template <typename AtenOpT>
class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
if (lhsElemTy.isa<mlir::FloatType>() ||
lhsElemTy.isa<mlir::IntegerType>()) {
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhsTensor,
/*shift=*/0);
return success();
} else {
// Quantized multiplication may need to rescale inputs.
return op.emitError("Only floating-point or integer datatype "
"legalization currently supported");
}
}
};
template <typename AtenOpT>
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
if (lhsElemTy.isa<mlir::FloatType>()) {
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
rhsTensor);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rcpOp.getResult(), /*shift=*/0);
} else {
rewriter.replaceOpWithNewOp<tosa::DivOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, rhsTensor);
}
return success();
}
};
// This defines a template to construct ops whose legalizations are
@ -227,69 +445,6 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
}
}
template <>
LogicalResult ConvertAtenOp<AtenMulTensorOp>::matchAndRewrite(
AtenMulTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
if (lhsElemTy.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), lhs, rhs,
/*shift=*/0);
return success();
} else {
// Quantized multiplication may need to rescale inputs.
return op.emitError(
"Only floating-point datatype legalization currently supported");
}
}
template <>
LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
AtenDivTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().cast<TensorType>();
if (!lhsTy || !rhsTy)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();
if (lhsElemTy != rhsElemTy)
return op.emitError("Add: input datatypes mismatched");
if (lhsElemTy.isa<mlir::FloatType>()) {
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), lhs,
rcpOp.getResult(), /*shift=*/0);
} else {
rewriter.replaceOpWithNewOp<tosa::DivOp>(
op, getTypeConverter()->convertType(op.getType()), lhs, rhs);
}
return success();
}
using ReductionConvFunc = llvm::Optional<Value> (*)(PatternRewriter &,
Operation *,
RankedTensorType, Value,
@ -635,23 +790,6 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
}
};
// FIXME(AG): This will eventually go into a Tosa*Utils file
// Convert an fp32 scalar into tosa fp32 tensor.
static LogicalResult
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
Value torchScalarValue, Value &tosaTensor) {
double scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
return failure();
// Construct a tosa.const
tosaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
@ -668,8 +806,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
Value expTensor;
Value expScalar = op.exponent();
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
expTensor)))
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar,
expTensor, selfTy.getElementType())))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
@ -1238,6 +1376,45 @@ public:
}
};
template <>
LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
AtenRsubScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto otherScalar = op.other();
auto alphaScalar = op.alpha();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("Only ranked tensor types supported in TOSA Rsub");
if (!selfTy.getElementType().isa<mlir::FloatType>())
return op.emitError("Only floating-point datatype legalization supported");
Value otherTensor, alphaTensor;
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar,
otherTensor, selfTy.getElementType())))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA Rsub operation");
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
alphaTensor, selfTy.getElementType(),
true)))
return failure();
auto multTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
alphaTensor, /*shift*/ 0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(
op, getTypeConverter()->convertType(op.getType()), otherTensor,
multTensor);
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -1281,6 +1458,8 @@ public:
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
#undef INSERT_UNARY_PATTERN
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
@ -1294,9 +1473,36 @@ public:
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
#undef INSERT_BINARY_MUL_PATTERN
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
#undef INSERT_BINARY_DIV_PATTERN
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
@ -1348,10 +1554,9 @@ public:
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -141,5 +141,9 @@ template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape);
template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
Operation *,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir

View File

@ -105,15 +105,46 @@ func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
// -----
// CHECK-LABEL: func @torch.aten.ceil$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.reciprocal$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.add$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.add"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
@ -123,14 +154,17 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
// -----
// CHECK-LABEL: func @torch.aten.sub$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sub"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_2]], %[[VAL_6]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
@ -347,3 +381,93 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<6.432100e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00
%alpha = torch.constant.float 6.432100e+00
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%other = torch.constant.float 3.123400e+00
%alpha = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.gt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.lt.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_3]], %[[VAL_2]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @torch.aten.eq.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}