mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and aten.round - Fix torchScalarToTosaTensor function to correctly convert Torch scalar input to TOSA tensor - Update xfail_sets.py with new e2e results - Update basic.mlir with LIT tests for new ops Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3749/head
parent
f4840ed886
commit
b08d08682f
|
@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.notifyMatchFailure(op,
|
||||
"Unable to extract the scalar constant");
|
||||
|
||||
int64_t numElem = 1;
|
||||
for (int64_t dim : dshape)
|
||||
numElem *= dim;
|
||||
|
||||
if (isa<mlir::FloatType>(dtype)) {
|
||||
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
||||
(isFloat ? doubleValue : intValue),
|
||||
dshape, dtype)
|
||||
.value();
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<float>(
|
||||
rewriter, op,
|
||||
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
|
||||
dshape, dtype)
|
||||
.value();
|
||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
auto w = intType.getWidth();
|
||||
if (w != 1 && w != 32 && w != 64)
|
||||
|
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
bool d = isFloat ? static_cast<bool>(doubleValue)
|
||||
: static_cast<bool>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<bool>(
|
||||
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 32) {
|
||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||
.value();
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
"of destination type");
|
||||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
||||
tosaTensor = tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
|
||||
.value();
|
||||
}
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||
|
@ -5320,7 +5329,7 @@ public:
|
|||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
@ -5336,18 +5345,48 @@ public:
|
|||
op, "Only Tensor types with static shapes are currently supported");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
Value constOp;
|
||||
if (failed(torchScalarToTosaTensor(
|
||||
rewriter, op, op.getValue(), constOp, outElemTy,
|
||||
makeShapeTorchCompatible(outType.getShape()))))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Supplied value must be a Scalar constant");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
||||
Value fillValueTargetTensor;
|
||||
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
|
||||
// Reshape value tensor to have same rank and shape as input
|
||||
auto inputRank =
|
||||
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||
|
||||
auto fillValue = adaptor.getValue();
|
||||
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
|
||||
if (!fillValueType)
|
||||
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
|
||||
auto fillValueElemTy = fillValueType.getElementType();
|
||||
|
||||
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
|
||||
|
||||
auto fillValueMatchedInputRankType = RankedTensorType::get(
|
||||
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
|
||||
fillValueElemTy);
|
||||
|
||||
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), fillValueMatchedInputRankType, fillValue,
|
||||
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
|
||||
|
||||
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
|
||||
fillValueElemTy),
|
||||
fillValueMatchedInputRankTensor.getResult(),
|
||||
makeShapeTorchCompatible(outType.getShape()));
|
||||
} else {
|
||||
if (failed(torchScalarToTosaTensor(
|
||||
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
|
||||
makeShapeTorchCompatible(outType.getShape()))))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Fill value must be a scalar constant");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||
fillValueTargetTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.flip
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||
AtenFlipOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types are currently supported");
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant dims are currently supported");
|
||||
|
||||
auto selfRank = selfTy.getRank();
|
||||
|
||||
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||
Value result = self;
|
||||
|
||||
for (auto &dim : dims) {
|
||||
dim = toPositiveDim(dim, selfRank);
|
||||
if (!isValidDim(dim, selfRank))
|
||||
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||
|
||||
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
|
||||
static_cast<int32_t>(dim));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.round:
|
||||
// Rounds elements of input to the nearest integer.
|
||||
// Implements "round half to even" to break ties when a number is equidistant
|
||||
// from two integers.
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
|
||||
AtenRoundOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// To round to the nearest integer, we will consider the fractional part of
|
||||
// the input element (= input element - integer part of element). If the
|
||||
// fractional part is smaller than 0.5, round the number down. If the
|
||||
// fractional part is 0.5, apply "round half to even" rule. If the fractional
|
||||
// part is greater than 0.5, round up.
|
||||
//
|
||||
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
|
||||
// res = floor(input)
|
||||
// else:
|
||||
// res = ceil(input)
|
||||
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
|
||||
|
||||
auto resultTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
auto boolTy =
|
||||
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
|
||||
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
auto oneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
|
||||
|
||||
auto two =
|
||||
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
|
||||
|
||||
auto floorInput =
|
||||
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
|
||||
|
||||
// input - floor(input)
|
||||
auto fractionalPart = rewriter.create<tosa::SubOp>(
|
||||
op->getLoc(), resultTy, self, floorInput.getResult());
|
||||
|
||||
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
|
||||
|
||||
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
|
||||
|
||||
auto floorDivResult = rewriter.create<tosa::FloorOp>(
|
||||
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
|
||||
|
||||
// (floor(input) // 2) * 2
|
||||
auto evenComparison = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
|
||||
|
||||
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
|
||||
auto floorInputEven = rewriter.create<tosa::EqualOp>(
|
||||
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
|
||||
|
||||
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
|
||||
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
|
||||
|
||||
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
|
||||
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
|
||||
|
||||
// (frac == 0.5) && (floor(input) % 2 == 0)
|
||||
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
|
||||
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
|
||||
floorInputEven.getResult());
|
||||
|
||||
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
|
||||
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
|
||||
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
|
||||
fracEqualOneHalfCond.getResult());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
|
||||
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
|
||||
ceilInput.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||
template <typename T>
|
||||
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -6283,11 +6444,13 @@ public:
|
|||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
|
||||
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
||||
#undef INSERT_FILL_SCALAR_PATTERN
|
||||
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||
#undef INSERT_FILL_PATTERN
|
||||
|
||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
|
@ -6359,6 +6522,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1663,6 +1663,22 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AtenLinalgCrossBroadcast_basic",
|
||||
"AtenLinalgCrossCustomDim_basic",
|
||||
"AtenLinalgCrossFloat_basic",
|
||||
|
@ -1819,7 +1835,6 @@ TOSA_PASS_SET = {
|
|||
"ArangeStartOutModule_basic",
|
||||
"ArangeStartOutViewModule_basic",
|
||||
"ArangeStartStepIntModule_basic",
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"ArangeDtypeIntModule_basic",
|
||||
"ArangeFalsePinMemoryModule_basic",
|
||||
"ArangeFloatModule_basic",
|
||||
|
@ -2120,7 +2135,6 @@ TOSA_PASS_SET = {
|
|||
"NormScalarOptDimModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"NumpyTRank1Module_basic",
|
||||
"NumpyTRank2Module_basic",
|
||||
"NumpyTRankNDynamicModule_basic",
|
||||
|
@ -2132,7 +2146,6 @@ TOSA_PASS_SET = {
|
|||
"OnesModuleInt_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
|
@ -2171,7 +2184,6 @@ TOSA_PASS_SET = {
|
|||
"ScalarTensorInt64Module_basic",
|
||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||
"SiluModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceStaticModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
|
@ -3222,6 +3234,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"NumpyTRank0Module_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
|
@ -3240,11 +3258,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"HstackBasicFloatModule_basic",
|
||||
"HstackBasicIntFloatModule_basic",
|
||||
"HstackBasicIntModule_basic",
|
||||
"Rot90BasicModule_basic",
|
||||
"Rot90DynamicDimsModule_basic",
|
||||
"Rot90MultipleRotationsModule_basic",
|
||||
"Rot90NegativeEvenRotationsModule_basic",
|
||||
"Rot90NegativeOddRotationsModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenKthvalueDynamicDimsModule_basic",
|
||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||
|
@ -3263,7 +3276,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseRreluEvalStaticModule_basic",
|
||||
"ElementwiseRreluTrainModule_basic",
|
||||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||
"MaskedScatterStaticBasic_basic",
|
||||
"MaxUnpool3dModulePad0_basic",
|
||||
|
@ -3342,8 +3354,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenMmQuint8_basic",
|
||||
"AtenRealView128Module_basic",
|
||||
"AtenRealView64Module_basic",
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
|
@ -3504,20 +3514,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"EqIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"ExponentialModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"Fill_TensorFloat32WithFloat32_basic",
|
||||
"Fill_TensorFloat32WithFloat64_basic",
|
||||
"Fill_TensorFloat32WithInt64_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithFloat32_basic",
|
||||
"Fill_TensorFloat64WithFloat64_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"Fill_TensorFloat64WithInt64_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"FullLikeModuleInt2D_basic",
|
||||
"FullLikeModuleInt3D_basic",
|
||||
|
@ -3847,9 +3843,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"VarMeanUnbiasedModule_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"ZeroFloat32Module_basic",
|
||||
"ZeroInt32Module_basic",
|
||||
"ZeroInt64Module_basic",
|
||||
"VisionTransformerModule_basic",
|
||||
"ZerosLikeModule_falsePinMemory",
|
||||
}
|
||||
|
||||
|
@ -3862,6 +3856,12 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ArangeZeroElementOutputModule_basic",
|
||||
"LinspaceEmptyModule_basic",
|
||||
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||
"TrilIndicesAllZerosModule_basic",
|
||||
"TriuIndicesAllZerosModule_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ReduceAllDimFloatModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
|
@ -4026,8 +4026,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"AtenPolarDoubleModule_basic",
|
||||
"AtenRealView128Module_basic",
|
||||
"AtenRealView64Module_basic",
|
||||
"AtenRoundFloatHalfToEvenModule_basic",
|
||||
"AtenRoundFloatModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
|
@ -4071,8 +4069,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"BucketizeTensorFloatModule_basic",
|
||||
"BucketizeTensorModule_basic",
|
||||
"BucketizeTensorOutInt32RightModule_basic",
|
||||
"BucketizeTensorStaticFloatModule_basic",
|
||||
"BucketizeTensorStaticModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ChunkListUnpackDynamic_Module_basic",
|
||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||
|
|
|
@ -1917,3 +1917,84 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t
|
|||
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
||||
return %0 : !torch.vtensor<[4,5,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
|
||||
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.flip(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||
return %1 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.round(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue