* [tosa] Support for AtenFlattenUsingIntsOp (#548)

pull/552/head snapshot-20220129.235
Anup Gangwar 2022-01-28 23:38:56 -06:00 committed by GitHub
parent 8bc028af05
commit 454fa9d123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 0 deletions

View File

@ -86,4 +86,7 @@ TOSA_PASS_SET = {
"BatchNorm1DModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
}

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
@ -1928,6 +1929,64 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
AtenFlattenUsingIntsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a ranked tensor type
auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return op.emitError(
"Only ranked tensor types with static shapes are currently supported");
int64_t selfRank = selfType.getRank();
int64_t start_dim, end_dim;
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim)))
return op.emitError("start_dim must be a Scalar constant");
start_dim = toPositiveDim(start_dim, selfRank);
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim)))
return op.emitError("end_dim must be a Scalar constant");
end_dim = toPositiveDim(end_dim, selfRank);
if (selfRank > 0 && !isValidDim(start_dim, selfRank))
return op.emitError("start_dim is statically invalid");
if (selfRank > 0 && !isValidDim(end_dim, selfRank))
return op.emitError("end_dim is statically invalid");
if (end_dim < start_dim)
return op.emitError("end_dim must be larger than start_dim");
SmallVector<int64_t> newShape;
for (auto s : llvm::enumerate(selfType.getShape())) {
int64_t idx = s.index();
if (idx < start_dim || idx > end_dim) {
newShape.push_back(s.value());
} else {
if (idx == start_dim)
newShape.push_back(s.value());
else
newShape.back() *= s.value();
}
}
// Handle the Scalar case
if (newShape.size() == 0)
newShape.push_back(1);
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
auto reshapeOp =
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
rewriter.getI64ArrayAttr(newShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -2085,6 +2144,7 @@ public:
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReshapeOp);
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -525,3 +525,22 @@ func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f
%2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[10,4,3],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[10,4,3],f32>
return %2 : !torch.vtensor<[10,4,3],f32>
}
// -----
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
// CHECK: }
func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
%int4 = torch.constant.int 4
%int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
return %0 : !torch.vtensor<[10,3,?,4],f32>
}