mirror of https://github.com/llvm/torch-mlir
build: update llvm tag to de3f0f7f (#1789)
Credit to @vivekkhandelwal1 for finding the necessary changes. Summary of changes: - Switch Tosa_IntArrayAttr[N], Tosa_IntArrayAttrUpto[N] to DenseI64ArrayAttr. - Replace kNoIterationLimit with kNoLimit. (https://reviews.llvm.org/D140525) - Add dependency on MhloPasses when MHLO is enabled - Specify result type when using mhlo::DotOppull/1827/head
parent
0979df6589
commit
0faba6d2fc
|
@ -1 +1 @@
|
||||||
Subproject commit 7ccbb4dff10efe6c26219204e361ddb0264938b8
|
Subproject commit de3f0f7fa0c7b902dde840913db7e773a02c4173
|
|
@ -1 +1 @@
|
||||||
Subproject commit 8c703fabd60d4447bc86f432446e9ad0eacab600
|
Subproject commit 2c8823d255a777d3053ef891f4dbeea1c32819f4
|
|
@ -18,7 +18,9 @@ set(linked_libs TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchConversionToMLProgram
|
TorchMLIRTorchConversionToMLProgram
|
||||||
TorchMLIRConversionUtils)
|
TorchMLIRConversionUtils)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
list(APPEND linked_libs TorchMLIRTorchToMhlo)
|
list(APPEND linked_libs
|
||||||
|
MhloPasses
|
||||||
|
TorchMLIRTorchToMhlo)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_mlir_library(TorchMLIRConversionPasses
|
add_mlir_library(TorchMLIRConversionPasses
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||||
#include "mhlo/transforms/passes.h"
|
#include "mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Transforms/passes.h"
|
#include "transforms/passes.h"
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
|
@ -37,7 +37,7 @@ void mlir::torch::registerConversionPasses() {
|
||||||
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
return mlir::mhlo::createLegalizeHloToLinalgPass();
|
||||||
});
|
});
|
||||||
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
||||||
return mlir::createSymbolicShapeOptimizationPass();
|
return mlir::mhlo::createSymbolicShapeOptimizationPass();
|
||||||
});
|
});
|
||||||
#endif // TORCH_MLIR_ENABLE_MHLO
|
#endif // TORCH_MLIR_ENABLE_MHLO
|
||||||
}
|
}
|
||||||
|
|
|
@ -216,7 +216,10 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lhsRank <= 2 && rhsRank <= 2) {
|
if (lhsRank <= 2 && rhsRank <= 2) {
|
||||||
output = rewriter.create<mhlo::DotOp>(op->getLoc(), lhs, rhs, nullptr);
|
auto tensorType =
|
||||||
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||||
|
output = rewriter.create<mhlo::DotOp>(op->getLoc(), tensorType, lhs, rhs,
|
||||||
|
nullptr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -881,7 +881,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
newOutputTy),
|
newOutputTy),
|
||||||
self, rewriter.getI64ArrayAttr(newOutputShape));
|
self, rewriter.getDenseI64ArrayAttr(newOutputShape));
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
@ -1076,7 +1076,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
lhsBroadcastedTy),
|
lhsBroadcastedTy),
|
||||||
lhs, rewriter.getI64ArrayAttr(lhsBroadcastedShape));
|
lhs, rewriter.getDenseI64ArrayAttr(lhsBroadcastedShape));
|
||||||
|
|
||||||
auto rankBroadcastedRhs =
|
auto rankBroadcastedRhs =
|
||||||
rhsRank == maxInputRank
|
rhsRank == maxInputRank
|
||||||
|
@ -1085,7 +1085,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
rhsBroadcastedTy),
|
rhsBroadcastedTy),
|
||||||
rhs, rewriter.getI64ArrayAttr(rhsBroadcastedShape));
|
rhs, rewriter.getDenseI64ArrayAttr(rhsBroadcastedShape));
|
||||||
|
|
||||||
// TOSA matmul is performed on two 3D inputs and generates a 3D output.
|
// TOSA matmul is performed on two 3D inputs and generates a 3D output.
|
||||||
// Lower ranked tensors are dim-1 reshaped up to 3D
|
// Lower ranked tensors are dim-1 reshaped up to 3D
|
||||||
|
@ -1113,7 +1113,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
newType),
|
newType),
|
||||||
tensor, rewriter.getI64ArrayAttr(newShape));
|
tensor, rewriter.getDenseI64ArrayAttr(newShape));
|
||||||
};
|
};
|
||||||
|
|
||||||
// Where broadcasting is required in one or more batch dims, the following
|
// Where broadcasting is required in one or more batch dims, the following
|
||||||
|
@ -1303,7 +1303,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
newLhsType),
|
newLhsType),
|
||||||
lhsReshapeInput, rewriter.getI64ArrayAttr(newLhsShape));
|
lhsReshapeInput, rewriter.getDenseI64ArrayAttr(newLhsShape));
|
||||||
|
|
||||||
SmallVector<int64_t> transposedRhsShape;
|
SmallVector<int64_t> transposedRhsShape;
|
||||||
SmallVector<int32_t> transposedRhsDims;
|
SmallVector<int32_t> transposedRhsDims;
|
||||||
|
@ -1375,7 +1375,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
newRhsType),
|
newRhsType),
|
||||||
transposedRhsValue, rewriter.getI64ArrayAttr(newRhsShape));
|
transposedRhsValue, rewriter.getDenseI64ArrayAttr(newRhsShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto matmulLhsShape = makeShapeTorchCompatible(
|
auto matmulLhsShape = makeShapeTorchCompatible(
|
||||||
|
@ -1506,7 +1506,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
reshapedOpType),
|
reshapedOpType),
|
||||||
mmOpResult, rewriter.getI64ArrayAttr(reshapedOpShape));
|
mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape));
|
||||||
|
|
||||||
if (opNeedsTranspose) {
|
if (opNeedsTranspose) {
|
||||||
|
|
||||||
|
@ -1915,9 +1915,9 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.create<tosa::Conv2DOp>(op->getLoc(),
|
.create<tosa::Conv2DOp>(op->getLoc(),
|
||||||
getTypeConverter()->convertType(convOpTy),
|
getTypeConverter()->convertType(convOpTy),
|
||||||
transposedInput, transposedWeight, bias,
|
transposedInput, transposedWeight, bias,
|
||||||
rewriter.getI64ArrayAttr(padding),
|
rewriter.getDenseI64ArrayAttr(padding),
|
||||||
rewriter.getI64ArrayAttr(stride),
|
rewriter.getDenseI64ArrayAttr(stride),
|
||||||
rewriter.getI64ArrayAttr(dilation))
|
rewriter.getDenseI64ArrayAttr(dilation))
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
std::optional<Value> nhwcToNchwTransposeConst =
|
std::optional<Value> nhwcToNchwTransposeConst =
|
||||||
|
@ -1979,7 +1979,7 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(newType), self,
|
op, getTypeConverter()->convertType(newType), self,
|
||||||
rewriter.getI64ArrayAttr(newShape));
|
rewriter.getDenseI64ArrayAttr(newShape));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2078,7 +2078,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
outTensorType.getElementType());
|
outTensorType.getElementType());
|
||||||
|
|
||||||
result = rewriter.create<tosa::ReshapeOp>(
|
result = rewriter.create<tosa::ReshapeOp>(
|
||||||
op->getLoc(), newType, toBcast, rewriter.getI64ArrayAttr(newShape));
|
op->getLoc(), newType, toBcast,
|
||||||
|
rewriter.getDenseI64ArrayAttr(newShape));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
@ -2203,8 +2204,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
sumDiv, rewriter.getI64IntegerAttr(i));
|
sumDiv, rewriter.getI64IntegerAttr(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
return rewriter.create<tosa::ReshapeOp>(op.getLoc(), outType, sumDiv,
|
return rewriter.create<tosa::ReshapeOp>(
|
||||||
rewriter.getI64ArrayAttr(outShape));
|
op.getLoc(), outType, sumDiv, rewriter.getDenseI64ArrayAttr(outShape));
|
||||||
};
|
};
|
||||||
|
|
||||||
// TOSA has integer Div so, compute reciprocal of element count to be used in
|
// TOSA has integer Div so, compute reciprocal of element count to be used in
|
||||||
|
@ -2260,11 +2261,11 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value weightVal = rewriter.create<tosa::ReshapeOp>(
|
Value weightVal = rewriter.create<tosa::ReshapeOp>(
|
||||||
op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(),
|
op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(),
|
||||||
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
|
rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape));
|
||||||
|
|
||||||
Value biasVal = rewriter.create<tosa::ReshapeOp>(
|
Value biasVal = rewriter.create<tosa::ReshapeOp>(
|
||||||
op.getLoc(), weightAndMeanBcastType, adaptor.getBias(),
|
op.getLoc(), weightAndMeanBcastType, adaptor.getBias(),
|
||||||
rewriter.getI64ArrayAttr(weightAndBiasBcastShape));
|
rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape));
|
||||||
|
|
||||||
double eps;
|
double eps;
|
||||||
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
|
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps)))
|
||||||
|
@ -2365,8 +2366,9 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
|
auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
|
||||||
selfType.getElementType());
|
selfType.getElementType());
|
||||||
auto reshapeOp = rewriter.create<tosa::ReshapeOp>(
|
auto reshapeOp =
|
||||||
op.getLoc(), newType, adaptor.getSelf(), rewriter.getI64ArrayAttr(newShape));
|
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.getSelf(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(newShape));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
|
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
|
||||||
|
@ -2530,7 +2532,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||||
rewriter.getI64ArrayAttr(outShape));
|
rewriter.getDenseI64ArrayAttr(outShape));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2603,7 +2605,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||||
rewriter.getI64ArrayAttr(outShape));
|
rewriter.getDenseI64ArrayAttr(outShape));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2838,7 +2840,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape),
|
RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape),
|
||||||
weightType.getElementType()),
|
weightType.getElementType()),
|
||||||
weight, rewriter.getI64ArrayAttr(newWeightShape));
|
weight, rewriter.getDenseI64ArrayAttr(newWeightShape));
|
||||||
|
|
||||||
int64_t numIndices = 1;
|
int64_t numIndices = 1;
|
||||||
if (indicesType.hasStaticShape()) {
|
if (indicesType.hasStaticShape()) {
|
||||||
|
@ -2853,7 +2855,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape),
|
RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape),
|
||||||
indicesType.getElementType()),
|
indicesType.getElementType()),
|
||||||
indices, rewriter.getI64ArrayAttr(newIndicesShape));
|
indices, rewriter.getDenseI64ArrayAttr(newIndicesShape));
|
||||||
|
|
||||||
auto castIndices = rewriter.create<tosa::CastOp>(
|
auto castIndices = rewriter.create<tosa::CastOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
|
@ -2870,7 +2872,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||||
op, outType, gatherOp,
|
op, outType, gatherOp,
|
||||||
rewriter.getI64ArrayAttr(makeShapeTorchCompatible(outType.getShape())));
|
rewriter.getDenseI64ArrayAttr(
|
||||||
|
makeShapeTorchCompatible(outType.getShape())));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2960,7 +2963,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
|
auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim);
|
||||||
auto prunedShapeAttr = rewriter.getI64ArrayAttr(prunedShape);
|
auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape);
|
||||||
|
|
||||||
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
|
Value reduceMax = rewriter.create<tosa::ReduceMaxOp>(
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
|
@ -2975,7 +2978,7 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
if (argMax.getType() != indicesType) {
|
if (argMax.getType() != indicesType) {
|
||||||
argMax = rewriter.create<tosa::ReshapeOp>(
|
argMax = rewriter.create<tosa::ReshapeOp>(
|
||||||
op->getLoc(), indicesType, argMax,
|
op->getLoc(), indicesType, argMax,
|
||||||
rewriter.getI64ArrayAttr(reducedShape));
|
rewriter.getDenseI64ArrayAttr(reducedShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!keepDim) {
|
if (!keepDim) {
|
||||||
|
@ -3043,8 +3046,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
|
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
|
||||||
rewriter.getI64ArrayAttr(startSlice),
|
rewriter.getDenseI64ArrayAttr(startSlice),
|
||||||
rewriter.getI64ArrayAttr(sizeSlice));
|
rewriter.getDenseI64ArrayAttr(sizeSlice));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3427,8 +3430,9 @@ public:
|
||||||
// function also transposes inputs.
|
// function also transposes inputs.
|
||||||
virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
virtual LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value &input, ArrayAttr &kernel,
|
Value &input, DenseI64ArrayAttr &kernel,
|
||||||
ArrayAttr &stride, ArrayAttr &pad,
|
DenseI64ArrayAttr &stride,
|
||||||
|
DenseI64ArrayAttr &pad,
|
||||||
Type &outputTy) const {
|
Type &outputTy) const {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented pooling input parsing function");
|
op, "Unimplemented pooling input parsing function");
|
||||||
|
@ -3503,7 +3507,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value input;
|
Value input;
|
||||||
ArrayAttr kernel, stride, pad;
|
DenseI64ArrayAttr kernel, stride, pad;
|
||||||
Type outputTy;
|
Type outputTy;
|
||||||
|
|
||||||
// Attempts to read input and kernel parameters, or synthesize them in the
|
// Attempts to read input and kernel parameters, or synthesize them in the
|
||||||
|
@ -3540,8 +3544,9 @@ public:
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
LogicalResult processInputs(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Value &input,
|
ConversionPatternRewriter &rewriter, Value &input,
|
||||||
ArrayAttr &kernel, ArrayAttr &stride,
|
DenseI64ArrayAttr &kernel,
|
||||||
ArrayAttr &pad, Type &outputTy) const override {
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
|
Type &outputTy) const override {
|
||||||
auto inputXchw = adaptor.getSelf();
|
auto inputXchw = adaptor.getSelf();
|
||||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
||||||
if (!inputTy)
|
if (!inputTy)
|
||||||
|
@ -3603,12 +3608,12 @@ public:
|
||||||
input =
|
input =
|
||||||
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
|
ConvertAtenPoolingBaseOp<AtenOpT, TosaOpT>::transposePoolingInputToHwc(
|
||||||
op, rewriter, inputXchw);
|
op, rewriter, inputXchw);
|
||||||
kernel = rewriter.getI64ArrayAttr(kernelDims);
|
kernel = rewriter.getDenseI64ArrayAttr(kernelDims);
|
||||||
stride = rewriter.getI64ArrayAttr({strideH, strideW});
|
stride = rewriter.getDenseI64ArrayAttr({strideH, strideW});
|
||||||
// Adaptive pooling does unit dilation and zero pad.
|
// Adaptive pooling does unit dilation and zero pad.
|
||||||
pad = rewriter.getI64ArrayAttr({0, 0, 0, 0});
|
pad = rewriter.getDenseI64ArrayAttr({0, 0, 0, 0});
|
||||||
outputTy =
|
outputTy = RankedTensorType::get(makeShapeLLVMCompatible(outputShape),
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), inputElemTy);
|
inputElemTy);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3643,8 +3648,9 @@ static Type getOutputTypeForNonAdaptivePoolingOp(
|
||||||
template <typename AtenOpT, typename tosaOp>
|
template <typename AtenOpT, typename tosaOp>
|
||||||
static LogicalResult getOutputTypeAndPoolingParameters(
|
static LogicalResult getOutputTypeAndPoolingParameters(
|
||||||
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
|
AtenOpT op, ConversionPatternRewriter &rewriter, Value inputXchw,
|
||||||
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy, ArrayAttr &kernel,
|
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
|
||||||
ArrayAttr &stride, ArrayAttr &pad) {
|
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
|
||||||
|
DenseI64ArrayAttr &pad) {
|
||||||
|
|
||||||
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
|
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
|
||||||
if (!inputTy)
|
if (!inputTy)
|
||||||
|
@ -3669,9 +3675,9 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Non-const padding factor for pooling op unsupported");
|
op, "Non-const padding factor for pooling op unsupported");
|
||||||
|
|
||||||
kernel = rewriter.getI64ArrayAttr(kernelSizeInts);
|
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
|
||||||
stride = rewriter.getI64ArrayAttr(strideInts);
|
stride = rewriter.getDenseI64ArrayAttr(strideInts);
|
||||||
pad = rewriter.getI64ArrayAttr(
|
pad = rewriter.getDenseI64ArrayAttr(
|
||||||
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
|
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});
|
||||||
|
|
||||||
// FIXME: add ceil_mode support.
|
// FIXME: add ceil_mode support.
|
||||||
|
@ -3696,10 +3702,12 @@ public:
|
||||||
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
|
tosa::MaxPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||||
LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor,
|
LogicalResult processInputs(AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Value &input,
|
ConversionPatternRewriter &rewriter, Value &input,
|
||||||
ArrayAttr &kernel, ArrayAttr &stride,
|
DenseI64ArrayAttr &kernel,
|
||||||
ArrayAttr &pad, Type &outputTy) const override {
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
|
Type &outputTy) const override {
|
||||||
SmallVector<int64_t, 2> dilationArray;
|
SmallVector<int64_t, 2> dilationArray;
|
||||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationArray)))
|
if (!matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilationArray)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Non-const dilation for pooling op unsupported.");
|
op, "Non-const dilation for pooling op unsupported.");
|
||||||
// TOSA pooling only supports unit dilation.
|
// TOSA pooling only supports unit dilation.
|
||||||
|
@ -3729,8 +3737,9 @@ public:
|
||||||
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
|
tosa::AvgPool2dOp>::ConvertAtenPoolingBaseOp;
|
||||||
LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
LogicalResult processInputs(AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Value &input,
|
ConversionPatternRewriter &rewriter, Value &input,
|
||||||
ArrayAttr &kernel, ArrayAttr &stride,
|
DenseI64ArrayAttr &kernel,
|
||||||
ArrayAttr &pad, Type &outputTy) const override {
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
|
Type &outputTy) const override {
|
||||||
SmallVector<int64_t, 2> dilationArray{1, 1};
|
SmallVector<int64_t, 2> dilationArray{1, 1};
|
||||||
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
|
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
|
||||||
tosa::AvgPool2dOp>(
|
tosa::AvgPool2dOp>(
|
||||||
|
|
|
@ -151,7 +151,7 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
auto indicesChosenAxis = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
auto indicesChosenAxis = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()),
|
GetTypeFromTensorShape(indicesOneDimShape, indexType.getElementType()),
|
||||||
indexValue, rewriter.getI64ArrayAttr(indicesOneDimShape));
|
indexValue, rewriter.getDenseI64ArrayAttr(indicesOneDimShape));
|
||||||
|
|
||||||
SmallVector<Value> concatInputs;
|
SmallVector<Value> concatInputs;
|
||||||
for (auto dim = 0; dim < paramsRank; dim++) {
|
for (auto dim = 0; dim < paramsRank; dim++) {
|
||||||
|
@ -312,14 +312,14 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
||||||
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
auto tosaValuesReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()),
|
GetTypeFromTensorShape(tosaValuesShape, paramsType.getElementType()),
|
||||||
paramsValue, rewriter.getI64ArrayAttr(tosaValuesShape));
|
paramsValue, rewriter.getDenseI64ArrayAttr(tosaValuesShape));
|
||||||
|
|
||||||
// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
|
// %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) ->
|
||||||
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
|
// tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix.
|
||||||
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
|
GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()),
|
||||||
indicesValue, rewriter.getI64ArrayAttr(indicesMatrixShape));
|
indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape));
|
||||||
|
|
||||||
SmallVector<int32_t> flattenedCoeffVec; // [12,3,1]
|
SmallVector<int32_t> flattenedCoeffVec; // [12,3,1]
|
||||||
// flattenedCoeffVec = [4,3,1]
|
// flattenedCoeffVec = [4,3,1]
|
||||||
|
@ -367,7 +367,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
||||||
rewriter, op->getLoc(),
|
rewriter, op->getLoc(),
|
||||||
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
|
GetTypeFromTensorShape(tosaIndicesShape, indicesType.getElementType()),
|
||||||
flattenedIndicesReduceOp.getResult(),
|
flattenedIndicesReduceOp.getResult(),
|
||||||
rewriter.getI64ArrayAttr(tosaIndicesShape));
|
rewriter.getDenseI64ArrayAttr(tosaIndicesShape));
|
||||||
|
|
||||||
// Now the gather op itself
|
// Now the gather op itself
|
||||||
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
|
// %9 = "tosa.gather"(%2, %7) : (tensor<1x12x1xf32>, tensor<1x8xi32>) ->
|
||||||
|
@ -384,7 +384,7 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
||||||
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
|
// %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
|
||||||
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
return tosa::CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
|
rewriter, op->getLoc(), resultType, tosaGatherOp.getResult(),
|
||||||
rewriter.getI64ArrayAttr(resultType.getShape()))
|
rewriter.getDenseI64ArrayAttr(resultType.getShape()))
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -446,7 +446,7 @@ std::optional<Value> convertReduceOpCommon(
|
||||||
if (!keep_dims) {
|
if (!keep_dims) {
|
||||||
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
|
auto reshape_op = CreateOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, op->getLoc(), output_type, val,
|
rewriter, op->getLoc(), output_type, val,
|
||||||
rewriter.getI64ArrayAttr(output_shape));
|
rewriter.getDenseI64ArrayAttr(output_shape));
|
||||||
val = reshape_op.getResult();
|
val = reshape_op.getResult();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,9 +32,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
|
||||||
rewriter, op->getLoc(), output_type, input_val,
|
rewriter, op->getLoc(), output_type, input_val,
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
|
||||||
rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}),
|
rewriter.getDenseI32ArrayAttr({multiplier}),
|
||||||
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
|
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
|
||||||
rewriter.getBoolAttr(false));
|
rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
return rescale_op.getResult();
|
return rescale_op.getResult();
|
||||||
}
|
}
|
||||||
|
@ -85,8 +85,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
||||||
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
|
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
|
||||||
rewriter, op->getLoc(), output_type, conv_val,
|
rewriter, op->getLoc(), output_type, conv_val,
|
||||||
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
|
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
|
||||||
rewriter.getI32ArrayAttr({multiplier}),
|
rewriter.getDenseI32ArrayAttr({multiplier}),
|
||||||
rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
|
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
|
||||||
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
|
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
|
||||||
|
|
||||||
return rescale_op.getResult();
|
return rescale_op.getResult();
|
||||||
|
@ -121,8 +121,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
||||||
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
|
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
|
||||||
rewriter, op->getLoc(), output_type, conv_val,
|
rewriter, op->getLoc(), output_type, conv_val,
|
||||||
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
|
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
|
||||||
rewriter.getI32ArrayAttr(multiplier_arr),
|
rewriter.getDenseI32ArrayAttr(multiplier_arr),
|
||||||
rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
|
rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
|
||||||
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
|
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
|
||||||
|
|
||||||
return rescale_op.getResult();
|
return rescale_op.getResult();
|
||||||
|
|
|
@ -3667,7 +3667,7 @@ public:
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||||
|
|
||||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
config))) {
|
config))) {
|
||||||
|
|
|
@ -194,7 +194,7 @@ class SimplifyDtypeCalculationsPass
|
||||||
// A single linear scan should suffice.
|
// A single linear scan should suffice.
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
config))) {
|
config))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -384,7 +384,7 @@ class SimplifyShapeCalculationsPass
|
||||||
// A single linear scan should suffice.
|
// A single linear scan should suffice.
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
config))) {
|
config))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -224,7 +224,7 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
|
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808>} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
|
||||||
func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -246,7 +246,7 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
||||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32>
|
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32>
|
||||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32>
|
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32>
|
||||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32>
|
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32>
|
||||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xf32>) -> tensor<1xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||||
func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
|
func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
|
@ -264,7 +264,7 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
||||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
||||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
||||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
||||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
||||||
func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||||
|
@ -280,7 +280,7 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.
|
||||||
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
|
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[ARG2:.*]] = torch.constant.bool false
|
// CHECK: %[[ARG2:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
||||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808>} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
|
||||||
func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
|
func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
|
||||||
|
@ -299,7 +299,7 @@ func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !to
|
||||||
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) {axis = 1 : i64} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1>
|
||||||
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) {axis = 2 : i64} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1>
|
||||||
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) {axis = 3 : i64} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1>
|
||||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = [1]} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) {new_shape = array<i64: 1>} : (tensor<1x1x1x1xi1>) -> tensor<1xi1>
|
||||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1>
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1>
|
||||||
func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> {
|
||||||
|
@ -467,7 +467,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-1]} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
|
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: -1>} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -489,10 +489,10 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05
|
// CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05
|
||||||
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
|
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
|
||||||
// CHECK: %[[VAL_7:.*]] = torch.constant.bool false
|
// CHECK: %[[VAL_7:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32>
|
// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
|
||||||
// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32>
|
// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
|
||||||
// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32>
|
// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
|
||||||
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [4, 1]} : (tensor<4xf32>) -> tensor<4x1xf32>
|
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 4, 1>} : (tensor<4xf32>) -> tensor<4x1xf32>
|
||||||
// CHECK: %[[VAL_12:.*]] = "tosa.const"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32>
|
// CHECK: %[[VAL_12:.*]] = "tosa.const"() {value = dense<9.99999974E-6> : tensor<f32>} : () -> tensor<f32>
|
||||||
// CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
|
// CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32>
|
||||||
// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
|
// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
|
||||||
|
@ -521,7 +521,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
|
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 10, 3, 216, 4>} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
|
||||||
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
|
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
|
||||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
|
||||||
|
@ -551,17 +551,17 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor
|
||||||
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
||||||
// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
|
// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
|
||||||
// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
||||||
// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
|
// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32>
|
||||||
// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) {axis = 3 : i64} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32>
|
||||||
// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
|
// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) {axis = 2 : i64} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32>
|
||||||
// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) {axis = 1 : i64} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = [5, 1, 1, 1]} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) {new_shape = array<i64: 5, 1, 1, 1>} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32>
|
||||||
// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
||||||
// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [1, 2, 2, 3]} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = array<i64: 1, 2, 2, 3>} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32>
|
||||||
// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
|
// CHECK: %[[VAL_26:.*]] = "tosa.const"() {value = dense<5.000000e-01> : tensor<f32>} : () -> tensor<f32>
|
||||||
// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32>
|
||||||
// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
|
// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor<f32>) -> tensor<5x1x1x1xf32>
|
||||||
|
@ -681,7 +681,7 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
|
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 4, 3, 1>} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
|
||||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -698,7 +698,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
|
||||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
|
||||||
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3, 1]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
|
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = array<i64: 4, 3, 1>} : (tensor<4x3xi32>) -> tensor<4x3x1xi32>
|
||||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
|
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32>
|
||||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
|
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
@ -778,7 +778,7 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch
|
||||||
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
// CHECK: %[[VAL_11:.*]] = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
|
// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32>
|
||||||
// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
|
// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) {kernel = array<i64: 7, 7>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32>
|
||||||
// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
// CHECK: %[[VAL_14:.*]] = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
|
// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32>
|
||||||
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
|
// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32>
|
||||||
|
@ -809,7 +809,7 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
||||||
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = array<i64: 3, 2, 1>} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||||
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
// CHECK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||||
|
@ -927,18 +927,18 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten
|
||||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
|
||||||
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
|
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
|
||||||
// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
|
// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
|
||||||
// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = [1, 4, 2, 1]} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32>
|
// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array<i64: 1, 4, 2, 1>} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32>
|
||||||
// CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
|
// CHECK: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
|
||||||
// CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
|
// CHECK: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>} : () -> tensor<1x4x2x1xi32>
|
||||||
// CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
|
// CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) {axis = 3 : i64} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
|
||||||
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
|
// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array<i64: 1, 12, 1>} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
|
||||||
// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
|
// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) {new_shape = array<i64: 8, 3>} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
|
||||||
// CHECK: %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
// CHECK: %[[VAL_13:.*]] = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
// CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
|
// CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
|
||||||
// CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32>
|
// CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<8x1xi32>
|
||||||
// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> tensor<1x8xi32>
|
// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) {new_shape = array<i64: 1, 8>} : (tensor<8x1xi32>) -> tensor<1x8xi32>
|
||||||
// CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
|
// CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
|
||||||
// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
|
// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) {new_shape = array<i64: 1, 4, 2>} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
|
||||||
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
|
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
|
||||||
// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32>
|
// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
Loading…
Reference in New Issue