mirror of https://github.com/llvm/torch-mlir
[torch] Add support for `torch.view` with dynamic shapes (#3164)
We can map to `tensor.reshape` for handling multiple output dynamic shapes. Later we can perform a more complex analysis for indentifying expand/collapse cases from the tensor.reshape. Initially we planned to handle this identification at the `torch` level however it will be easier to handle once converted to core mlir-dialects.pull/3186/head
parent
4c21e20caa
commit
0e77de996a
|
@ -1003,8 +1003,14 @@ public:
|
|||
// collapsed. Note this may technically not always be true.
|
||||
// TODO: think of a way better way to at least detect when this assumption
|
||||
// is violated for the cases of dynamic dimensions.
|
||||
bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1;
|
||||
bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1;
|
||||
int64_t inputDynDim = llvm::count(inputShape, kUnknownSize);
|
||||
int64_t outputDynDim = llvm::count(outputShape, kUnknownSize);
|
||||
if (outputDynDim > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Cannot support more than one output dynamic dimension");
|
||||
|
||||
bool inputHasOneDynDim = inputDynDim == 1;
|
||||
bool outputHasOneDynDim = outputDynDim == 1;
|
||||
bool singleDynDimsAreEqual =
|
||||
inputHasOneDynDim && outputHasOneDynDim &&
|
||||
productReduce(inputShape) == productReduce(outputShape);
|
||||
|
@ -1271,6 +1277,85 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Value> sizes;
|
||||
if (!getListConstructElements(op.getSize(), sizes))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor size list is not from list construct");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
ImplicitLocOpBuilder b(loc, rewriter);
|
||||
auto self = adaptor.getSelf();
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
// Convert to the `linalg` types, count the number of negative values,
|
||||
// and determine the product of non-negative values. This lets us compute
|
||||
// the inferred dimensions sizes.
|
||||
auto sizeTy =
|
||||
cast<IntegerType>(typeConverter->convertType(sizes.front().getType()));
|
||||
Value one =
|
||||
b.create<arith::ConstantOp>(sizeTy, rewriter.getIntegerAttr(sizeTy, 1));
|
||||
Value zero =
|
||||
b.create<arith::ConstantOp>(sizeTy, rewriter.getIntegerAttr(sizeTy, 0));
|
||||
Value count = zero;
|
||||
Value knownSize = one;
|
||||
for (auto &size : sizes) {
|
||||
Value convert = typeConverter->materializeTargetConversion(rewriter, loc,
|
||||
sizeTy, size);
|
||||
|
||||
Value mul = b.create<arith::MulIOp>(knownSize, convert);
|
||||
Value add = b.create<arith::AddIOp>(count, one);
|
||||
Value isNeg =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, convert, zero);
|
||||
|
||||
knownSize = b.create<arith::SelectOp>(isNeg, knownSize, mul);
|
||||
count = b.create<arith::SelectOp>(isNeg, add, count);
|
||||
size = convert;
|
||||
}
|
||||
|
||||
// Check we are only inferring one dimension:
|
||||
Value countPred =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, countPred,
|
||||
b.getStringAttr("must have at most one inferred (negative) dimension"));
|
||||
|
||||
// Determine the total size of the inferred dimension and update the
|
||||
// inferred dimension:
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
Value totalSize = one;
|
||||
for (int i = 0, s = selfTy.getRank(); i < s; ++i) {
|
||||
Value index = b.create<arith::ConstantIndexOp>(i);
|
||||
Value dim = b.create<tensor::DimOp>(self, index);
|
||||
dim = b.create<arith::IndexCastOp>(sizeTy, dim);
|
||||
totalSize = b.create<arith::MulIOp>(totalSize, dim);
|
||||
}
|
||||
|
||||
Value inferredSize = b.create<arith::DivSIOp>(totalSize, knownSize);
|
||||
for (auto &size : sizes) {
|
||||
Value isNeg =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, size, zero);
|
||||
size = b.create<arith::SelectOp>(isNeg, inferredSize, size);
|
||||
}
|
||||
|
||||
auto ty = RankedTensorType::get(sizes.size(), sizes.front().getType());
|
||||
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
|
||||
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
|
||||
outputDims);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
|
||||
public:
|
||||
|
@ -2348,10 +2433,12 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnflattenIntOp>();
|
||||
patterns.add<ConvertAtenViewOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/200);
|
||||
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
||||
/*benefit=*/100);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||
|
|
|
@ -32,6 +32,7 @@ from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTen
|
|||
|
||||
from .xfail_sets import (
|
||||
LINALG_XFAIL_SET,
|
||||
LINALG_CRASHING_SET,
|
||||
MAKE_FX_TOSA_PASS_SET,
|
||||
STABLEHLO_PASS_SET,
|
||||
STABLEHLO_CRASHING_SET,
|
||||
|
@ -99,7 +100,7 @@ def main():
|
|||
if args.config == "linalg":
|
||||
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||
xfail_set = LINALG_XFAIL_SET
|
||||
crashing_set = set()
|
||||
crashing_set = LINALG_CRASHING_SET
|
||||
elif args.config == "stablehlo":
|
||||
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
|
||||
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
|
||||
|
|
|
@ -24,6 +24,11 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"SplitWithSizes_Module_basic",
|
||||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
# Crashes due to copy to a smaller destination buffer than the source buffer.
|
||||
"SliceCopyStartGreaterThanDimSize_Module_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
||||
|
@ -2280,15 +2285,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
||||
# Failure - torch.aten.view lower
|
||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorMultiInputContiguousCenter_basic",
|
||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"IndexTensorMultiInputNonContiguous_basic",
|
||||
"IndexTensorMultiInputOneDim_basic",
|
||||
"IndexTensorMultiInputThreeIndexers_basic",
|
||||
"IndexTensorMultiInput_basic",
|
||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||
|
||||
|
@ -2327,7 +2323,6 @@ ONNX_XFAIL_SET = {
|
|||
"EmbeddingModuleF16_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
"EmbeddingModuleI64_basic",
|
||||
"FlattenDynamicModule_basic",
|
||||
"GluStaticModule_basic",
|
||||
"GroupNormModule_basic",
|
||||
"IndexTensorHackedTwinModule3dInput_basic",
|
||||
|
|
|
@ -992,6 +992,28 @@ class ReshapeAliasExpandModule(torch.nn.Module):
|
|||
def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(a.size(1), a.size(0))
|
||||
|
||||
@register_test_case(module_factory=lambda: ReshapeDynamicModule())
|
||||
def ReshapeDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3,4))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReshapeAliasCollapseModule(torch.nn.Module):
|
||||
|
@ -1153,4 +1175,4 @@ class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule())
|
||||
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))
|
||||
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))
|
||||
|
|
|
@ -23,7 +23,8 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc
|
|||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[BUILTIN_TENSOR]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -31,7 +32,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
|||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
%2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%2 = torch.prim.ListConstruct %1, %0 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
@ -41,7 +42,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
|||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest2(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor<?x6x?xf32>
|
||||
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2], [3]] : tensor<?x6x?xf32> into tensor<?x2x3x?xf32>
|
||||
// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<?x2x3x?xf32> -> !torch.vtensor<[?,2,3,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32>
|
||||
|
||||
|
@ -174,9 +175,8 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2
|
|||
// CHECK: func.func @torch.aten.view$combineConcepts(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32>
|
||||
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2], [3], [4, 5, 6]] : tensor<8x?x?x?x2x1x3xf32> into tensor<8x?x?x?x6xf32>
|
||||
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4], [5], [6]] : tensor<8x?x?x?x6xf32> into tensor<2x2x2x?x?x?x6xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
|
||||
func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> {
|
||||
|
|
Loading…
Reference in New Issue