mirror of https://github.com/llvm/torch-mlir
[TOSA] Add upsample_nearest2d, split_dim, outer, GELU tanh mode and misc
- Add Torch to TOSA lowering for the following ops: + torch.aten.upsample_nearest2d + torch.aten.upsample_nearest2d.vec + torch.aten.outer + torch.prims.split_dim - Add Tanh approximation mode for GELU lowering - Add different types support for compare ops - Add different input and output types support for linalg vector norm lowering - Update xfail with new e2e results - Add new LIT tests to basic.mlir Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I7b1d44d94319cf94fcc9d234cc07708ef9ce321epull/3886/head
parent
1b8d7e094b
commit
40d3de502e
|
@ -23,6 +23,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
|
@ -405,6 +406,36 @@ public:
|
|||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType());
|
||||
auto rhsElemTy = rhsTensorTy.getElementType();
|
||||
|
||||
auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
|
||||
auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);
|
||||
|
||||
// Support different types comparisons
|
||||
if (lhsElemTy != rhsElemTy) {
|
||||
if (isLhsElemFloat && !isRhsElemFloat) {
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
|
||||
} else if (!isLhsElemFloat && isRhsElemFloat) {
|
||||
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
|
||||
} else if (isLhsElemFloat && isRhsElemFloat) {
|
||||
auto lhsElemFloatTy = dyn_cast<mlir::FloatType>(lhsElemTy);
|
||||
auto rhsElemFloatTy = dyn_cast<mlir::FloatType>(rhsElemTy);
|
||||
if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) {
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
|
||||
} else {
|
||||
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
|
||||
}
|
||||
} else {
|
||||
auto lhsElemIntTy = dyn_cast<mlir::IntegerType>(lhsElemTy);
|
||||
auto rhsElemIntTy = dyn_cast<mlir::IntegerType>(rhsElemTy);
|
||||
if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) {
|
||||
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
|
||||
} else {
|
||||
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
|
||||
}
|
||||
}
|
||||
}
|
||||
// There is no Lesser operator in TOSA.
|
||||
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
|
||||
|
@ -3196,9 +3227,10 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||
AtenGeluOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types are currently supported");
|
||||
|
@ -3209,21 +3241,97 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
op, "Only floating-point datatype legalization supported");
|
||||
}
|
||||
|
||||
// TODO: Handle approximate.
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
std::string approximate;
|
||||
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none") {
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
||||
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const approximate value not supported");
|
||||
}
|
||||
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
||||
cdf = rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
|
||||
if (approximate.compare("none") == 0) {
|
||||
// GELU(x) = x * CDF(x)
|
||||
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
|
||||
cdf = rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(),
|
||||
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
|
||||
/*shift=*/0);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, cdf,
|
||||
/*shift=*/0);
|
||||
} else if (approximate.compare("tanh") == 0) {
|
||||
// "tanh" approximate
|
||||
// GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
|
||||
auto selfShape = selfType.getShape();
|
||||
auto numElem = std::accumulate(selfShape.begin(), selfShape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
|
||||
Value half = tosa::getConstTensor<float>(rewriter, op,
|
||||
SmallVector<float>(numElem, 0.5),
|
||||
selfShape, selfElemTy)
|
||||
.value();
|
||||
Value one = tosa::getConstTensor<float>(rewriter, op,
|
||||
SmallVector<float>(numElem, 1.0),
|
||||
selfShape, selfElemTy)
|
||||
.value();
|
||||
Value three = tosa::getConstTensor<float>(rewriter, op,
|
||||
SmallVector<float>(numElem, 3.0),
|
||||
selfShape, selfElemTy)
|
||||
.value();
|
||||
|
||||
// 0.044715
|
||||
Value magicNumber = tosa::getConstTensor<float>(
|
||||
rewriter, op, SmallVector<float>(numElem, 0.044715),
|
||||
selfShape, selfElemTy)
|
||||
.value();
|
||||
|
||||
// From <cmath> header: M_2_PI = 2 / pi
|
||||
Value twoOverPi = tosa::getConstTensor<float>(
|
||||
rewriter, op, SmallVector<float>(numElem, M_2_PI),
|
||||
selfShape, selfElemTy)
|
||||
.value();
|
||||
|
||||
// 0.5 * x
|
||||
auto halfInput = rewriter.create<tosa::MulOp>(op->getLoc(), resultType,
|
||||
half, self, /*shift=*/0);
|
||||
|
||||
// sqrt(2/pi)
|
||||
auto sqrtTwoOverPi =
|
||||
rewriter.create<tosa::PowOp>(op->getLoc(), resultType, twoOverPi, half);
|
||||
|
||||
// x^3
|
||||
auto inputPowThree =
|
||||
rewriter.create<tosa::PowOp>(op->getLoc(), resultType, self, three);
|
||||
|
||||
// 0.044715 * x^3
|
||||
auto inputPowThreeMul =
|
||||
rewriter.create<tosa::MulOp>(op->getLoc(), resultType, magicNumber,
|
||||
inputPowThree.getResult(), /*shift=*/0);
|
||||
|
||||
// x + 0.044715 * x^3
|
||||
auto inputPowThreeMulAdd = rewriter.create<tosa::AddOp>(
|
||||
op->getLoc(), resultType, self, inputPowThreeMul.getResult());
|
||||
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
auto sqrtTwoOverPiMul = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), resultType, sqrtTwoOverPi.getResult(),
|
||||
inputPowThreeMulAdd.getResult(), /*shift=*/0);
|
||||
|
||||
// tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
|
||||
auto tanh = rewriter.create<tosa::TanhOp>(op->getLoc(), resultType,
|
||||
sqrtTwoOverPiMul.getResult());
|
||||
|
||||
// 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
|
||||
auto tanhAdd = rewriter.create<tosa::AddOp>(op->getLoc(), resultType, one,
|
||||
tanh.getResult());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, resultType, halfInput.getResult(), tanhAdd.getResult(),
|
||||
/*shift=*/0);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unsupported approximation algorithm");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -7620,6 +7728,298 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for torch.prims.split_dim
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getA();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultShape = resultType.getShape();
|
||||
|
||||
int64_t dim, outerLength;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant int dim value is supported");
|
||||
|
||||
auto selfRank = selfType.getRank();
|
||||
dim = toPositiveDim(dim, selfRank);
|
||||
if (!isValidDim(dim, selfRank))
|
||||
return rewriter.notifyMatchFailure(op, "Dim is invalid");
|
||||
|
||||
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant int outer length value is supported");
|
||||
|
||||
// Technically, I should calculate the output shape based on the dim and outer
|
||||
// length values. However, that would just give the same result as me taking
|
||||
// the result shape straight from resultType and applying tosa::ReshapeOp to
|
||||
// the input. Therefore, I'm opting for the latter approach here, which is
|
||||
// more simple and quicker.
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, resultType, self,
|
||||
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.outer
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
|
||||
AtenOuterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
if (selfType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported");
|
||||
|
||||
auto vec2 = adaptor.getVec2();
|
||||
|
||||
auto vec2Type = dyn_cast<TensorType>(vec2.getType());
|
||||
if (!vec2Type)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
if (vec2Type.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported");
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultShape = resultType.getShape();
|
||||
|
||||
self = tosa::promoteType(rewriter, self, resultType);
|
||||
vec2 = tosa::promoteType(rewriter, vec2, resultType);
|
||||
|
||||
SmallVector<int64_t, 2> resultShapeIndex1Replaced({resultShape[0], 1});
|
||||
SmallVector<int64_t, 2> resultShapeIndex0Replaced({1, resultShape[1]});
|
||||
|
||||
// Reshape and tile self to shape {selfShape[0], resultShape[1]}
|
||||
auto selfReshaped = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(resultShapeIndex1Replaced,
|
||||
resultType.getElementType()),
|
||||
self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
|
||||
|
||||
auto selfTiled = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), resultType, selfReshaped.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
|
||||
|
||||
// Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
|
||||
auto vec2Reshaped = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(resultShapeIndex0Replaced,
|
||||
resultType.getElementType()),
|
||||
vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
|
||||
|
||||
auto vec2Tiled = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(), resultType, vec2Reshaped.getResult(),
|
||||
rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
|
||||
|
||||
auto result =
|
||||
tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(),
|
||||
vec2Tiled.getResult(), /*shift=*/0);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.upsample_nearest2d
|
||||
template <typename AtenOpT>
|
||||
class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// aten.upsample_nearest2d lowering process:
|
||||
// 1. Reshape input: (N, C, H, W) -> (N, C, H x W)
|
||||
// 2. Calculate PyTorch-styled gather op indices based on the following
|
||||
// formula (based on Torch to Linalg UpsampleNearest2d lowering formula):
|
||||
// for i in range(N x C):
|
||||
// for heightIndex in range(scaledHeight):
|
||||
// for widthIndex in range(scaledWidth):
|
||||
// indices.append(int(heightIndex // scalesH * selfWidth +
|
||||
// widthIndex // scalesW))
|
||||
// 3. Convert PyTorch-styled indices to TensorFlow-styled indices
|
||||
// 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output
|
||||
// 5. Reshape output to desired output shape
|
||||
Value self;
|
||||
if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
|
||||
self = adaptor.getSelf();
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
|
||||
self = adaptor.getInput();
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected either AtenUpsampleNearest2dOp or "
|
||||
"AtenUpsampleNearest2dVecOp");
|
||||
}
|
||||
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
auto selfShape = selfType.getShape();
|
||||
auto selfRank = selfType.getRank();
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
|
||||
auto selfHeight = selfShape[selfRank - 2];
|
||||
auto selfWidth = selfShape[selfRank - 1];
|
||||
|
||||
auto resultType = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
auto resultShape = resultType.getShape();
|
||||
auto resultElemTy = resultType.getElementType();
|
||||
|
||||
// Get op's parameters
|
||||
SmallVector<int64_t> outputSize;
|
||||
SmallVector<double> scaleFactors;
|
||||
double scalesH;
|
||||
double scalesW;
|
||||
int64_t outputHeight;
|
||||
int64_t outputWidth;
|
||||
if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
|
||||
if (!matchPattern(op.getOutputSize(),
|
||||
m_TorchListOfConstantInts(outputSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-constant output size not supported");
|
||||
|
||||
outputHeight = outputSize[0];
|
||||
outputWidth = outputSize[1];
|
||||
|
||||
if (isa<Torch::NoneType>(op.getScalesH().getType())) {
|
||||
scalesH =
|
||||
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
|
||||
} else {
|
||||
if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-constant height scales not supported");
|
||||
|
||||
scalesH = std::ceil(scalesH);
|
||||
}
|
||||
|
||||
if (isa<Torch::NoneType>(op.getScalesW().getType())) {
|
||||
scalesW =
|
||||
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
|
||||
} else {
|
||||
if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-constant width scales not supported");
|
||||
|
||||
scalesW = std::ceil(scalesW);
|
||||
}
|
||||
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
|
||||
auto isOutputSizeNone =
|
||||
isa<Torch::NoneType>(op.getOutputSize().getType());
|
||||
auto isScaleFactorsNone =
|
||||
isa<Torch::NoneType>(op.getScaleFactors().getType());
|
||||
|
||||
if ((isOutputSizeNone && isScaleFactorsNone) ||
|
||||
(!isOutputSizeNone && !isScaleFactorsNone))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Must specified exactly one of output size and scale factors");
|
||||
|
||||
if (!isOutputSizeNone) {
|
||||
if (!matchPattern(op.getOutputSize(),
|
||||
m_TorchListOfConstantInts(outputSize))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-constant output size not supported");
|
||||
} else {
|
||||
outputHeight = outputSize[0];
|
||||
outputWidth = outputSize[1];
|
||||
}
|
||||
}
|
||||
|
||||
if (!isScaleFactorsNone) {
|
||||
if (!matchPattern(op.getScaleFactors(),
|
||||
m_TorchListOfConstantFloats(scaleFactors))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-constant output size not supported");
|
||||
} else {
|
||||
scalesH = std::ceil(scaleFactors[0]);
|
||||
scalesW = std::ceil(scaleFactors[1]);
|
||||
}
|
||||
} else {
|
||||
scalesH =
|
||||
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
|
||||
scalesW =
|
||||
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
|
||||
}
|
||||
|
||||
if (isOutputSizeNone) {
|
||||
outputHeight = static_cast<int64_t>(scalesH * selfHeight);
|
||||
outputWidth = static_cast<int64_t>(scalesW * selfWidth);
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape input
|
||||
SmallVector<int64_t> reshapedSelfShape(selfShape.begin(),
|
||||
selfShape.end() - 2);
|
||||
reshapedSelfShape.push_back(selfHeight * selfWidth);
|
||||
|
||||
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy),
|
||||
self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape));
|
||||
|
||||
// Calculate PyTorch-styled gather indices
|
||||
SmallVector<int32_t> targetIndicesVec;
|
||||
int64_t indexRepeat = std::accumulate(
|
||||
selfShape.begin(), selfShape.end() - 2, 1, std::multiplies<int64_t>());
|
||||
for (int64_t i = 0; i < indexRepeat; i++) {
|
||||
for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) {
|
||||
for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) {
|
||||
targetIndicesVec.push_back(static_cast<int32_t>(
|
||||
std::floor(heightIndex / scalesH) * selfWidth +
|
||||
std::floor(widthIndex / scalesW)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> targetIndicesShape(selfShape.begin(),
|
||||
selfShape.end() - 2);
|
||||
targetIndicesShape.push_back(outputHeight * outputWidth);
|
||||
auto targetIndicesTorch =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
|
||||
targetIndicesShape)
|
||||
.value();
|
||||
|
||||
// Convert PyTorch-styled indices to TensorFlow-styled indices
|
||||
auto targetIndicesTF = tosa::convertTorchIndexToTfIndices(
|
||||
rewriter, op, reshapedSelf.getResult(), targetIndicesTorch,
|
||||
selfRank - 2);
|
||||
if (!targetIndicesTF)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Convert PyTorch-styled indices and dim "
|
||||
"to TensorFlow-styled indices failed");
|
||||
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
|
||||
// target elements
|
||||
auto gatherOp = tosa::convertGatherNdOp(
|
||||
rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy),
|
||||
reshapedSelf.getResult(), targetIndicesTF.value());
|
||||
if (!gatherOp)
|
||||
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
|
||||
|
||||
auto result = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), resultType, gatherOp.value(),
|
||||
rewriter.getDenseI64ArrayAttr(resultShape));
|
||||
|
||||
rewriter.replaceOp(op, {result.getResult()});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -7891,6 +8291,13 @@ public:
|
|||
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
|
||||
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
|
||||
|
||||
#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertUpsampleNearest2dForward<AtenOp>>(typeConverter, context);
|
||||
INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp);
|
||||
INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp);
|
||||
#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
|
@ -7950,6 +8357,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenOuterOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1031,11 +1031,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto absVal = CreateOpAndInfer<tosa::AbsOp>(rewriter, op->getLoc(),
|
||||
input_type, input_value)
|
||||
auto input_value_casted =
|
||||
tosa::promoteType(rewriter, input_value, output_type);
|
||||
auto absVal = CreateOpAndInfer<tosa::AbsOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(input_type.getShape(), elemType),
|
||||
input_value_casted)
|
||||
.getResult();
|
||||
auto powVal = CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(),
|
||||
input_type, absVal, ordVal)
|
||||
auto powVal = CreateOpAndInfer<tosa::PowOp>(
|
||||
rewriter, op->getLoc(),
|
||||
RankedTensorType::get(input_type.getShape(), elemType),
|
||||
absVal, ordVal)
|
||||
.getResult();
|
||||
std::optional<Value> result = convertReduceSumOp(
|
||||
rewriter, op, output_type, powVal, axes_elems, keep_dims);
|
||||
|
|
|
@ -1714,27 +1714,26 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
"GridSamplerBasic4_basic",
|
||||
"ScatterSrcModule_basic",
|
||||
"ScatterSrcStaticModule_basic",
|
||||
"HBC_basic",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"UpSampleNearest2d_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"Deg2radModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"L1LossMeanReductionModule_basic",
|
||||
"L1LossNoReductionModule_basic",
|
||||
"L1LossSumReductionModule_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"RandIntLowModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
"ReflectionPad1dModule3dInput_Left",
|
||||
|
@ -3462,8 +3461,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseIntTensorLtFloatTensorModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
|
@ -3471,7 +3468,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"MultinomialModule2D_F32",
|
||||
"MultinomialModule2D_basic",
|
||||
"MultinomialModule_basic",
|
||||
"RenormModuleFloat16_basic",
|
||||
# REMOVE WHEN ENABLE_GQA IS ADDED
|
||||
"ScatterAddStaticModule_basic",
|
||||
"TensorsConcatComplex128FloatModule_basic",
|
||||
|
@ -3635,7 +3631,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseGeluApproximateTanhModule_basic",
|
||||
"ElementwiseIntTensorLtFloatScalarModule_basic",
|
||||
"ElementwiseLog10IntModule_basic",
|
||||
"ElementwiseLog10Module_basic",
|
||||
|
@ -3691,8 +3686,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||
"IndexPutImplIndexWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
|
@ -3703,7 +3696,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"LinalgVectorNormComplexModule_basic",
|
||||
"LinspaceDtypeModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
|
@ -3764,11 +3756,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"OnesLikeModule_falsePinMemory",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
|
@ -3785,9 +3772,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"RandIntLowModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
"RandnDtypeDeviceModule_basic",
|
||||
"RandnGeneratorF64Module_basic",
|
||||
"RandnGeneratorModule_basic",
|
||||
|
@ -3797,26 +3781,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReduceAllDimEmpty_basic",
|
||||
"ReduceFrobeniusNormComplexModule_basic",
|
||||
"ReduceL1NormComplexModule_basic",
|
||||
"ReduceL1NormWithDTypeModule_basic",
|
||||
"ReduceL2NormComplexModule_basic",
|
||||
"ReduceL3NormKeepDimComplexModule_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
"ReflectionPad1dModule3dInput_Left",
|
||||
"ReflectionPad1dModule3dInput_basic",
|
||||
"ReflectionPad2dModule_Bottom",
|
||||
"ReflectionPad2dModule_Left",
|
||||
"ReflectionPad2dModule_Right",
|
||||
"ReflectionPad2dModule_Top",
|
||||
"ReflectionPad2dModule_basic",
|
||||
"ReplicationPad2dModule_basic",
|
||||
"ReplicationPad2dModule_bottom0",
|
||||
"ReplicationPad2dModule_left0",
|
||||
"ReplicationPad2dModule_right0",
|
||||
"ReplicationPad2dModule_top0",
|
||||
"RollModule_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
|
@ -3888,11 +3857,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"UpSampleNearest2dBackwardScalesNone_basic",
|
||||
"UpSampleNearest2dBackward_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
"UpSampleNearest2d_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
|
@ -3939,6 +3903,13 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ColumnStack0dModule_basic",
|
||||
"ColumnStack1dModule_basic",
|
||||
"ColumnStackBasicIntModule_basic",
|
||||
"Deg2radModule_basic",
|
||||
"L1LossMeanReductionModule_basic",
|
||||
"L1LossNoReductionModule_basic",
|
||||
"L1LossSumReductionModule_basic",
|
||||
"FloatPowerTensorTensorStaticModule_basic",
|
||||
"IsInfiniteModule_basic",
|
||||
"ElementwiseCopysignModule_basic",
|
||||
|
@ -4648,7 +4619,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"QuantizedSingleLayer_basic",
|
||||
"RandIntDtypeModule_basic",
|
||||
"RandIntLowDtypeModule_basic",
|
||||
"RandIntLowModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
"RandLikeDtypeModule_basic",
|
||||
|
|
|
@ -2519,3 +2519,117 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3
|
|||
%1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int> -> !torch.vtensor<[1,1,10,6],f32>
|
||||
return %1 : !torch.vtensor<[1,1,10,6],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.outer$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 1>} : (tensor<3xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array<i64: 1, 4>} : (tensor<3x1xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 4>} : (tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 3, 1>} : (tensor<1x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
|
||||
// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> {
|
||||
%0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32>
|
||||
return %0 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.split_dim$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 2, 4, 3, 3>} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1, 2, 2, 2, 3, 3>} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64>
|
||||
// CHECK: }
|
||||
func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prims.split_dim %arg0, %int1, %int2 : !torch.vtensor<[1,8,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,4,3,3],si64>
|
||||
%1 = torch.prims.split_dim %0, %int2, %int2 : !torch.vtensor<[1,2,4,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,3],si64>
|
||||
return %1 : !torch.vtensor<[1,2,2,2,3,3],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,3],f64> -> tensor<1x1x2x3xf64>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 4.000000e+00
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.float 3.000000e+00
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 8
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 9
|
||||
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 1, 6>} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 72, 1>} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 6, 1>} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 72, 3>} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<3xi32>) -> tensor<72x3xi32>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 1, 72>} : (tensor<72x1xi32>) -> tensor<1x72xi32>
|
||||
// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64>
|
||||
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 1, 1, 72>} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64>
|
||||
// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array<i64: 1, 1, 8, 9>} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64>
|
||||
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64>
|
||||
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,8,9],f64>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> {
|
||||
%float4.000000e00 = torch.constant.float 4.000000e+00
|
||||
%float3.000000e00 = torch.constant.float 3.000000e+00
|
||||
%int8 = torch.constant.int 8
|
||||
%int9 = torch.constant.int 9
|
||||
%0 = torch.prim.ListConstruct %int8, %int9 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.upsample_nearest2d %arg0, %0, %float4.000000e00, %float3.000000e00 : !torch.vtensor<[1,1,2,3],f64>, !torch.list<int>, !torch.float, !torch.float -> !torch.vtensor<[1,1,8,9],f64>
|
||||
return %1 : !torch.vtensor<[1,1,8,9],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d.vec$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,5],f32> -> tensor<1x1x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 1, 20>} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 1, 14, 1>} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 1, 20, 1>} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 14, 3>} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<3xi32>) -> tensor<14x3xi32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32>
|
||||
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 14>} : (tensor<14x1xi32>) -> tensor<1x14xi32>
|
||||
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32>
|
||||
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 1, 1, 14>} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32>
|
||||
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 1, 1, 2, 7>} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32>
|
||||
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32>
|
||||
// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,2,7],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> {
|
||||
%none = torch.constant.none
|
||||
%int2 = torch.constant.int 2
|
||||
%int7 = torch.constant.int 7
|
||||
%0 = torch.prim.ListConstruct %int2, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.upsample_nearest2d.vec %arg0, %0, %none : !torch.vtensor<[1,1,4,5],f32>, !torch.list<int>, !torch.none -> !torch.vtensor<[1,1,2,7],f32>
|
||||
return %1 : !torch.vtensor<[1,1,2,7],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue