mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix rsub; add clamp.Tensor, avg_pool1d, max_pool1d, prims.collapse (#3855)
- Fix aten.rsub.Scalar legalization with appropriate type casting - Add legalization for aten.clamp.Tensor - Resolve some unexpected test failures from PyTorch update by adding legalization for the following ops: + aten.avg_pool1d + aten.max_pool1d + torch.prims.collapse - Update xfail_sets with new e2e results - Add new LIT tests to basic.mlir Change-Id: I9762c7d36ca0b0f75ca68d0c71d7f5d5309a96ad --------- Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3860/head
parent
8519ecc4d7
commit
b6f04fa32b
|
@ -2072,25 +2072,29 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||
|
||||
auto resultTy =
|
||||
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
self = tosa::promoteType(rewriter, self, resultTy);
|
||||
|
||||
Value otherTensor, alphaTensor;
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||
selfTy.getElementType(), {})))
|
||||
resultElemTy, {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Rsub operation");
|
||||
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
||||
alphaTensor, selfTy.getElementType(),
|
||||
alphaTensor, resultElemTy,
|
||||
/*checkForUnity=*/true)))
|
||||
return failure();
|
||||
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self,
|
||||
alphaTensor, /*shift=*/0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, resultTy, otherTensor,
|
||||
multTensor);
|
||||
|
||||
return success();
|
||||
|
@ -4730,6 +4734,108 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for aten.clamp.Tensor
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
|
||||
AtenClampTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// We are not using tosa.clamp to lower aten.clamp.Tensor, as
|
||||
// aten.clamp.Tensor's min and max attributes are tensors that can have size
|
||||
// greater than 1, which is not compatible with tosa.clamp.
|
||||
//
|
||||
// Instead, we use the following formula:
|
||||
// yi = min(max(xi, min_valuei), max_valuei)
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
|
||||
// Get min tensor. If None, there is no lower bound.
|
||||
Value min;
|
||||
if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) {
|
||||
min = adaptor.getMin();
|
||||
} else {
|
||||
min =
|
||||
TypeSwitch<Type, Value>(selfElemTy)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return tosa::getConstTensor<float>(
|
||||
rewriter, op, std::numeric_limits<float>::lowest(), {},
|
||||
selfElemTy)
|
||||
.value();
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 8:
|
||||
return tosa::getConstTensor<int8_t>(
|
||||
rewriter, op, std::numeric_limits<int8_t>::min(), {})
|
||||
.value();
|
||||
case 32:
|
||||
return tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, std::numeric_limits<int32_t>::min(),
|
||||
{})
|
||||
.value();
|
||||
case 64:
|
||||
return tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, std::numeric_limits<int64_t>::min(),
|
||||
{})
|
||||
.value();
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
}
|
||||
|
||||
// Get max tensor. If None, there is no upper bound.
|
||||
Value max;
|
||||
if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) {
|
||||
max = adaptor.getMax();
|
||||
} else {
|
||||
max =
|
||||
TypeSwitch<Type, Value>(selfElemTy)
|
||||
.Case<mlir::FloatType>([&](auto) {
|
||||
return tosa::getConstTensor<float>(
|
||||
rewriter, op, std::numeric_limits<float>::max(), {},
|
||||
selfElemTy)
|
||||
.value();
|
||||
})
|
||||
.Case<mlir::IntegerType>([&](auto intType) {
|
||||
switch (intType.getWidth()) {
|
||||
case 8:
|
||||
return tosa::getConstTensor<int8_t>(
|
||||
rewriter, op, std::numeric_limits<int8_t>::max(), {})
|
||||
.value();
|
||||
case 32:
|
||||
return tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, std::numeric_limits<int32_t>::max(),
|
||||
{})
|
||||
.value();
|
||||
case 64:
|
||||
return tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, std::numeric_limits<int64_t>::max(),
|
||||
{})
|
||||
.value();
|
||||
}
|
||||
llvm_unreachable("Invalid integer width");
|
||||
});
|
||||
}
|
||||
|
||||
// max(xi, min_valuei)
|
||||
auto minThresholdCheck = tosa::createBinaryOpAndCast<tosa::MaximumOp>(
|
||||
rewriter, op, resultType, self, min);
|
||||
|
||||
// yi = min(max(xi, min_valuei), max_valuei)
|
||||
auto result = tosa::createBinaryOpAndCast<tosa::MinimumOp>(
|
||||
rewriter, op, resultType, minThresholdCheck, max);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
||||
AtenArangeStartStepOp op, OpAdaptor adaptor,
|
||||
|
@ -5236,11 +5342,29 @@ public:
|
|||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingOutputToChw(
|
||||
op, rewriter, pooledOutput);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op,
|
||||
Value result = transposedOutput;
|
||||
auto resultTy = dyn_cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
transposedOutput);
|
||||
op.getType()));
|
||||
|
||||
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||
std::is_same<AtenOpT, AtenAvgPool1dOp>()) {
|
||||
auto resultShape = resultTy.getShape();
|
||||
auto resultElemTy = resultTy.getElementType();
|
||||
|
||||
result = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(resultShape),
|
||||
resultElemTy),
|
||||
transposedOutput,
|
||||
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, resultTy,
|
||||
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
// op.getType()),
|
||||
result);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -5387,6 +5511,12 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const kernel_size for pooling op unsupported");
|
||||
|
||||
// Expand kernel size parameter to size 2 to be compatible with
|
||||
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||
kernelSizeInts.push_back(1);
|
||||
|
||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const stride for pooling op unsupported");
|
||||
|
@ -5394,13 +5524,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
|||
// list during import. For such a case, the stride value is the kernel size.
|
||||
// See:
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
|
||||
if (strideInts.empty())
|
||||
if (strideInts.empty()) {
|
||||
strideInts.assign(kernelSizeInts);
|
||||
} else {
|
||||
// Expand stride parameter to size 2 to be compatible with
|
||||
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||
strideInts.push_back(1);
|
||||
}
|
||||
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const padding factor for pooling op unsupported");
|
||||
|
||||
// Expand padding parameter to size 2 to be compatible with
|
||||
// tosa::MaxPool2dOp or tosa::AvgPool2dOp
|
||||
if constexpr (std::is_same<AtenOpT, AtenMaxPool1dOp>() ||
|
||||
std::is_same<AtenOpT, AtenAvgPool1dOp>())
|
||||
paddingInts.push_back(0);
|
||||
|
||||
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
|
||||
paddingInts[1], paddingInts[1]};
|
||||
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
|
||||
|
@ -5456,6 +5599,68 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Legalization for aten.max_pool1d
|
||||
class ConvertAtenMaxPool1dOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenMaxPool1dOp,
|
||||
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||
LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
DenseI64ArrayAttr &kernel,
|
||||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||
Type &outputTy) const override {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a RankedTensorType
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor type inputs are supported");
|
||||
auto selfShape = selfTy.getShape();
|
||||
|
||||
// Expected a rank 3 input tensor
|
||||
if (selfTy.getRank() != 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input tensor for MaxPool1d should have rank 3");
|
||||
|
||||
// Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp
|
||||
SmallVector<int64_t> rank4Shape(selfShape);
|
||||
rank4Shape.push_back(1);
|
||||
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
|
||||
selfTy.getElementType()),
|
||||
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
|
||||
|
||||
SmallVector<int64_t> dilationArray;
|
||||
if (!matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilationArray)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Non-const dilation for pooling op unsupported.");
|
||||
// TOSA pooling only supports unit dilation.
|
||||
if (dilationArray[0] > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Cannot process non-unit pooling dilation.");
|
||||
|
||||
// Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp
|
||||
dilationArray.push_back(1);
|
||||
|
||||
if (failed(getOutputTypeAndPoolingParameters<AtenMaxPool1dOp,
|
||||
tosa::MaxPool2dOp>(
|
||||
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
|
||||
kernel, stride, pad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid pooling parameters or input type");
|
||||
|
||||
// Transpose to xHWC
|
||||
input = ConvertAtenPoolingBaseOp<AtenMaxPool1dOp, tosa::MaxPool2dOp>::
|
||||
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertAtenAvgPool2dOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp> {
|
||||
public:
|
||||
|
@ -5504,6 +5709,68 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Legalization for aten.avg_pool1d
|
||||
class ConvertAtenAvgPool1dOp
|
||||
: public ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp> {
|
||||
public:
|
||||
using ConvertAtenPoolingBaseOp<AtenAvgPool1dOp,
|
||||
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||
LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &input,
|
||||
DenseI64ArrayAttr &kernel,
|
||||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||
Type &outputTy) const override {
|
||||
auto self = adaptor.getSelf();
|
||||
|
||||
// Not a RankedTensorType
|
||||
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor type inputs are supported");
|
||||
auto selfShape = selfTy.getShape();
|
||||
|
||||
// Expected a rank 3 input tensor
|
||||
if (selfTy.getRank() != 3)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input tensor for MaxPool1d should have rank 3");
|
||||
|
||||
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
|
||||
SmallVector<int64_t> rank4Shape(selfShape);
|
||||
rank4Shape.push_back(1);
|
||||
auto reshapedSelf = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(makeShapeTorchCompatible(rank4Shape),
|
||||
selfTy.getElementType()),
|
||||
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
|
||||
|
||||
// Currently, we can not represent `count_include_pad` with the existing
|
||||
// TOSA AvgPool2d specification. Without the below check, we produce silent
|
||||
// wrong answers (SWA) when the `count_include_pad` value is `true.`
|
||||
bool countIncludePad;
|
||||
if (!matchPattern(op.getCountIncludePad(),
|
||||
m_TorchConstantBool(&countIncludePad)) ||
|
||||
countIncludePad) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
|
||||
"`count_include_pad` value should be `False`.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
|
||||
tosa::AvgPool2dOp>(
|
||||
op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy,
|
||||
kernel, stride, pad)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid pooling parameters or input type");
|
||||
|
||||
// Transpose to xHWC
|
||||
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
|
||||
transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Ref: Error checking based on the Torch to LinAlg lowering
|
||||
template <typename AtenOpT, int fillVal>
|
||||
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
|
||||
|
@ -6880,6 +7147,49 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Legalization for torch.prims.collapse
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||
PrimsCollapseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.getA();
|
||||
|
||||
// Not a tensor type
|
||||
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||
if (!selfType)
|
||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||
|
||||
auto resultType =
|
||||
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||
auto resultShape = resultType.getShape();
|
||||
|
||||
int64_t start, end;
|
||||
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant int start value is supported");
|
||||
|
||||
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant int end value is supported");
|
||||
|
||||
// Identity case
|
||||
if (start == end) {
|
||||
rewriter.replaceOp(op, self);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Technically, I should calculate the output shape based on the input shape,
|
||||
// start value, and end value. However, that would just give the same result
|
||||
// as me taking the result shape straight from resultType and applying
|
||||
// tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach
|
||||
// here, which is more simple and quicker.
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, resultType, self,
|
||||
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -7101,9 +7411,15 @@ public:
|
|||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenMaxPool1dOp>();
|
||||
patterns.add<ConvertAtenMaxPool1dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAvgPool2dOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAvgPool1dOp>();
|
||||
patterns.add<ConvertAtenAvgPool1dOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
|
@ -7199,6 +7515,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp);
|
||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1744,6 +1744,23 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseRank1DynamicModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
"ElementwiseClampTensorIntModule_basic",
|
||||
"ElementwiseFracModule_basic",
|
||||
"ElementwiseLdexpModule_basic",
|
||||
"ElementwiseSignbitIntModule_basic",
|
||||
"Exp2StaticIntModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
|
@ -3373,9 +3390,10 @@ ONNX_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
}
|
||||
|
||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||
"ElementwiseCopysignModule_basic",
|
||||
"ElementwiseSignbitModule_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
"MaxPool3dEmptyStrideStaticModule_basic",
|
||||
|
@ -3519,11 +3537,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"BoolIntTrueModule_basic",
|
||||
"BroadcastDynamicDimModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"CollapseAllDimensionsModule_basic",
|
||||
"CollapseFullDynamicModule_basic",
|
||||
"CollapsePartialDynamicModule_basic",
|
||||
"CollapseRank1DynamicModule_basic",
|
||||
"CollapseStaticModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
|
@ -3585,10 +3598,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ElementwiseAtanhIntModule_basic",
|
||||
"ElementwiseAtanhModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseClampMinTensorFloatModule_basic",
|
||||
"ElementwiseClampMinTensorIntModule_basic",
|
||||
"ElementwiseClampTensorFloatModule_basic",
|
||||
"ElementwiseClampTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
|
@ -3784,7 +3793,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"ReplicationPad2dModule_right0",
|
||||
"ReplicationPad2dModule_top0",
|
||||
"RollModule_basic",
|
||||
"RsubIntModule_noalpha_basic",
|
||||
"ScalarConstantTupleModule_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
|
@ -3897,16 +3905,12 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||
"IouOfModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MeshgridIndexingIJ_basic",
|
||||
"MeshgridIndexingXY_basic",
|
||||
"Meshgrid_basic",
|
||||
"OneHotModule_basic",
|
||||
"ReduceFrobeniusNormKeepDimModule_basic",
|
||||
"ReduceFrobeniusNormModule_basic",
|
||||
"RepeatInterleaveSelfIntModule_basic",
|
||||
"ScaledDotProductAttentionBoolMaskModule_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
|
||||
|
@ -3927,6 +3931,16 @@ ONNX_TOSA_CRASHING_SET = {
|
|||
}
|
||||
|
||||
ONNX_TOSA_XFAIL_SET = {
|
||||
"ElementwiseCopysignModule_basic",
|
||||
"ElementwiseFracModule_basic",
|
||||
"ElementwiseLdexpModule_basic",
|
||||
"ElementwiseSignbitIntModule_basic",
|
||||
"ElementwiseSignbitModule_basic",
|
||||
"Exp2StaticIntModule_basic",
|
||||
"NllLossStaticModule_basic",
|
||||
"NllLossStaticModule_mean_basic",
|
||||
"NllLossStaticModule_sum_basic",
|
||||
"NllLossStaticModule_weight_basic",
|
||||
"Exp2StaticModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalModule_basic",
|
||||
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
|
||||
|
@ -3950,7 +3964,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"TriuIndicesAllZerosModule_basic",
|
||||
"ElementwiseCreateComplexModule_basic",
|
||||
"ReduceAllDimFloatModule_basic",
|
||||
"AdaptiveMaxPool1dDimOneStatic_basic",
|
||||
"ScaledDotProductAttentionDifferentCausalModule_basic",
|
||||
"HstackBasicComplexModule_basic",
|
||||
"HstackBasicFloatModule_basic",
|
||||
|
@ -4029,7 +4042,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||
"AdaptiveAvgPool2dDynamicNoBatch_basic",
|
||||
"AdaptiveAvgPool2dDynamic_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||
|
@ -4285,10 +4297,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseBitwiseXorModule_basic",
|
||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampTensorInt8Module_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
|
@ -4335,7 +4343,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRelu6Module_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
|
@ -4414,8 +4421,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"HBC_basic",
|
||||
"HardTanhIntModule_basic",
|
||||
"HardTanhModule_basic",
|
||||
"HardtanhBackward_basic",
|
||||
"IndexPut1DFloatAccumulateModule_basic",
|
||||
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||
|
@ -4463,7 +4468,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorModule3dInput_basic",
|
||||
"IndexTensorModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||
|
@ -4474,8 +4478,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
|
@ -4503,10 +4505,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"Matmul_matvec",
|
||||
"Matmul_vecmat",
|
||||
"MaxPool1dCeilModeTrueModule_basic",
|
||||
"MaxPool1dEmptyStrideStaticModule_basic",
|
||||
"MaxPool1dModule_basic",
|
||||
"MaxPool1dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool1dStaticModule_basic",
|
||||
"MaxPool2dCeilModeTrueModule_basic",
|
||||
"MaxPool2dModule_basic",
|
||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||
|
@ -4607,7 +4606,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"NormScalarOptDimKeepDimModule_basic",
|
||||
"NormScalarOptDimModule_basic",
|
||||
"NormalFunctionalModule_basic",
|
||||
"NormalizeModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
|
@ -4730,7 +4728,6 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ReplicationPad2dModule_right0",
|
||||
"ReplicationPad2dModule_top0",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"ReshapeAliasCollapseModule_basic",
|
||||
"ReshapeAliasExpandModule_basic",
|
||||
"ReshapeCollapseModule_basic",
|
||||
|
|
|
@ -2258,16 +2258,8 @@ func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1:
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.uniform$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e+01
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.00007045, 2.18384027, 7.80044794, 5.12785149], [5.79490519, 2.97063255, 1.42340159, 7.10978221], [7.11366796, 9.41223621, 4.45151854, 5.67474747]]> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf64>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf64> -> !torch.vtensor<[3,4],f64>
|
||||
// CHECK: return %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>
|
||||
// CHECK: }
|
||||
// CHECK-LABEL: torch.aten.uniform$basic
|
||||
// CHECK: tosa.const
|
||||
func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) {
|
||||
%float1.000000e00 = torch.constant.float 1.000000e+00
|
||||
%float1.000000e01 = torch.constant.float 1.000000e+01
|
||||
|
@ -2313,3 +2305,122 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
|
|||
%2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.none -> !torch.vtensor<[3,3],f32>
|
||||
return %2 : !torch.vtensor<[3,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 64, 112, 1>} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array<i64: 3, 1>, pad = array<i64: 1, 0, 0, 0>, stride = array<i64: 2, 1>} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array<i64: 1, 64, 56>} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32>
|
||||
// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32>
|
||||
// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32>
|
||||
// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> {
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[1,64,56],f32>
|
||||
return %4 : !torch.vtensor<[1,64,56],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 1, 512, 10, 1>} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32>
|
||||
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array<i64: 1, 512, 10>} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32>
|
||||
// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
|
||||
// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
|
||||
return %3 : !torch.vtensor<[1,512,10],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>,
|
||||
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_8:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = tosa.minimum %[[VAL_8]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_12:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_11]] : (tensor<3x5xf32>, tensor<f32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_13:.*]] = tosa.minimum %[[VAL_12]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||
// CHECK: %[[VAL_15:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_4]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_16:.*]] = tosa.minimum %[[VAL_15]], %[[VAL_3]] : (tensor<3x5xf32>, tensor<1xf32>) -> tensor<3x5xf32>
|
||||
// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32>
|
||||
// CHECK: return %[[VAL_10]], %[[VAL_14]], %[[VAL_17]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32>
|
||||
%1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
|
||||
%2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32>
|
||||
return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.prims.collapse$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 2, 12>} : (tensor<2x3x4xf32>) -> tensor<2x12xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
|
||||
return %0 : !torch.vtensor<[2,12],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue