mirror of https://github.com/llvm/torch-mlir
[TOSA] Add legalization for fill, flip, and round (#3768)
- Add Torch to TOSA lowering for aten.fill.Scalar/Tensor, aten.flip, and aten.round - Fix torchScalarToTosaTensor function to correctly convert Torch scalar input to TOSA tensor - Update xfail_sets.py with new e2e results - Update basic.mlir with LIT tests for new ops Change-Id: If1e42c2e582710dd8ad0465eed29806fbcdbde41 Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3749/head
parent
f4840ed886
commit
b08d08682f
|
@ -153,11 +153,17 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Unable to extract the scalar constant");
|
"Unable to extract the scalar constant");
|
||||||
|
|
||||||
|
int64_t numElem = 1;
|
||||||
|
for (int64_t dim : dshape)
|
||||||
|
numElem *= dim;
|
||||||
|
|
||||||
if (isa<mlir::FloatType>(dtype)) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
tosaTensor = tosa::getConstTensor<float>(rewriter, op,
|
tosaTensor =
|
||||||
(isFloat ? doubleValue : intValue),
|
tosa::getConstTensor<float>(
|
||||||
dshape, dtype)
|
rewriter, op,
|
||||||
.value();
|
SmallVector<float>(numElem, (isFloat ? doubleValue : intValue)),
|
||||||
|
dshape, dtype)
|
||||||
|
.value();
|
||||||
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
} else if (auto intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
auto w = intType.getWidth();
|
auto w = intType.getWidth();
|
||||||
if (w != 1 && w != 32 && w != 64)
|
if (w != 1 && w != 32 && w != 64)
|
||||||
|
@ -173,8 +179,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
bool d = isFloat ? static_cast<bool>(doubleValue)
|
bool d = isFloat ? static_cast<bool>(doubleValue)
|
||||||
: static_cast<bool>(intValue);
|
: static_cast<bool>(intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<bool>(
|
||||||
tosa::getConstTensor<bool>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<bool>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
} else if (w == 32) {
|
} else if (w == 32) {
|
||||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -183,8 +190,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||||
: static_cast<int32_t>(intValue);
|
: static_cast<int32_t>(intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<int32_t>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
} else if (w == 64) {
|
} else if (w == 64) {
|
||||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -192,8 +200,9 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
"of destination type");
|
"of destination type");
|
||||||
}
|
}
|
||||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||||
tosaTensor =
|
tosaTensor = tosa::getConstTensor<int64_t>(
|
||||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).value();
|
rewriter, op, SmallVector<int64_t>(numElem, d), dshape)
|
||||||
|
.value();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
return rewriter.notifyMatchFailure(op, "Usupported element type");
|
||||||
|
@ -5320,7 +5329,7 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
@ -5336,18 +5345,48 @@ public:
|
||||||
op, "Only Tensor types with static shapes are currently supported");
|
op, "Only Tensor types with static shapes are currently supported");
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat()) {
|
if (!outElemTy.isIntOrFloat())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point or integer datatype legalization supported");
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
}
|
|
||||||
Value constOp;
|
|
||||||
if (failed(torchScalarToTosaTensor(
|
|
||||||
rewriter, op, op.getValue(), constOp, outElemTy,
|
|
||||||
makeShapeTorchCompatible(outType.getShape()))))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "Supplied value must be a Scalar constant");
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
|
Value fillValueTargetTensor;
|
||||||
|
if constexpr (std::is_same<AtenOpT, AtenFillTensorOp>()) {
|
||||||
|
// Reshape value tensor to have same rank and shape as input
|
||||||
|
auto inputRank =
|
||||||
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
|
|
||||||
|
auto fillValue = adaptor.getValue();
|
||||||
|
auto fillValueType = dyn_cast<TensorType>(fillValue.getType());
|
||||||
|
if (!fillValueType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Fill value is not a tensor");
|
||||||
|
auto fillValueElemTy = fillValueType.getElementType();
|
||||||
|
|
||||||
|
SmallVector<int64_t> fillValueMatchedInputRankShape(inputRank, 1);
|
||||||
|
|
||||||
|
auto fillValueMatchedInputRankType = RankedTensorType::get(
|
||||||
|
makeShapeTorchCompatible(fillValueMatchedInputRankShape),
|
||||||
|
fillValueElemTy);
|
||||||
|
|
||||||
|
auto fillValueMatchedInputRankTensor = rewriter.create<tosa::ReshapeOp>(
|
||||||
|
op->getLoc(), fillValueMatchedInputRankType, fillValue,
|
||||||
|
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));
|
||||||
|
|
||||||
|
fillValueTargetTensor = rewriter.create<tosa::TileOp>(
|
||||||
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
|
||||||
|
fillValueElemTy),
|
||||||
|
fillValueMatchedInputRankTensor.getResult(),
|
||||||
|
makeShapeTorchCompatible(outType.getShape()));
|
||||||
|
} else {
|
||||||
|
if (failed(torchScalarToTosaTensor(
|
||||||
|
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
|
||||||
|
makeShapeTorchCompatible(outType.getShape()))))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Fill value must be a scalar constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType,
|
||||||
|
fillValueTargetTensor);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -5869,6 +5908,127 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.flip
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||||
|
AtenFlipOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types are currently supported");
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only constant dims are currently supported");
|
||||||
|
|
||||||
|
auto selfRank = selfTy.getRank();
|
||||||
|
|
||||||
|
auto resultTy = getTypeConverter()->convertType(op.getType());
|
||||||
|
Value result = self;
|
||||||
|
|
||||||
|
for (auto &dim : dims) {
|
||||||
|
dim = toPositiveDim(dim, selfRank);
|
||||||
|
if (!isValidDim(dim, selfRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "Not all dims are valid");
|
||||||
|
|
||||||
|
result = rewriter.create<tosa::ReverseOp>(op->getLoc(), resultTy, result,
|
||||||
|
static_cast<int32_t>(dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.round:
|
||||||
|
// Rounds elements of input to the nearest integer.
|
||||||
|
// Implements "round half to even" to break ties when a number is equidistant
|
||||||
|
// from two integers.
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
|
||||||
|
AtenRoundOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
// To round to the nearest integer, we will consider the fractional part of
|
||||||
|
// the input element (= input element - integer part of element). If the
|
||||||
|
// fractional part is smaller than 0.5, round the number down. If the
|
||||||
|
// fractional part is 0.5, apply "round half to even" rule. If the fractional
|
||||||
|
// part is greater than 0.5, round up.
|
||||||
|
//
|
||||||
|
// if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)):
|
||||||
|
// res = floor(input)
|
||||||
|
// else:
|
||||||
|
// res = ceil(input)
|
||||||
|
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types supported");
|
||||||
|
|
||||||
|
auto resultTy =
|
||||||
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
|
auto boolTy =
|
||||||
|
RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1));
|
||||||
|
|
||||||
|
auto resultElemTy = resultTy.getElementType();
|
||||||
|
|
||||||
|
auto oneHalf =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, resultElemTy).value();
|
||||||
|
|
||||||
|
auto two =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 2, {}, resultElemTy).value();
|
||||||
|
|
||||||
|
auto floorInput =
|
||||||
|
rewriter.create<tosa::FloorOp>(op->getLoc(), resultTy, self);
|
||||||
|
|
||||||
|
// input - floor(input)
|
||||||
|
auto fractionalPart = rewriter.create<tosa::SubOp>(
|
||||||
|
op->getLoc(), resultTy, self, floorInput.getResult());
|
||||||
|
|
||||||
|
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
|
||||||
|
|
||||||
|
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
|
||||||
|
|
||||||
|
auto floorDivResult = rewriter.create<tosa::FloorOp>(
|
||||||
|
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
|
||||||
|
|
||||||
|
// (floor(input) // 2) * 2
|
||||||
|
auto evenComparison = rewriter.create<tosa::MulOp>(
|
||||||
|
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
|
||||||
|
|
||||||
|
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
|
||||||
|
auto floorInputEven = rewriter.create<tosa::EqualOp>(
|
||||||
|
op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult());
|
||||||
|
|
||||||
|
auto fracEqualOneHalf = rewriter.create<tosa::EqualOp>(
|
||||||
|
op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf);
|
||||||
|
|
||||||
|
auto fracLtOneHalf = rewriter.create<tosa::GreaterOp>(
|
||||||
|
op->getLoc(), boolTy, oneHalf, fractionalPart.getResult());
|
||||||
|
|
||||||
|
// (frac == 0.5) && (floor(input) % 2 == 0)
|
||||||
|
auto fracEqualOneHalfCond = rewriter.create<tosa::LogicalAndOp>(
|
||||||
|
op->getLoc(), boolTy, fracEqualOneHalf.getResult(),
|
||||||
|
floorInputEven.getResult());
|
||||||
|
|
||||||
|
// (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0))
|
||||||
|
auto floorResultCond = rewriter.create<tosa::LogicalOrOp>(
|
||||||
|
op->getLoc(), boolTy, fracLtOneHalf.getResult(),
|
||||||
|
fracEqualOneHalfCond.getResult());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
|
||||||
|
op, resultTy, floorResultCond.getResult(), floorInput.getResult(),
|
||||||
|
ceilInput.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Template to create supporting diagonal mask tensor for aten.diagonal
|
// Template to create supporting diagonal mask tensor for aten.diagonal
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
Value createDiagonalMask(PatternRewriter &rewriter, Operation *op,
|
||||||
|
@ -6052,6 +6212,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -6283,11 +6444,13 @@ public:
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \
|
#define INSERT_FILL_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenFillScalarOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenFillOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp);
|
INSERT_FILL_PATTERN(AtenFill_ScalarOp);
|
||||||
#undef INSERT_FILL_SCALAR_PATTERN
|
INSERT_FILL_PATTERN(AtenFillScalarOp);
|
||||||
|
INSERT_FILL_PATTERN(AtenFillTensorOp);
|
||||||
|
#undef INSERT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
#define INSERT_MASKED_FILL_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
@ -6359,6 +6522,8 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenRoundOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1663,6 +1663,22 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"AtenRoundFloatHalfToEvenModule_basic",
|
||||||
|
"AtenRoundFloatModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
|
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||||
|
"Fill_TensorFloat64WithInt64Static_basic",
|
||||||
|
"FlipModuleStaticShape_basic",
|
||||||
|
"FlipModule_basic",
|
||||||
|
"FlipNegativeIndexModule_basic",
|
||||||
|
"Rot90BasicModule_basic",
|
||||||
|
"Rot90DynamicDimsModule_basic",
|
||||||
|
"Rot90MultipleRotationsModule_basic",
|
||||||
|
"Rot90NegativeEvenRotationsModule_basic",
|
||||||
|
"Rot90NegativeOddRotationsModule_basic",
|
||||||
"AtenLinalgCrossBroadcast_basic",
|
"AtenLinalgCrossBroadcast_basic",
|
||||||
"AtenLinalgCrossCustomDim_basic",
|
"AtenLinalgCrossCustomDim_basic",
|
||||||
"AtenLinalgCrossFloat_basic",
|
"AtenLinalgCrossFloat_basic",
|
||||||
|
@ -1819,7 +1835,6 @@ TOSA_PASS_SET = {
|
||||||
"ArangeStartOutModule_basic",
|
"ArangeStartOutModule_basic",
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
"ArangeStartStepIntModule_basic",
|
"ArangeStartStepIntModule_basic",
|
||||||
"ArangeZeroElementOutputModule_basic",
|
|
||||||
"ArangeDtypeIntModule_basic",
|
"ArangeDtypeIntModule_basic",
|
||||||
"ArangeFalsePinMemoryModule_basic",
|
"ArangeFalsePinMemoryModule_basic",
|
||||||
"ArangeFloatModule_basic",
|
"ArangeFloatModule_basic",
|
||||||
|
@ -2120,7 +2135,6 @@ TOSA_PASS_SET = {
|
||||||
"NormScalarOptDimModule_basic",
|
"NormScalarOptDimModule_basic",
|
||||||
"NumToTensorFloatModule_basic",
|
"NumToTensorFloatModule_basic",
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"NumpyTRank0Module_basic",
|
|
||||||
"NumpyTRank1Module_basic",
|
"NumpyTRank1Module_basic",
|
||||||
"NumpyTRank2Module_basic",
|
"NumpyTRank2Module_basic",
|
||||||
"NumpyTRankNDynamicModule_basic",
|
"NumpyTRankNDynamicModule_basic",
|
||||||
|
@ -2132,7 +2146,6 @@ TOSA_PASS_SET = {
|
||||||
"OnesModuleInt_basic",
|
"OnesModuleInt_basic",
|
||||||
"PadModule_basic",
|
"PadModule_basic",
|
||||||
"PadWithNoneValModule_basic",
|
"PadWithNoneValModule_basic",
|
||||||
"Permute0RankModule_basic",
|
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
|
@ -2171,7 +2184,6 @@ TOSA_PASS_SET = {
|
||||||
"ScalarTensorInt64Module_basic",
|
"ScalarTensorInt64Module_basic",
|
||||||
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
"SiluModule_basic",
|
"SiluModule_basic",
|
||||||
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
|
||||||
"SliceStaticModule_basic",
|
"SliceStaticModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
"SplitTensorGetItem_Module_basic",
|
||||||
"SplitTensorLastSmallerModule_basic",
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
@ -3222,6 +3234,12 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
"ArangeZeroElementOutputModule_basic",
|
||||||
|
"NumpyTRank0Module_basic",
|
||||||
|
"Permute0RankModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
|
"SliceStartEqEndModule_basic",
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
|
@ -3240,11 +3258,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"HstackBasicFloatModule_basic",
|
"HstackBasicFloatModule_basic",
|
||||||
"HstackBasicIntFloatModule_basic",
|
"HstackBasicIntFloatModule_basic",
|
||||||
"HstackBasicIntModule_basic",
|
"HstackBasicIntModule_basic",
|
||||||
"Rot90BasicModule_basic",
|
|
||||||
"Rot90DynamicDimsModule_basic",
|
|
||||||
"Rot90MultipleRotationsModule_basic",
|
|
||||||
"Rot90NegativeEvenRotationsModule_basic",
|
|
||||||
"Rot90NegativeOddRotationsModule_basic",
|
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenKthvalueDynamicDimsModule_basic",
|
"AtenKthvalueDynamicDimsModule_basic",
|
||||||
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
"AtenKthvalueFloat64DynamicDimsModule_basic",
|
||||||
|
@ -3263,7 +3276,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"ElementwiseRreluEvalStaticModule_basic",
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
"ElementwiseRreluTrainModule_basic",
|
"ElementwiseRreluTrainModule_basic",
|
||||||
"ElementwiseRreluTrainStaticModule_basic",
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
|
||||||
"IndexPutWithNoneAndBroadcastModule_basic",
|
"IndexPutWithNoneAndBroadcastModule_basic",
|
||||||
"MaskedScatterStaticBasic_basic",
|
"MaskedScatterStaticBasic_basic",
|
||||||
"MaxUnpool3dModulePad0_basic",
|
"MaxUnpool3dModulePad0_basic",
|
||||||
|
@ -3342,8 +3354,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"AtenMmQuint8_basic",
|
"AtenMmQuint8_basic",
|
||||||
"AtenRealView128Module_basic",
|
"AtenRealView128Module_basic",
|
||||||
"AtenRealView64Module_basic",
|
"AtenRealView64Module_basic",
|
||||||
"AtenRoundFloatHalfToEvenModule_basic",
|
|
||||||
"AtenRoundFloatModule_basic",
|
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
|
@ -3504,20 +3514,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"ExponentialModule_basic",
|
"ExponentialModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
|
||||||
"FakeQuantizePerTensorAffineModule_basic",
|
|
||||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
|
||||||
"Fill_TensorFloat32WithFloat32_basic",
|
|
||||||
"Fill_TensorFloat32WithFloat64_basic",
|
|
||||||
"Fill_TensorFloat32WithInt64_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat32_basic",
|
|
||||||
"Fill_TensorFloat64WithFloat64_basic",
|
|
||||||
"Fill_TensorFloat64WithInt64Static_basic",
|
|
||||||
"Fill_TensorFloat64WithInt64_basic",
|
|
||||||
"FlipModuleStaticShape_basic",
|
|
||||||
"FlipModule_basic",
|
|
||||||
"FlipNegativeIndexModule_basic",
|
|
||||||
"FloatImplicitModule_basic",
|
"FloatImplicitModule_basic",
|
||||||
"FullLikeModuleInt2D_basic",
|
"FullLikeModuleInt2D_basic",
|
||||||
"FullLikeModuleInt3D_basic",
|
"FullLikeModuleInt3D_basic",
|
||||||
|
@ -3847,9 +3843,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
"VarMeanUnbiasedModule_basic",
|
"VarMeanUnbiasedModule_basic",
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
"ZeroFloat32Module_basic",
|
"VisionTransformerModule_basic",
|
||||||
"ZeroInt32Module_basic",
|
|
||||||
"ZeroInt64Module_basic",
|
|
||||||
"ZerosLikeModule_falsePinMemory",
|
"ZerosLikeModule_falsePinMemory",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3862,6 +3856,12 @@ ONNX_TOSA_CRASHING_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_TOSA_XFAIL_SET = {
|
ONNX_TOSA_XFAIL_SET = {
|
||||||
|
"ArangeZeroElementOutputModule_basic",
|
||||||
|
"LinspaceEmptyModule_basic",
|
||||||
|
"RepeatInterleaveSelfIntNoDimModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexStaticModule_basic",
|
||||||
|
"TrilIndicesAllZerosModule_basic",
|
||||||
|
"TriuIndicesAllZerosModule_basic",
|
||||||
"ElementwiseCreateComplexModule_basic",
|
"ElementwiseCreateComplexModule_basic",
|
||||||
"ReduceAllDimFloatModule_basic",
|
"ReduceAllDimFloatModule_basic",
|
||||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||||
|
@ -4026,8 +4026,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"AtenPolarDoubleModule_basic",
|
"AtenPolarDoubleModule_basic",
|
||||||
"AtenRealView128Module_basic",
|
"AtenRealView128Module_basic",
|
||||||
"AtenRealView64Module_basic",
|
"AtenRealView64Module_basic",
|
||||||
"AtenRoundFloatHalfToEvenModule_basic",
|
|
||||||
"AtenRoundFloatModule_basic",
|
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
|
@ -4071,8 +4069,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"BucketizeTensorFloatModule_basic",
|
"BucketizeTensorFloatModule_basic",
|
||||||
"BucketizeTensorModule_basic",
|
"BucketizeTensorModule_basic",
|
||||||
"BucketizeTensorOutInt32RightModule_basic",
|
"BucketizeTensorOutInt32RightModule_basic",
|
||||||
"BucketizeTensorStaticFloatModule_basic",
|
|
||||||
"BucketizeTensorStaticModule_basic",
|
|
||||||
"CeilFloatModule_basic",
|
"CeilFloatModule_basic",
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
|
|
|
@ -1917,3 +1917,84 @@ func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !t
|
||||||
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
%0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32>
|
||||||
return %0 : !torch.vtensor<[4,5,2],f32>
|
return %0 : !torch.vtensor<[4,5,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.fill.Scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.fill.Tensor(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 1, 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1x1xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array<i64: 1, 12, 128, 128>} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> {
|
||||||
|
%0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,12,128,128],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.flip(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.round(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_1]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_2]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.equal %[[VAL_4]], %[[VAL_9]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.equal %[[VAL_5]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.greater %[[VAL_2]], %[[VAL_5]] : (tensor<f32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.logical_and %[[VAL_11]], %[[VAL_10]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = tosa.logical_or %[[VAL_12]], %[[VAL_13]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = tosa.select %[[VAL_14]], %[[VAL_4]], %[[VAL_6]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[VAL_16]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
%0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue