From f8080bd1c5493365a84e9de7e6c1ae095d1600e5 Mon Sep 17 00:00:00 2001 From: Anup Gangwar Date: Thu, 20 Jan 2022 12:58:30 -0600 Subject: [PATCH] * [tosa] Support for AtenRsubScalarOp for scalar constants (#531) * [tosa] Support for AtenCeilOp and AtenReciprocalOp * [tosa] Support for comparator ops, Aten[Gt|Lt|Eq][Tensor|Scalar]Op with scalar constant * [tosa] Support for Scalar variants of Aten[Mul|Div|Add|Sub] Ops with scalar constants Signed-off-by: Anup Gangwar Co-authored-by: Anup Gangwar --- e2e_testing/torchscript/xfail_sets.py | 27 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 391 +++++++++++++----- .../TorchToTosa/TosaLegalizeUtils.cpp | 4 + test/Conversion/TorchToTosa/basic.mlir | 156 ++++++- 4 files changed, 469 insertions(+), 109 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 6275ab762..8bacab38a 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -54,4 +54,31 @@ TOSA_PASS_SET = { "BmmModule_basic", "Matmul_dot", "Matmul_3d", + "RsubModule_basic", + "RsubModule_noalpha_basic", + "ElementwiseGtFloatScalarModule_basic", + "ElementwiseGtIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseGtFloatTensorModule_basic", + "ElementwiseGtIntTensorModule_basic", + "ElementwiseLtFloatScalarModule_basic", + "ElementwiseLtIntScalarModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseLtFloatTensorModule_basic", + "ElementwiseLtIntTensorModule_basic", + "ElementwiseEqFloatScalarModule_basic", + "ElementwiseEqIntScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseEqFloatTensorModule_basic", + "ElementwiseEqIntTensorModule_basic", + "ElementwiseMulScalarModule_int", + "ElementwiseMulScalarModule_float", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseDivScalarModule_basic", + "ElementwiseSubScalarFloatModule_basic", + "ElementwiseAddScalarFloatModule_basic", + "ElementwiseMulScalarModule_float", + "ElementwiseCeilModule_basic", + "ElementwiseReciprocalModule_basic", + "TypePromotionAlphaWiderModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index aa0fe6f60..17ce945e9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -110,6 +110,67 @@ public: } }; +// FIXME: This will eventually go into a Tosa*Utils file. +LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value torchScalarValue, + Value &tosaTensor, Type dtype) { + if (dtype.isa()) { + double scalarValue; + + if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue))) + return failure(); + + tosaTensor = + mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue); + } else if (auto intType = dtype.dyn_cast()) { + int64_t scalarValue; + + if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue))) + return failure(); + + auto w = intType.getWidth(); + if (w != 32 && w != 64) + return op->emitError("Unsupported integer type") << intType; + + if (w == 32) { + tosaTensor = tosa::getConstTensor( + rewriter, op, {static_cast(scalarValue)}, {}) + .getValue(); + } else if (w == 64) { + tosaTensor = + tosa::getConstTensor(rewriter, op, {scalarValue}, {}) + .getValue(); + } + return success(); + } else + return op->emitError("Usupported element type"); + + return success(); +} + +LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value alphaScalar, + Value &alphaTensor, Type dtype, + bool checkForUnity) { + if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor, + dtype))) + return success(); + + // `alpha` has not been specified. + int64_t alphaValue; + if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue))) + return op->emitError("Currently only scalar constants are supported for " + "alpha in TOSA operation"); + // When no alpha has been specified, this must be 1. + if (checkForUnity && alphaValue != 1) + return op->emitError("Unsupported integer value for alpha"); + + alphaTensor = + mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue); + + return success(); +} + // These binary op legalizations are specific to add/sub which have an // alpha multiplier. template @@ -121,34 +182,191 @@ public: matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); - auto lhsTy = lhs.getType().cast(); + auto lhsTy = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); - auto rhsTy = rhs.getType().cast(); + auto rhsTy = rhs.getType().dyn_cast(); - if (!lhsTy || !rhsTy) + if (!lhsTy) return op.emitError("Only Tensor types supported in TOSA"); auto lhsElemTy = lhsTy.getElementType(); - auto rhsElemTy = rhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) + return op.emitError( + "Only floating-point or integer datatype legalization supported"); - if (lhsElemTy != rhsElemTy) - return op.emitError("Add: input datatypes mismatched"); + Value rhsAsTensor; + if (!rhsTy) { + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), + op.other(), rhsAsTensor, lhsElemTy))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; - // FIXME: Handle alpha. - // Needs extraction of floating point constant. + // Handle alpha. + Value alphaTensor; + if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(), + alphaTensor, lhsElemTy, false))) + return op.emitError("Currently only scalar constants are supported for " + "alpha in conversion to TOSA operation"); + + auto multTensor = rewriter.create( + op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy), + rhsTensor, alphaTensor, /*shift*/ 0); if (lhsElemTy.isa()) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), - lhs, rhs); + lhs, multTensor); return success(); } else { return op.emitError( "Only floating-point datatype legalization supported"); } } +}; // namespace + +// Binary op legalizations for comparator ops. +template +class ConvertAtenCompareOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("Only Tensor types supported in TOSA"); + + auto lhsElemTy = lhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + + Value rhsAsTensor; + if (!rhsTy) { + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), + op.other(), rhsAsTensor, lhsElemTy))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + // There is no Lesser operator in TOSA + auto swapLhsRhs = (std::is_same() || + std::is_same()); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + (swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor)); + return success(); + } +}; + +// Binary op legalizations for Mul variants. +template +class ConvertAtenMulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("Only Tensor types supported in TOSA"); + + auto lhsElemTy = lhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + + Value rhsAsTensor; + if (!rhsTy) { + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), + op.other(), rhsAsTensor, lhsElemTy))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + + if (lhsElemTy.isa() || + lhsElemTy.isa()) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + lhs, rhsTensor, + /*shift=*/0); + return success(); + } else { + // Quantized multiplication may need to rescale inputs. + return op.emitError("Only floating-point or integer datatype " + "legalization currently supported"); + } + } +}; + +template +class ConvertAtenDivOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("Only Tensor types supported in TOSA"); + + auto lhsElemTy = lhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + + Value rhsAsTensor; + if (!rhsTy) { + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), + op.other(), rhsAsTensor, lhsElemTy))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + + if (lhsElemTy.isa()) { + auto rcpOp = rewriter.create( + op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy), + rhsTensor); + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + lhs, rcpOp.getResult(), /*shift=*/0); + } else { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + lhs, rhsTensor); + } + return success(); + } }; // This defines a template to construct ops whose legalizations are @@ -227,69 +445,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMulTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - Value lhs = adaptor.self(); - auto lhsTy = lhs.getType().cast(); - Value rhs = adaptor.other(); - auto rhsTy = rhs.getType().cast(); - - if (!lhsTy || !rhsTy) - return op.emitError("Only Tensor types supported in TOSA"); - - auto lhsElemTy = lhsTy.getElementType(); - auto rhsElemTy = rhsTy.getElementType(); - - if (lhsElemTy != rhsElemTy) - return op.emitError("Add: input datatypes mismatched"); - - if (lhsElemTy.isa()) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), lhs, rhs, - /*shift=*/0); - return success(); - } else { - // Quantized multiplication may need to rescale inputs. - return op.emitError( - "Only floating-point datatype legalization currently supported"); - } -} - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenDivTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - Value lhs = adaptor.self(); - auto lhsTy = lhs.getType().cast(); - Value rhs = adaptor.other(); - auto rhsTy = rhs.getType().cast(); - - if (!lhsTy || !rhsTy) - return op.emitError("Only Tensor types supported in TOSA"); - - auto lhsElemTy = lhsTy.getElementType(); - auto rhsElemTy = rhsTy.getElementType(); - - if (lhsElemTy != rhsElemTy) - return op.emitError("Add: input datatypes mismatched"); - - if (lhsElemTy.isa()) { - auto rcpOp = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), lhs, - rcpOp.getResult(), /*shift=*/0); - } else { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), lhs, rhs); - } - return success(); -} - using ReductionConvFunc = llvm::Optional (*)(PatternRewriter &, Operation *, RankedTensorType, Value, @@ -635,23 +790,6 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; -// FIXME(AG): This will eventually go into a Tosa*Utils file -// Convert an fp32 scalar into tosa fp32 tensor. -static LogicalResult -tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op, - Value torchScalarValue, Value &tosaTensor) { - double scalarValue; - - if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue))) - return failure(); - - // Construct a tosa.const - tosaTensor = - mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue); - - return success(); -} - template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPowTensorScalarOp op, OpAdaptor adaptor, @@ -668,8 +806,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value expTensor; Value expScalar = op.exponent(); - if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar, - expTensor))) + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar, + expTensor, selfTy.getElementType()))) return op.emitError("Currently only scalar constants are supported for " "conversion in TOSA Pow operation"); @@ -1238,6 +1376,45 @@ public: } }; +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRsubScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto self = adaptor.self(); + auto otherScalar = op.other(); + auto alphaScalar = op.alpha(); + + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("Only ranked tensor types supported in TOSA Rsub"); + + if (!selfTy.getElementType().isa()) + return op.emitError("Only floating-point datatype legalization supported"); + + Value otherTensor, alphaTensor; + + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar, + otherTensor, selfTy.getElementType()))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA Rsub operation"); + + if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, + alphaTensor, selfTy.getElementType(), + true))) + return failure(); + + auto multTensor = rewriter.create( + op->getLoc(), getTypeConverter()->convertType(op.getType()), self, + alphaTensor, /*shift*/ 0); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), otherTensor, + multTensor); + + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -1281,6 +1458,8 @@ public: INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) #undef INSERT_UNARY_PATTERN #define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ @@ -1294,9 +1473,36 @@ public: target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) #undef INSERT_BINARY_ADDSUB_PATTERN +#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) +#undef INSERT_BINARY_COMPARE_PATTERN + +#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); +#undef INSERT_BINARY_MUL_PATTERN + +#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); +#undef INSERT_BINARY_DIV_PATTERN + #define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ target.addIllegalOp(); \ patterns.add>( \ @@ -1348,10 +1554,9 @@ public: INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenSigmoidOp); INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenMulTensorOp); - INSERT_ATENOP_PATTERN(AtenDivTensorOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index c7569b5e3..e5fc50624 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -141,5 +141,9 @@ template llvm::Optional getConstTensor(PatternRewriter &, ArrayRef vec, ArrayRef shape); +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); } // namespace tosa } // namespace mlir diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2aaa71edd..a71ca6944 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -105,15 +105,46 @@ func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // ----- +// CHECK-LABEL: func @torch.aten.ceil$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.reciprocal$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func @torch.aten.add$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG2:.*]] = torch.constant.int 1 -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.add"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32> @@ -123,14 +154,17 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten // ----- // CHECK-LABEL: func @torch.aten.sub$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[ARG2:.*]] = torch.constant.int 1 -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sub"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor, tensor) -> tensor -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32> @@ -347,3 +381,93 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<6.432100e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %other = torch.constant.float 3.123400e+00 + %alpha = torch.constant.float 6.432100e+00 + %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %other = torch.constant.float 3.123400e+00 + %alpha = torch.constant.int 1 + %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.gt.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.lt.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_3]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func @torch.aten.eq.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +}