[TOSA] Add upsample_nearest2d, split_dim, outer, GELU tanh mode and misc

- Add Torch to TOSA lowering for the following ops:
  + torch.aten.upsample_nearest2d
  + torch.aten.upsample_nearest2d.vec
  + torch.aten.outer
  + torch.prims.split_dim
- Add Tanh approximation mode for GELU lowering
- Add different types support for compare ops
- Add different input and output types support for linalg vector norm
  lowering
- Update xfail with new e2e results
- Add new LIT tests to basic.mlir

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
Change-Id: I7b1d44d94319cf94fcc9d234cc07708ef9ce321e
pull/3886/head
Justin Ngo 2024-11-20 22:39:17 +00:00
parent 1b8d7e094b
commit 40d3de502e
4 changed files with 564 additions and 65 deletions

View File

@ -23,6 +23,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cmath>
#include <numeric>
#include <optional>
#include <random>
@ -405,6 +406,36 @@ public:
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
auto rhsTensorTy = dyn_cast<TensorType>(rhsTensor.getType());
auto rhsElemTy = rhsTensorTy.getElementType();
auto isLhsElemFloat = isa<mlir::FloatType>(lhsElemTy);
auto isRhsElemFloat = isa<mlir::FloatType>(rhsElemTy);
// Support different types comparisons
if (lhsElemTy != rhsElemTy) {
if (isLhsElemFloat && !isRhsElemFloat) {
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
} else if (!isLhsElemFloat && isRhsElemFloat) {
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
} else if (isLhsElemFloat && isRhsElemFloat) {
auto lhsElemFloatTy = dyn_cast<mlir::FloatType>(lhsElemTy);
auto rhsElemFloatTy = dyn_cast<mlir::FloatType>(rhsElemTy);
if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) {
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
} else {
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
}
} else {
auto lhsElemIntTy = dyn_cast<mlir::IntegerType>(lhsElemTy);
auto rhsElemIntTy = dyn_cast<mlir::IntegerType>(rhsElemTy);
if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) {
rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy);
} else {
lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy);
}
}
}
// There is no Lesser operator in TOSA.
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
@ -3196,9 +3227,10 @@ template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
// Not a tensor type.
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types are currently supported");
@ -3209,21 +3241,97 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
op, "Only floating-point datatype legalization supported");
}
// TODO: Handle approximate.
auto resultType =
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
std::string approximate;
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) ||
approximate != "none") {
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) {
return rewriter.notifyMatchFailure(
op, "Non-const approximate value not supported");
}
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
cdf = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(),
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
if (approximate.compare("none") == 0) {
// GELU(x) = x * CDF(x)
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
cdf = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(),
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, cdf,
/*shift=*/0);
} else if (approximate.compare("tanh") == 0) {
// "tanh" approximate
// GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
auto selfShape = selfType.getShape();
auto numElem = std::accumulate(selfShape.begin(), selfShape.end(), 1,
std::multiplies<int64_t>());
Value half = tosa::getConstTensor<float>(rewriter, op,
SmallVector<float>(numElem, 0.5),
selfShape, selfElemTy)
.value();
Value one = tosa::getConstTensor<float>(rewriter, op,
SmallVector<float>(numElem, 1.0),
selfShape, selfElemTy)
.value();
Value three = tosa::getConstTensor<float>(rewriter, op,
SmallVector<float>(numElem, 3.0),
selfShape, selfElemTy)
.value();
// 0.044715
Value magicNumber = tosa::getConstTensor<float>(
rewriter, op, SmallVector<float>(numElem, 0.044715),
selfShape, selfElemTy)
.value();
// From <cmath> header: M_2_PI = 2 / pi
Value twoOverPi = tosa::getConstTensor<float>(
rewriter, op, SmallVector<float>(numElem, M_2_PI),
selfShape, selfElemTy)
.value();
// 0.5 * x
auto halfInput = rewriter.create<tosa::MulOp>(op->getLoc(), resultType,
half, self, /*shift=*/0);
// sqrt(2/pi)
auto sqrtTwoOverPi =
rewriter.create<tosa::PowOp>(op->getLoc(), resultType, twoOverPi, half);
// x^3
auto inputPowThree =
rewriter.create<tosa::PowOp>(op->getLoc(), resultType, self, three);
// 0.044715 * x^3
auto inputPowThreeMul =
rewriter.create<tosa::MulOp>(op->getLoc(), resultType, magicNumber,
inputPowThree.getResult(), /*shift=*/0);
// x + 0.044715 * x^3
auto inputPowThreeMulAdd = rewriter.create<tosa::AddOp>(
op->getLoc(), resultType, self, inputPowThreeMul.getResult());
// sqrt(2/pi) * (x + 0.044715 * x^3)
auto sqrtTwoOverPiMul = rewriter.create<tosa::MulOp>(
op->getLoc(), resultType, sqrtTwoOverPi.getResult(),
inputPowThreeMulAdd.getResult(), /*shift=*/0);
// tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
auto tanh = rewriter.create<tosa::TanhOp>(op->getLoc(), resultType,
sqrtTwoOverPiMul.getResult());
// 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
auto tanhAdd = rewriter.create<tosa::AddOp>(op->getLoc(), resultType, one,
tanh.getResult());
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op, resultType, halfInput.getResult(), tanhAdd.getResult(),
/*shift=*/0);
} else {
return rewriter.notifyMatchFailure(op,
"Unsupported approximation algorithm");
}
return success();
}
@ -7620,6 +7728,298 @@ LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
return success();
}
// Legalization for torch.prims.split_dim
template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getA();
// Not a tensor type
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape();
int64_t dim, outerLength;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Only constant int dim value is supported");
auto selfRank = selfType.getRank();
dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "Dim is invalid");
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
return rewriter.notifyMatchFailure(
op, "Only constant int outer length value is supported");
// Technically, I should calculate the output shape based on the dim and outer
// length values. However, that would just give the same result as me taking
// the result shape straight from resultType and applying tosa::ReshapeOp to
// the input. Therefore, I'm opting for the latter approach here, which is
// more simple and quicker.
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultType, self,
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
return success();
}
// Legalization for aten.outer
template <>
LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
AtenOuterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
if (selfType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported");
auto vec2 = adaptor.getVec2();
auto vec2Type = dyn_cast<TensorType>(vec2.getType());
if (!vec2Type)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
if (vec2Type.getRank() != 1)
return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported");
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape();
self = tosa::promoteType(rewriter, self, resultType);
vec2 = tosa::promoteType(rewriter, vec2, resultType);
SmallVector<int64_t, 2> resultShapeIndex1Replaced({resultShape[0], 1});
SmallVector<int64_t, 2> resultShapeIndex0Replaced({1, resultShape[1]});
// Reshape and tile self to shape {selfShape[0], resultShape[1]}
auto selfReshaped = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(resultShapeIndex1Replaced,
resultType.getElementType()),
self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
auto selfTiled = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, selfReshaped.getResult(),
rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
// Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
auto vec2Reshaped = rewriter.create<tosa::ReshapeOp>(
op->getLoc(),
RankedTensorType::get(resultShapeIndex0Replaced,
resultType.getElementType()),
vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
auto vec2Tiled = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, vec2Reshaped.getResult(),
rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
auto result =
tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(),
vec2Tiled.getResult(), /*shift=*/0);
rewriter.replaceOp(op, result);
return success();
}
// Legalization for aten.upsample_nearest2d
template <typename AtenOpT>
class ConvertUpsampleNearest2dForward : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// aten.upsample_nearest2d lowering process:
// 1. Reshape input: (N, C, H, W) -> (N, C, H x W)
// 2. Calculate PyTorch-styled gather op indices based on the following
// formula (based on Torch to Linalg UpsampleNearest2d lowering formula):
// for i in range(N x C):
// for heightIndex in range(scaledHeight):
// for widthIndex in range(scaledWidth):
// indices.append(int(heightIndex // scalesH * selfWidth +
// widthIndex // scalesW))
// 3. Convert PyTorch-styled indices to TensorFlow-styled indices
// 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output
// 5. Reshape output to desired output shape
Value self;
if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
self = adaptor.getSelf();
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
self = adaptor.getInput();
} else {
return rewriter.notifyMatchFailure(
op, "Expected either AtenUpsampleNearest2dOp or "
"AtenUpsampleNearest2dVecOp");
}
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();
auto selfHeight = selfShape[selfRank - 2];
auto selfWidth = selfShape[selfRank - 1];
auto resultType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
auto resultShape = resultType.getShape();
auto resultElemTy = resultType.getElementType();
// Get op's parameters
SmallVector<int64_t> outputSize;
SmallVector<double> scaleFactors;
double scalesH;
double scalesW;
int64_t outputHeight;
int64_t outputWidth;
if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
if (!matchPattern(op.getOutputSize(),
m_TorchListOfConstantInts(outputSize)))
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");
outputHeight = outputSize[0];
outputWidth = outputSize[1];
if (isa<Torch::NoneType>(op.getScalesH().getType())) {
scalesH =
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
} else {
if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH)))
return rewriter.notifyMatchFailure(
op, "Non-constant height scales not supported");
scalesH = std::ceil(scalesH);
}
if (isa<Torch::NoneType>(op.getScalesW().getType())) {
scalesW =
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
} else {
if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW)))
return rewriter.notifyMatchFailure(
op, "Non-constant width scales not supported");
scalesW = std::ceil(scalesW);
}
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
auto isOutputSizeNone =
isa<Torch::NoneType>(op.getOutputSize().getType());
auto isScaleFactorsNone =
isa<Torch::NoneType>(op.getScaleFactors().getType());
if ((isOutputSizeNone && isScaleFactorsNone) ||
(!isOutputSizeNone && !isScaleFactorsNone))
return rewriter.notifyMatchFailure(
op, "Must specified exactly one of output size and scale factors");
if (!isOutputSizeNone) {
if (!matchPattern(op.getOutputSize(),
m_TorchListOfConstantInts(outputSize))) {
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");
} else {
outputHeight = outputSize[0];
outputWidth = outputSize[1];
}
}
if (!isScaleFactorsNone) {
if (!matchPattern(op.getScaleFactors(),
m_TorchListOfConstantFloats(scaleFactors))) {
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");
} else {
scalesH = std::ceil(scaleFactors[0]);
scalesW = std::ceil(scaleFactors[1]);
}
} else {
scalesH =
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
scalesW =
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
}
if (isOutputSizeNone) {
outputHeight = static_cast<int64_t>(scalesH * selfHeight);
outputWidth = static_cast<int64_t>(scalesW * selfWidth);
}
}
// Reshape input
SmallVector<int64_t> reshapedSelfShape(selfShape.begin(),
selfShape.end() - 2);
reshapedSelfShape.push_back(selfHeight * selfWidth);
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy),
self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape));
// Calculate PyTorch-styled gather indices
SmallVector<int32_t> targetIndicesVec;
int64_t indexRepeat = std::accumulate(
selfShape.begin(), selfShape.end() - 2, 1, std::multiplies<int64_t>());
for (int64_t i = 0; i < indexRepeat; i++) {
for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) {
for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) {
targetIndicesVec.push_back(static_cast<int32_t>(
std::floor(heightIndex / scalesH) * selfWidth +
std::floor(widthIndex / scalesW)));
}
}
}
SmallVector<int64_t> targetIndicesShape(selfShape.begin(),
selfShape.end() - 2);
targetIndicesShape.push_back(outputHeight * outputWidth);
auto targetIndicesTorch =
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
targetIndicesShape)
.value();
// Convert PyTorch-styled indices to TensorFlow-styled indices
auto targetIndicesTF = tosa::convertTorchIndexToTfIndices(
rewriter, op, reshapedSelf.getResult(), targetIndicesTorch,
selfRank - 2);
if (!targetIndicesTF)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch-styled indices and dim "
"to TensorFlow-styled indices failed");
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
// target elements
auto gatherOp = tosa::convertGatherNdOp(
rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy),
reshapedSelf.getResult(), targetIndicesTF.value());
if (!gatherOp)
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
auto result = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), resultType, gatherOp.value(),
rewriter.getDenseI64ArrayAttr(resultShape));
rewriter.replaceOp(op, {result.getResult()});
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
@ -7891,6 +8291,13 @@ public:
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN
#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertUpsampleNearest2dForward<AtenOp>>(typeConverter, context);
INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp);
INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp);
#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
@ -7950,6 +8357,8 @@ public:
INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp);
INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp);
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
INSERT_ATENOP_PATTERN(AtenOuterOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1031,11 +1031,17 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
return std::nullopt;
}
auto absVal = CreateOpAndInfer<tosa::AbsOp>(rewriter, op->getLoc(),
input_type, input_value)
auto input_value_casted =
tosa::promoteType(rewriter, input_value, output_type);
auto absVal = CreateOpAndInfer<tosa::AbsOp>(
rewriter, op->getLoc(),
RankedTensorType::get(input_type.getShape(), elemType),
input_value_casted)
.getResult();
auto powVal = CreateOpAndInfer<tosa::PowOp>(rewriter, op->getLoc(),
input_type, absVal, ordVal)
auto powVal = CreateOpAndInfer<tosa::PowOp>(
rewriter, op->getLoc(),
RankedTensorType::get(input_type.getShape(), elemType),
absVal, ordVal)
.getResult();
std::optional<Value> result = convertReduceSumOp(
rewriter, op, output_type, powVal, axes_elems, keep_dims);

View File

@ -1714,27 +1714,26 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
"GridSamplerBasic4_basic",
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
"InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"UpSampleNearest2d_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dStaticFactor_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"Deg2radModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"L1LossMeanReductionModule_basic",
"L1LossNoReductionModule_basic",
"L1LossSumReductionModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RenormModuleFloat16_basic",
"SplitDimStaticModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
@ -3462,8 +3461,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic",
@ -3471,7 +3468,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"MultinomialModule2D_F32",
"MultinomialModule2D_basic",
"MultinomialModule_basic",
"RenormModuleFloat16_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScatterAddStaticModule_basic",
"TensorsConcatComplex128FloatModule_basic",
@ -3635,7 +3631,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ElementwiseExpIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseLog10IntModule_basic",
"ElementwiseLog10Module_basic",
@ -3691,8 +3686,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"IntFloatModule_basic",
"IntImplicitModule_basic",
@ -3703,7 +3696,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"LinalgVectorNormComplexModule_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
@ -3764,11 +3756,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
@ -3785,9 +3772,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandnDtypeDeviceModule_basic",
"RandnGeneratorF64Module_basic",
"RandnGeneratorModule_basic",
@ -3797,26 +3781,11 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"ReduceAllDimEmpty_basic",
"ReduceFrobeniusNormComplexModule_basic",
"ReduceL1NormComplexModule_basic",
"ReduceL1NormWithDTypeModule_basic",
"ReduceL2NormComplexModule_basic",
"ReduceL3NormKeepDimComplexModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
"ReflectionPad1dModule3dInput_basic",
"ReflectionPad2dModule_Bottom",
"ReflectionPad2dModule_Left",
"ReflectionPad2dModule_Right",
"ReflectionPad2dModule_Top",
"ReflectionPad2dModule_basic",
"ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0",
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"RollModule_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
@ -3888,11 +3857,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic",
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
@ -3939,6 +3903,13 @@ ONNX_TOSA_CRASHING_SET = {
}
ONNX_TOSA_XFAIL_SET = {
"ColumnStack0dModule_basic",
"ColumnStack1dModule_basic",
"ColumnStackBasicIntModule_basic",
"Deg2radModule_basic",
"L1LossMeanReductionModule_basic",
"L1LossNoReductionModule_basic",
"L1LossSumReductionModule_basic",
"FloatPowerTensorTensorStaticModule_basic",
"IsInfiniteModule_basic",
"ElementwiseCopysignModule_basic",
@ -4648,7 +4619,6 @@ ONNX_TOSA_XFAIL_SET = {
"QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",

View File

@ -2519,3 +2519,117 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3
%1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int> -> !torch.vtensor<[1,1,10,6],f32>
return %1 : !torch.vtensor<[1,1,10,6],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.outer$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 3, 1>} : (tensor<3xf32>) -> tensor<3x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array<i64: 1, 4>} : (tensor<3x1xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 4>} : (tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 3, 1>} : (tensor<1x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> {
%0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
// -----
// CHECK-LABEL: func.func @torch.prims.split_dim$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 2, 4, 3, 3>} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 1, 2, 2, 2, 3, 3>} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64>
// CHECK: }
func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prims.split_dim %arg0, %int1, %int2 : !torch.vtensor<[1,8,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,4,3,3],si64>
%1 = torch.prims.split_dim %0, %int2, %int2 : !torch.vtensor<[1,2,4,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,3],si64>
return %1 : !torch.vtensor<[1,2,2,2,3,3],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,3],f64> -> tensor<1x1x2x3xf64>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 4.000000e+00
// CHECK: %[[VAL_3:.*]] = torch.constant.float 3.000000e+00
// CHECK: %[[VAL_4:.*]] = torch.constant.int 8
// CHECK: %[[VAL_5:.*]] = torch.constant.int 9
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 1, 6>} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64>
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32>
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 1, 72, 1>} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32>
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32>
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 6, 1>} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64>
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 72, 3>} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<3xi32>) -> tensor<72x3xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32>
// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array<i64: 1, 72>} : (tensor<72x1xi32>) -> tensor<1x72xi32>
// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 1, 1, 72>} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64>
// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array<i64: 1, 1, 8, 9>} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,8,9],f64>
// CHECK: }
func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> {
%float4.000000e00 = torch.constant.float 4.000000e+00
%float3.000000e00 = torch.constant.float 3.000000e+00
%int8 = torch.constant.int 8
%int9 = torch.constant.int 9
%0 = torch.prim.ListConstruct %int8, %int9 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.upsample_nearest2d %arg0, %0, %float4.000000e00, %float3.000000e00 : !torch.vtensor<[1,1,2,3],f64>, !torch.list<int>, !torch.float, !torch.float -> !torch.vtensor<[1,1,8,9],f64>
return %1 : !torch.vtensor<[1,1,8,9],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d.vec$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,5],f32> -> tensor<1x1x4x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.none
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 7
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 1, 20>} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 1, 1, 14, 1>} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32>
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32>
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32>
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array<i64: 1, 20, 1>} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 14, 3>} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<3xi32>) -> tensor<14x3xi32>
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 14>} : (tensor<14x1xi32>) -> tensor<1x14xi32>
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 1, 1, 14>} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 1, 1, 2, 7>} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32>
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32>
// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,2,7],f32>
// CHECK: }
func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int7 = torch.constant.int 7
%0 = torch.prim.ListConstruct %int2, %int7 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.upsample_nearest2d.vec %arg0, %0, %none : !torch.vtensor<[1,1,4,5],f32>, !torch.list<int>, !torch.none -> !torch.vtensor<[1,1,2,7],f32>
return %1 : !torch.vtensor<[1,1,2,7],f32>
}