diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index 96839fbfb..4abc401e7 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -25,6 +26,470 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +bool skipMultiplyAlpha(Value alphaValue) { + double doubleValue; + auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue)); + + int64_t intValue; + auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue)); + + return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0)); +} + +// These legalizations are for unary ops with only for floating point datatypes. +// There is no supported quantized integer mode for these. +namespace { +template +class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + + if (!selfTy) + return op.emitError("Only Tensor types supported in MHLO"); + + if (selfTy.getElementType().isa()) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + self); + return success(); + } else { + return op.emitError( + "Only floating-point datatype legalization supported"); + } + } +}; +} // namespace + +// aten.ones & aten.zeros +// Ref: Error checking based on the Torch to TOSA lowering +namespace { +template +class ConvertAtenConstPatternOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (!outType) + return op.emitError("Only Tensor types supported in MHLO"); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + + // FIXME: Handle layout, device and pin_memory. Assume dtype has been + // processed to set output type correctly? + if (!op.layout().getType().template isa()) + return op.emitError("Only default layout is supported"); + + bool pinMemory; + if (!op.pin_memory().getType().template isa() && + (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || + pinMemory)) { + return op.emitError( + "Unsupported pin_memory, should be either None or false"); + } + + SmallVector shape; + if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { + return op.emitError("Shape must be a list of Scalar constants"); + } + + int64_t size = 1; + for (auto s : shape) + size *= s; + + SmallVector values(size, fillVal); + auto constOp = + mhlo::getConstTensor(rewriter, op, values, shape).getValue(); + + rewriter.replaceOpWithNewOp(op, outType, constOp); + return success(); + } +}; + +} // namespace + +// These binary op legalizations are specific to add/sub which have an +// alpha multiplier. +namespace { +template +class ConvertAtenAddSubOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + TensorType lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + TensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType) + return op.emitError("Only Tensor types supported in MHLO"); + + TensorType outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + Value rhsAsTensor; + if (!rhsType) { + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), + rhsAsTensor, outElemTy, + outType.getShape()))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in MHLO operation"); + } + Value lhsTensor = lhs; + Value rhsTensor = rhsType ? rhs : rhsAsTensor; + + // Handle broadcasting. Since we have the output type already, here we + // just broodcast operands' shape to output shape. + lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); + rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); + + // Handle alpha. + Value multTensor; + if (skipMultiplyAlpha(op.alpha())) { + multTensor = rhsTensor; + } else { + Value alphaTensor; + if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(), + op.alpha(), alphaTensor, + outElemTy, outType.getShape(), + /*checkForUnity=*/false))) { + return op.emitError("Currently only scalar constants are supported for " + "alpha in conversion to MHLO operation"); + } + + multTensor = rewriter.create(op.getLoc(), outType, rhsTensor, + alphaTensor); + } + + rewriter.replaceOpWithNewOp(op, outType, lhsTensor, multTensor); + return success(); + } +}; +} // namespace + +// Binary op legalizations for Mul variants. +namespace { +template +class ConvertAtenMulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + TensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType) + return op.emitError("Only Tensor types supported in MHLO"); + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + Value lhsTensor = lhs; + Value rhsTensor; + if (std::is_same()) { + rhsTensor = lhs; + } else { + if (!rhsType) { + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), + rhsTensor, outElemTy, + outType.getShape()))) + return op.emitError( + "Currently only scalar constants are supported for " + "conversion in MHLO operation"); + } else { + rhsTensor = rhs; + } + } + + // Handle broadcasting. Since we have the output type already, here we + // just broodcast operands' shape to output shape. + lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); + rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); + + rewriter.replaceOpWithNewOp(op, outType, lhsTensor, rhsTensor); + return success(); + } +}; +} // namespace + +// Binary op legalizations for Div variants. +namespace { +template +class ConvertAtenDivOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().dyn_cast(); + Value rhs = adaptor.other(); + auto rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("Only Tensor types supported."); + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + Value lhsTensor = lhs; + Value rhsTensor; + if (!rhsTy) { + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), + rhsTensor, outElemTy, + outType.getShape()))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in MHLO operation"); + } else { + rhsTensor = rhs; + } + + // Handle broadcasting. Since we have the output type already, here we + // just broodcast operands' shape to output shape. + lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); + rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); + + rewriter.replaceOpWithNewOp(op, outType, lhsTensor, rhsTensor); + return success(); + } +}; + +} // namespace + +// Binary op legalizations for comparator ops. +namespace { +template +class ConvertAtenCompareOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.self(); + Value rhs = adaptor.other(); + RankedTensorType lhsTy = lhs.getType().dyn_cast(); + RankedTensorType rhsTy = rhs.getType().dyn_cast(); + + if (!lhsTy) + return op.emitError("Only Tensor types supported in MHLO"); + + RankedTensorType outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type lhsElemTy = lhsTy.getElementType(); + if (!lhsElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + Value rhsAsTensor; + if (!rhsTy) { + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), + rhsAsTensor, lhsElemTy, {}))) { + return op.emitError("Currently only scalar constants are supported for " + "conversion in MHLO operation"); + } + } + + Value lhsTensor = lhs; + Value rhsTensor = rhsTy ? rhs : rhsAsTensor; + rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, lhsTy); + + mhlo::ComparisonTypeAttr compareTypeAttr; + mhlo::ComparisonDirectionAttr compareDirectionAttr; + + if (lhsElemTy.isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + op->getContext(), mhlo::ComparisonType::FLOAT); + } else if (lhsElemTy.isa()) { + compareTypeAttr = mhlo::ComparisonTypeAttr::get( + op->getContext(), mhlo::ComparisonType::SIGNED); + } + + if (std::is_same() || + std::is_same()) { + compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( + op->getContext(), mhlo::ComparisonDirection::LT); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( + op->getContext(), mhlo::ComparisonDirection::GT); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( + op->getContext(), mhlo::ComparisonDirection::EQ); + } else if (std::is_same() || + std::is_same()) { + compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( + op->getContext(), mhlo::ComparisonDirection::NE); + } + + rewriter.replaceOpWithNewOp( + op, outType, lhsTensor, rhsTensor, compareDirectionAttr, + compareTypeAttr); + return success(); + } +}; + +} // namespace + +// AtenTransposeIntOp +namespace { +class ConvertAtenTransposeIntOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + int64_t dim0; + if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) { + return rewriter.notifyMatchFailure(op, "dim0 must be constant"); + } + int64_t dim1; + if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) { + return rewriter.notifyMatchFailure(op, "dim1 must be constant"); + } + + auto inType = self.getType().cast(); + auto inputRank = inType.getRank(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + dim0 = toPositiveDim(dim0, inputRank); + if (!isValidDim(dim0, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim0 out of range"); + } + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim1 out of range"); + } + + SmallVector permValues(inputRank); + std::iota(std::begin(permValues), std::end(permValues), 0); + std::swap(permValues[dim0], permValues[dim1]); + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(permValues.size())}, + rewriter.getI64Type()), + permValues); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); + return success(); + } +}; +} // namespace + +// AtenBroadcastToOp +namespace { +class ConvertAtenBroadcastToOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); + rewriter.replaceOp(op, bcastOp); + return success(); + } +}; +} // namespace + +// AtenPermuteOp +namespace { +class ConvertAtenPermuteOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.self(); + // Not a ranked tensor type + auto inType = self.getType().dyn_cast(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + if (!inType) + return op.emitError("Only ranked tensor types with static shapes are " + "currently supported"); + + SmallVector permValues; + if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) + return rewriter.notifyMatchFailure( + op, "Only constant dimensions are currently supported"); + + int64_t inRank = inType.getRank(); + for (auto &d : permValues) { + d = toPositiveDim(d, inRank); + if (!isValidDim(d, inRank)) + return op.emitError("Not all dims are valid"); + } + + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(permValues.size())}, + rewriter.getI64Type()), + permValues); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); + return success(); + } +}; + +} // namespace namespace { template @@ -57,15 +522,549 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } // namespace +// ValueTensorLiteralOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + ValueTensorLiteralOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + // Tensors with integer types need to be converted to signless integer + // element type. All tensors with element types other than integer can reuse + // existing elements attribute. + if (auto elements = op.valueAttr().dyn_cast()) { + Type builtinTensorElemTy = resultType.getElementType(); + unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); + + DenseElementsAttr valueAttr = + elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { + return APInt(bitWidth, v.getSExtValue()); + }); + rewriter.replaceOpWithNewOp(op, resultType, valueAttr); + return success(); + } + + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.value()); + return success(); +} + +} // namespace + +// AtenReciprocalOp +// Reciprocal(x) = Div(1, x) +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReciprocalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + if (!inputTy.getElementType().isa()) { + return op.emitError("Only floating-point datatype legalization supported " + "for AtenReciprocalOp"); + } + Value oneTensor = + mhlo::getConstTensor(rewriter, op, {static_cast(1.0)}, {}) + .getValue(); + oneTensor = mhlo::promoteAndBroadcast(rewriter, oneTensor, inputTy); + rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); + return success(); +} +} // namespace + +// PrimNumToTensorScalarOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimNumToTensorScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType outputType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + auto outputShape = outputType.getShape(); + auto outputElemType = outputType.getElementType(); + Value mhloTensor; + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor, + outputElemType, outputShape, + false))) { + return op->emitError("Failed lowering PrimNumToTensorScalarOp to MHLO"); + } + rewriter.replaceOp(op, mhloTensor); + return success(); +} +} // namespace + +// AtenContiguousOp +// Ref: TosaToTosa.cpp for implementation details +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenContiguousOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are currently supported"); + + // FIXME: memory_format is not handled. + + rewriter.replaceOp(op, adaptor.self()); + + return success(); +} + +} // namespace + +// AtenReluOp +// Relu(x) = Max(0, x) +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + auto lhsElemTy = lhsTy.getElementType(); + + int64_t lhsSize = 1; + for (auto &en : llvm::enumerate(lhsTy.getShape())) { + lhsSize *= en.value(); + } + auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy); + DenseElementsAttr constAttr; + if (lhsElemTy.isa()) { + std::vector constVec( + lhsSize, + APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), + /*negative=*/false)); + constAttr = DenseElementsAttr::get(constTy, constVec); + } else if (lhsElemTy.isa()) { + std::vector constVec( + lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth())); + constAttr = DenseElementsAttr::get(constTy, constVec); + } + Value rhs = + rewriter.create(op.getLoc(), constTy, constAttr); + + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); +} + +} // namespace + +// AtenErfOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenErfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.self(); + auto inputType = input.getType().cast(); + if (!inputType.getElementType().isa()) { + return rewriter.notifyMatchFailure(op, "Only support float data type"); + } + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + + // Using: + // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with + // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = + // 0.000972, a4 = 0.078108. + // Erf = 1 - 1 / (1 + a1X + a2X^2 + a3X^3 + a4X^4)^4 + + auto loc = op->getLoc(); + auto zeroConst = + mhlo::getConstTensor(rewriter, op, {0.0}, {}).getValue(); + auto zero = mhlo::promoteAndBroadcast(rewriter, zeroConst, outType); + auto oneConst = + mhlo::getConstTensor(rewriter, op, {1.0}, {}).getValue(); + auto one = mhlo::promoteAndBroadcast(rewriter, oneConst, outType); + auto a1Const = + mhlo::getConstTensor(rewriter, op, {0.278393}, {}).getValue(); + auto a1 = mhlo::promoteAndBroadcast(rewriter, a1Const, outType); + auto a2Const = + mhlo::getConstTensor(rewriter, op, {0.230389}, {}).getValue(); + auto a2 = mhlo::promoteAndBroadcast(rewriter, a2Const, outType); + auto a3Const = + mhlo::getConstTensor(rewriter, op, {0.000972}, {}).getValue(); + auto a3 = mhlo::promoteAndBroadcast(rewriter, a3Const, outType); + auto a4Const = + mhlo::getConstTensor(rewriter, op, {0.078108}, {}).getValue(); + auto a4 = mhlo::promoteAndBroadcast(rewriter, a4Const, outType); + + auto absX = rewriter.create(loc, outType, input); + auto a1X = rewriter.create(loc, outType, a1, absX); + auto sum = rewriter.create(loc, outType, a1X, one); + + auto x2 = rewriter.create(loc, outType, absX, absX); + auto a2X = rewriter.create(loc, outType, a2, x2); + sum = rewriter.create(loc, outType, sum, a2X); + + auto x3 = rewriter.create(loc, outType, x2, absX); + auto a3X = rewriter.create(loc, outType, a3, x3); + sum = rewriter.create(loc, outType, sum, a3X); + + auto x4 = rewriter.create(loc, outType, x3, absX); + auto a4X = rewriter.create(loc, outType, a4, x4); + sum = rewriter.create(loc, outType, sum, a4X); + + auto rcprl = rewriter.create(loc, outType, one, sum); + auto rcprl2 = rewriter.create(loc, outType, rcprl, rcprl); + auto rcprl4 = rewriter.create(loc, outType, rcprl2, rcprl2); + auto erf = rewriter.create(loc, outType, one, rcprl4); + + // Deal with negative x. + mhlo::ComparisonDirectionAttr compareDirectionAttr = + mhlo::ComparisonDirectionAttr::get(op->getContext(), + mhlo::ComparisonDirection::GE); + mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get( + op->getContext(), mhlo::ComparisonType::FLOAT); + auto geZero = rewriter.create( + loc, RankedTensorType::get(outType.getShape(), rewriter.getI1Type()), + input, zero, compareDirectionAttr, compareTypeAttr); + auto negaErf = rewriter.create(loc, erf); + rewriter.replaceOpWithNewOp(op, outType, geZero, erf, + negaErf); + return success(); +} + +} // namespace + +// AtenBatchNormOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBatchNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.input(); + // shape = [N, C, H, W] + auto inputTy = input.getType().cast(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + Value runningMean = adaptor.running_mean(); + Value runningVar = adaptor.running_var(); + // momentum is ignored + Value momentum = adaptor.momentum(); + (void)momentum; + + // init weight, bias, runningVar, runningMean if they are none + auto initNoneValue = [&](Value &input, bool zero) { + SmallVector constVec(inputTy.getShape()[1], + APFloat::getZero(inputTy.getElementType() + .cast() + .getFloatSemantics())); + if (!zero) { + for (auto &item : constVec) { + item = APFloat(inputTy.getElementType() + .cast() + .getFloatSemantics(), + 1); + } + } + auto constType = RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType()); + auto constAttr = DenseElementsAttr::get(constType, constVec); + input = + rewriter.create(op.getLoc(), constType, constAttr); + }; + if (failed(checkNotNone(rewriter, op, weight))) { + initNoneValue(weight, false); + } + if (failed(checkNotNone(rewriter, op, bias))) { + initNoneValue(bias, true); + } + if (failed(checkNotNone(rewriter, op, runningVar))) { + initNoneValue(runningVar, false); + } + if (failed(checkNotNone(rewriter, op, runningMean))) { + initNoneValue(runningMean, true); + } + + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + auto runningMeanTy = runningMean.getType().cast(); + auto runningVarTy = runningVar.getType().cast(); + if (inputTy.getRank() <= 2) { + return rewriter.notifyMatchFailure(op, + "input should have rank larger than 2"); + } + if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || + runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "expect weight, bias, running_mean and running_var to be rank 1"); + } + if (!inputTy.getElementType().template isa() || + !weightTy.getElementType().template isa() || + !biasTy.getElementType().template isa() || + !runningMeanTy.getElementType().template isa() || + !runningVarTy.getElementType().template isa()) { + return op.emitError( + "Only float element type is supported in MHLO BatchNormOp"); + } + + double eps = 0.0; + if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { + return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported"); + } + bool training = false; + if (!matchPattern(op.training(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "non-bool training unsupported"); + } + // TODO: handle cudnnEnabled parameter. Here, we just ignore it! + bool cudnnEnabled = false; + if (!matchPattern(op.cudnn_enabled(), m_TorchConstantBool(&cudnnEnabled))) { + return rewriter.notifyMatchFailure(op, + "non-bool cudnn_enabled unsupported"); + } + if (training) { + Type outputTy = getTypeConverter()->convertType(op.getType()); + Type batchMeanOrVarTy = + RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); + auto batchNormTrainingResult = rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); + rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); + return success(); + } else { + Type outputTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( + op, outputTy, input, weight, bias, runningMean, runningVar, + rewriter.getFloatAttr(inputTy.getElementType(), eps), + rewriter.getI64IntegerAttr(1)); + return success(); + } +} + +} // namespace + +// AtenNativeLayerNormOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNativeLayerNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.input(); + auto inputTy = input.getType().cast(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + + SmallVector normalizedShape; + if (!matchPattern(op.normalized_shape(), + m_TorchConstantIntList(normalizedShape))) { + return rewriter.notifyMatchFailure( + op, "normalized_shape must be a list of const int"); + } + double eps = 0; + if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { + return rewriter.notifyMatchFailure(op, "non const float eps unsupported"); + } + if (failed(checkNotNone(rewriter, op, weight)) || + failed(checkNotNone(rewriter, op, bias))) { + return op->emitError("Unsupported None for weight or bias"); + } + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); + + if (!inputTy.getElementType().isa() || + !biasTy.getElementType().isa() || + !weightTy.getElementType().isa()) { + return op->emitError("For now, only float data type are supported"); + } + int64_t normalizedShapeRank = normalizedShape.size(); + if (weightTy.getRank() != normalizedShapeRank || + biasTy.getRank() != normalizedShapeRank || + inputRank < normalizedShapeRank || normalizedShapeRank < 1) { + return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" + "normalized shape not compatible"); + } + for (int64_t i = 1; i <= normalizedShapeRank; i++) { + if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] || + weightTy.getShape()[normalizedShapeRank - i] != + normalizedShape[normalizedShapeRank - i] || + biasTy.getShape()[normalizedShapeRank - i] != + normalizedShape[normalizedShapeRank - i]) { + return op.emitError("mismatching contracting dimension"); + } + } + + // flatten dims to fit batch_norm operation. + int64_t numFeatureDimSize = 1; + int64_t numEmbeddingDimSize = 1; + for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { + numFeatureDimSize *= inputShape[i]; + } + for (int64_t i = 0; i < normalizedShapeRank; i++) { + numEmbeddingDimSize *= normalizedShape[i]; + } + SmallVector inputFlattenShape{1, numFeatureDimSize, + numEmbeddingDimSize}; + SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; + + auto mhloBatchNormOutTy = + RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); + auto mhloBathNormOutMeanOrVarTy = + RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); + + // reshape input + auto mhloInput = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, input, + mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), + {static_cast(inputFlattenShape.size())}) + .getValue()); + + // generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. + SmallVector zeroConstVec( + numFeatureDimSize, APFloat::getZero(inputTy.getElementType() + .cast() + .getFloatSemantics())); + SmallVector oneConstVec( + numFeatureDimSize, + APFloat( + inputTy.getElementType().cast().getFloatSemantics(), + 1)); + auto oneOrZeroConstType = + RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); + + Value scale = rewriter.create( + op->getLoc(), oneOrZeroConstType, + DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); + Value offset = rewriter.create( + op->getLoc(), oneOrZeroConstType, + DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); + auto batchNormTrainingResult = rewriter.create( + op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, + mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + + // reshape back + auto outputTy = + getTypeConverter()->convertType(op.getType(0)).cast(); + auto outputMeanOrVarTy = + getTypeConverter()->convertType(op.getType(1)).cast(); + + auto output = rewriter.create( + op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), + mhlo::getConstTensor(rewriter, op, outputTy.getShape(), + {static_cast(outputTy.getShape().size())}) + .getValue()); + auto mean = rewriter.create( + op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), + mhlo::getConstTensor( + rewriter, op, outputMeanOrVarTy.getShape(), + {static_cast(outputMeanOrVarTy.getShape().size())}) + .getValue()); + auto var = rewriter.create( + op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), + mhlo::getConstTensor( + rewriter, op, outputMeanOrVarTy.getShape(), + {static_cast(outputMeanOrVarTy.getShape().size())}) + .getValue()); + + // Apply affine transform: output x weight + bias [element-wise] + auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); + auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto outputMulWeight = + rewriter.create(op->getLoc(), output, bcastedWeight); + auto finalOuput = + rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); + rewriter.replaceOp(op, {finalOuput, mean, var}); + return success(); +} + +} // namespace + void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + +#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); + INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); + INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); +#undef INSERT_UNARY_FPONLY_PATTERN + +#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); +#undef INSERT_CONSTANT_FILL_PATTERN + +#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, mhlo::AddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, mhlo::SubtractOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, mhlo::SubtractOp); +#undef INSERT_BINARY_ADDSUB_PATTERN + +#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); +#undef INSERT_BINARY_MUL_PATTERN + +#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); +#undef INSERT_BINARY_DIV_PATTERN + +#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp); + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp); + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp); +#undef INSERT_BINARY_COMPARE_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenTanhOp); -#undef INSERT_ATENOP_PATTERN + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReciprocalOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenErfOp); + + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); +#undef INSERT_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 130ec544a..7798b8d18 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo TorchToMhlo.cpp + MhloLegalizeUtils.cpp BasicOp.cpp GatherOp.cpp ViewLikeOps.cpp diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp new file mode 100644 index 000000000..dc8ae290d --- /dev/null +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -0,0 +1,318 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "./MhloLegalizeUtils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace mlir { +namespace mhlo { + +// Create a 32-bit float constant operator from a float +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val) { + auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); + auto const_attr = DenseElementsAttr::get(const_type, val); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Create a 64-bit float constant operator from a double +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val) { + auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, val); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = + RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for APInt +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + auto const_type = RankedTensorType::get( + shape, rewriter.getIntegerType(vec[0].getBitWidth())); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template specialization for float +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +template <> +llvm::Optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return llvm::None; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + +// Template instantiation +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); + +template llvm::Optional getConstTensor(PatternRewriter &, + Operation *, + ArrayRef vec, + ArrayRef shape); + + +template +static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, + const int64_t &intValue) { + if (isFloat) { + // Do a round-trip check here instead of numeric limits due to + // compiler warnings around double <-> int conversion. + return (doubleValue == static_cast(static_cast(doubleValue))); + } else { + assert(isInt); + return (intValue >= std::numeric_limits::min()) && + (intValue <= std::numeric_limits::max()); + } + return true; +} + +template +Value getSplatConstTensor(ConversionPatternRewriter &rewriter, + Operation *op, + T val, + Type dtype, + llvm::ArrayRef dshape) { + auto const_type = RankedTensorType::get( + dshape, dtype); + auto const_attr = SplatElementsAttr::get(const_type, val); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + + +LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value torchScalarValue, + Value &mhloTensor, Type dtype, + llvm::ArrayRef dshape, + bool doBroadcast) { + // Retrieve a const float or int value but create the out Tensor with dtype. + double doubleValue; + auto isFloat = + matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue)); + + int64_t intValue; + auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue)); + + if (!isFloat && !isInt) + return op->emitError("Unable to extract the scalar constant"); + + if (dtype.isa()) { + if (doBroadcast) { + mhloTensor = getSplatConstTensor(rewriter, op, + (isFloat ? doubleValue : intValue), + dtype, dshape); + } else { + mhloTensor = mhlo::getConstTensor( + rewriter, op, (isFloat ? doubleValue : intValue), dshape) + .getValue(); + } + } else if (auto intType = dtype.dyn_cast()) { + auto w = intType.getWidth(); + if (w != 32 && w != 64) + return op->emitError("Unsupported integer type") << intType; + + if (w == 32) { + if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { + return op->emitError("Supplied value of scalar constant exceeds limits " + "of destination type"); + } + int32_t d = isFloat ? static_cast(doubleValue) + : static_cast(intValue); + if (doBroadcast) { + mhloTensor = getSplatConstTensor(rewriter, op, d, dtype, dshape); + } else { + mhloTensor = + mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); + } + } else if (w == 64) { + if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { + return op->emitError("Supplied value of scalar constant exceeds limits " + "of destination type"); + } + int64_t d = (isFloat ? static_cast(doubleValue) : intValue); + if (doBroadcast) { + mhloTensor = getSplatConstTensor(rewriter, op, d, dtype, dshape); + } else { + mhloTensor = + mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); + } + } + } else + return op->emitError("Usupported element type"); + + return success(); +} + + +LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value alphaScalar, + Value &alphaTensor, Type dtype, + llvm::ArrayRef dshape, + bool checkForUnity) { + if (succeeded(torchScalarToMhloTensor(rewriter, op, alphaScalar, alphaTensor, + dtype, dshape))) + return success(); + + // `alpha` has not been specified. + int64_t alphaValue; + if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue))) + return op->emitError("Currently only scalar constants are supported for " + "alpha in MHLO operation"); + // When no alpha has been specified, this must be 1. + if (checkForUnity && alphaValue != 1) + return op->emitError("Unsupported integer value for alpha"); + + alphaTensor = + mlir::mhlo::getMhloConstTensorSingleF32(rewriter, op, alphaValue); + + return success(); +} + + +Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, + Value input, TensorType outType) { + // Two tensors are “broadcastable” if the following rules hold: + // - Each tensor has at least one dimension. + // - When iterating over the dimension sizes, starting at the trailing dimension, + // the dimension sizes must either be equal, one of them is 1, or one of them + // does not exist. + Operation* op = input.getDefiningOp(); + TensorType in_type = input.getType().dyn_cast(); + + if (in_type.getElementType() != outType.getElementType()) { + TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType()); + input = rewriter.create(op->getLoc(), promoted_type, input); + } + + ArrayRef inShape = in_type.getShape(); + ArrayRef outShape = outType.getShape(); + + bool do_bcast = (inShape.size() != outShape.size()); + SmallVector bcastDims; + for (size_t i = 0; i < inShape.size(); ++i) { + // iterating over the dimension sizes, starting at the trailing dimension + size_t outPos = outShape.size() - 1 - i; + size_t inPos = inShape.size() - 1 - i; + int64_t outDim = outShape[outPos]; + int64_t inDim = inShape[inPos]; + if (inDim == outDim) { + bcastDims.push_back(outPos); + } else if (inDim != outDim && inDim == 1) { + bcastDims.push_back(outPos); + do_bcast = true; + } else { + op->emitError("The size of tensor a (") << inDim << ")" + << "must match the size of tensor b (" << outDim << ")" + << "at non-singleton dimension " << inPos; + } + } + std::reverse(bcastDims.begin(), bcastDims.end()); + if (!do_bcast) { + return input; + } + DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(bcastDims.size())}, rewriter.getI64Type()), + bcastDims); + auto bcast_op = + rewriter.create(op->getLoc(), outType, input, bcast_attr); + return bcast_op.getResult(); +} +} // namespace mhlo +} // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h new file mode 100644 index 000000000..410e51846 --- /dev/null +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -0,0 +1,62 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace mhlo { + +using mlir::ConversionPatternRewriter; + +// Create a 32-bit float constant operator from a float +Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val); + +// Create a 64-bit float constant operator from a double +Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val); + +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +// To create INT48 MHLO constant, need to pass in llvm::APInt instead. +template +llvm::Optional getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape); + +template +Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + T val, Type dtype, llvm::ArrayRef dshape); + +LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value torchScalarValue, + Value &mhloTensor, Type dtype, + llvm::ArrayRef dshape, + bool doBroadcast = true); + +LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value alphaScalar, + Value &alphaTensor, Type dtype, + llvm::ArrayRef dshape, + bool checkForUnity); + +Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, + TensorType outType); +} // namespace mhlo +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index c76d30f2e..045e694ab 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -10,3 +10,695 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = mhlo.add %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64> +// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64> +func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64> + return %0 : !torch.vtensor<[4,64],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$bcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.addtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32> +// CHECK: %[[VAL_7:.*]] = mhlo.add %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.multensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.multensor$bcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32> +// CHECK: %[[VAL_5:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32> +func.func @torch.aten.multensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { + %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32> + return %0 : !torch.vtensor<[8,4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = mhlo.subtract %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64> +// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64> +func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64> + return %0 : !torch.vtensor<[4,64],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$bcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.subtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32> +// CHECK: %[[VAL_7:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divtensor$bcast( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32> +// CHECK: %[[VAL_5:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32> +func.func @torch.aten.divtensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32> + return %0 : !torch.vtensor<[8,4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.exp$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.clone$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<4x64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %none = torch.constant.none + %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[4,64],f32>, !torch.none -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> +func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { + %0 = torch.vtensor.literal(dense<0.0> : tensor) : !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> +func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { + %0 = torch.vtensor.literal(dense<1> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<3.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_1]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> +func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> { + %int3 = torch.constant.int 3 + %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> +func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.tensor$convert( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_3]]) : (tensor<64xf32>) -> tensor<64xi32> +// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_4]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xi32>) -> tensor<4x64xi32> +// CHECK: %[[VAL_6:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_5]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xi32>, tensor<4x64xi32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],i1> + +func.func @torch.aten.gt.tensor$convert(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.lt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> +func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { + %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.eq.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> +func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { + %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.ne.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> +// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> +func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { + %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> + return %0 : !torch.vtensor<[4,64],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { +// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> +// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CEHCK: %true = torch.constant.bool true +// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor +// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CEHCK: %int1 = torch.constant.int 1 +// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor +// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor +// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) +// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> +// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> + +func.func @torch.aten.batch_norm(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %int1 = torch.constant.int 1 + %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> + return %4 : !torch.vtensor<[2,3,5,5],f32> +} + +// CHECK-LABEL: func.func @torch.aten.batch_norm$none_bias_weight( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { +// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> +// CEHCK: %none = torch.constant.none +// CEHCK: %1 = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CEHCK: %2 = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CEHCK: %true = torch.constant.bool true +// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor +// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CEHCK: %int1 = torch.constant.int 1 +// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor +// CEHCK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor +// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CEHCK: %[[VAL_6:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_5]], %[[VAL_6]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) +// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> +// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> +func.func @torch.aten.batch_norm$none_bias_weight(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %int1 = torch.constant.int 1 + %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %4 = torch.aten.batch_norm %arg0, %none, %none, %1, %0, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.none, !torch.none, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> + return %4 : !torch.vtensor<[2,3,5,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$inference( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { +// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> +// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CEHCK: %true = torch.constant.bool true +// CHECK: %false = torch.constant.bool false +// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor +// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CEHCK: %int1 = torch.constant.int 1 +// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor +// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor +// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) +// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> +// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> +func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %false = torch.constant.bool false + %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %int1 = torch.constant.int 1 + %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> + return %4 : !torch.vtensor<[2,3,5,5],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.relu( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],f32> -> tensor<2x5xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2x5xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xf32> -> !torch.vtensor<[2,5],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],f32> +func.func @torch.aten.relu(%arg0: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> { + %0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],f32> -> !torch.vtensor<[2,5],f32> + return %0 : !torch.vtensor<[2,5],f32> +} +// ----- +// CHECK-LABEL: func.func @torch.aten.relu$int8( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],si8> -> tensor<2x5xi8> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor<2x5xi8> +// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xi8> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xi8> -> !torch.vtensor<[2,5],si8> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],si8> +func.func @torch.aten.relu$int8(%arg0: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> { + %0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],si8> -> !torch.vtensor<[2,5],si8> + return %0 : !torch.vtensor<[2,5],si8> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reciprocal( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> { +// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0:.*]] : !torch.vtensor<[5,5,5],f32> -> tensor<5x5x5xf32> +// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CEHCK: %[[VAL_3:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<5x5x5xf32> +// CEHCK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_3]], %[[VAL_1]] : tensor<5x5x5xf32> +// CEHCK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<5x5x5xf32> -> !torch.vtensor<[5,5,5],f32> +// CEHCK: return %[[VAL_5]] : !torch.vtensor<[5,5,5],f32> +func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> { + %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[5,5,5],f32> -> !torch.vtensor<[5,5,5],f32> + return %0 : !torch.vtensor<[5,5,5],f32> +} + +// CHECK-LABEL: func @torch.aten.native_layer_norm( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32> +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int5 = torch.constant.int 5 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %true = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64> +// CHECK: %[[VAL_6:.*]] = "mhlo.dynamic_reshape"(%[[VAL_1]], %[[VAL_5]]) : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> +// CHECK: %[[VAL_13:.*]] = "mhlo.dynamic_reshape"(%[[VAL_9]], %[[VAL_12]]) : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_15:.*]] = "mhlo.dynamic_reshape"(%[[VAL_10]], %[[VAL_14]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_reshape"(%[[VAL_11]], %[[VAL_16]]) : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32> +func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<4x5xf32>) : !torch.vtensor<[4,5],f32> + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %true = torch.constant.bool true + %2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32> + return %result0 : !torch.vtensor<[3,7,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.contiguous( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> +} + +// ----- + +// CEHCK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> { +// CEHCK: %int1 = torch.constant.int 1 +// CEHCK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor +// CEHCK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],si64> +// CEHCK: return %[[VAL_1]] : !torch.vtensor<[],si64> +func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> + return %0 : !torch.vtensor<[], si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 64 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 8 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_1]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<8x4x64xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[8,4,64],f32> +func.func @torch.aten.broadcast_to$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> { + %int64 = torch.constant.int 64 + %int4 = torch.constant.int 4 + %int8 = torch.constant.int 8 + %0 = torch.prim.ListConstruct %int8, %int4, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.broadcast_to %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[8,4,64],f32> + return %1 : !torch.vtensor<[8,4,64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.permute$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> +func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> + return %1 : !torch.vtensor<[64,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.transpose$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} \ No newline at end of file