mirror of https://github.com/llvm/torch-mlir
parent
8bc028af05
commit
454fa9d123
|
@ -86,4 +86,7 @@ TOSA_PASS_SET = {
|
|||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"FlattenStaticModule_basic",
|
||||
"FlattenRank0Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
|
@ -1928,6 +1929,64 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||
AtenFlattenUsingIntsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a ranked tensor type
|
||||
auto selfType = adaptor.self().getType().dyn_cast<RankedTensorType>();
|
||||
if (!selfType || !selfType.hasStaticShape())
|
||||
return op.emitError(
|
||||
"Only ranked tensor types with static shapes are currently supported");
|
||||
|
||||
int64_t selfRank = selfType.getRank();
|
||||
|
||||
int64_t start_dim, end_dim;
|
||||
|
||||
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start_dim)))
|
||||
return op.emitError("start_dim must be a Scalar constant");
|
||||
start_dim = toPositiveDim(start_dim, selfRank);
|
||||
|
||||
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&end_dim)))
|
||||
return op.emitError("end_dim must be a Scalar constant");
|
||||
end_dim = toPositiveDim(end_dim, selfRank);
|
||||
|
||||
if (selfRank > 0 && !isValidDim(start_dim, selfRank))
|
||||
return op.emitError("start_dim is statically invalid");
|
||||
if (selfRank > 0 && !isValidDim(end_dim, selfRank))
|
||||
return op.emitError("end_dim is statically invalid");
|
||||
if (end_dim < start_dim)
|
||||
return op.emitError("end_dim must be larger than start_dim");
|
||||
|
||||
SmallVector<int64_t> newShape;
|
||||
for (auto s : llvm::enumerate(selfType.getShape())) {
|
||||
int64_t idx = s.index();
|
||||
if (idx < start_dim || idx > end_dim) {
|
||||
newShape.push_back(s.value());
|
||||
} else {
|
||||
if (idx == start_dim)
|
||||
newShape.push_back(s.value());
|
||||
else
|
||||
newShape.back() *= s.value();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the Scalar case
|
||||
if (newShape.size() == 0)
|
||||
newShape.push_back(1);
|
||||
|
||||
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
|
||||
auto reshapeOp =
|
||||
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
|
||||
rewriter.getI64ArrayAttr(newShape));
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), reshapeOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2085,6 +2144,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReshapeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -525,3 +525,22 @@ func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f
|
|||
%2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[10,4,3],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[10,4,3],f32>
|
||||
return %2 : !torch.vtensor<[10,4,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,3,8,9,3,4],f32>) -> !torch.vtensor<[10,3,?,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [10, 3, 216, 4]} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32>
|
||||
// CHECK: }
|
||||
func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor<[10,3,?,4],f32> {
|
||||
%int4 = torch.constant.int 4
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[10,3,8,9,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[10,3,?,4],f32>
|
||||
return %0 : !torch.vtensor<[10,3,?,4],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue