[torch] Add support for `torch.view` with dynamic shapes (#3164)

We can map to `tensor.reshape` for handling multiple output dynamic
shapes. Later we can perform a more complex analysis for indentifying
expand/collapse cases from the tensor.reshape.

Initially we planned to handle this identification at the `torch` level
however it will be easier to handle once converted to core
mlir-dialects.
pull/3186/head
Rob Suderman 2024-04-18 11:47:19 -07:00 committed by GitHub
parent 4c21e20caa
commit 0e77de996a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 127 additions and 22 deletions

View File

@ -1003,8 +1003,14 @@ public:
// collapsed. Note this may technically not always be true. // collapsed. Note this may technically not always be true.
// TODO: think of a way better way to at least detect when this assumption // TODO: think of a way better way to at least detect when this assumption
// is violated for the cases of dynamic dimensions. // is violated for the cases of dynamic dimensions.
bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; int64_t inputDynDim = llvm::count(inputShape, kUnknownSize);
bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; int64_t outputDynDim = llvm::count(outputShape, kUnknownSize);
if (outputDynDim > 1)
return rewriter.notifyMatchFailure(
op, "Cannot support more than one output dynamic dimension");
bool inputHasOneDynDim = inputDynDim == 1;
bool outputHasOneDynDim = outputDynDim == 1;
bool singleDynDimsAreEqual = bool singleDynDimsAreEqual =
inputHasOneDynDim && outputHasOneDynDim && inputHasOneDynDim && outputHasOneDynDim &&
productReduce(inputShape) == productReduce(outputShape); productReduce(inputShape) == productReduce(outputShape);
@ -1271,6 +1277,85 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenViewOpToReshape : public OpConversionPattern<AtenViewOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> sizes;
if (!getListConstructElements(op.getSize(), sizes))
return op.emitError(
"unimplemented: the tensor size list is not from list construct");
auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
auto self = adaptor.getSelf();
const TypeConverter *typeConverter = getTypeConverter();
// Convert to the `linalg` types, count the number of negative values,
// and determine the product of non-negative values. This lets us compute
// the inferred dimensions sizes.
auto sizeTy =
cast<IntegerType>(typeConverter->convertType(sizes.front().getType()));
Value one =
b.create<arith::ConstantOp>(sizeTy, rewriter.getIntegerAttr(sizeTy, 1));
Value zero =
b.create<arith::ConstantOp>(sizeTy, rewriter.getIntegerAttr(sizeTy, 0));
Value count = zero;
Value knownSize = one;
for (auto &size : sizes) {
Value convert = typeConverter->materializeTargetConversion(rewriter, loc,
sizeTy, size);
Value mul = b.create<arith::MulIOp>(knownSize, convert);
Value add = b.create<arith::AddIOp>(count, one);
Value isNeg =
b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, convert, zero);
knownSize = b.create<arith::SelectOp>(isNeg, knownSize, mul);
count = b.create<arith::SelectOp>(isNeg, add, count);
size = convert;
}
// Check we are only inferring one dimension:
Value countPred =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
b.create<cf::AssertOp>(
loc, countPred,
b.getStringAttr("must have at most one inferred (negative) dimension"));
// Determine the total size of the inferred dimension and update the
// inferred dimension:
auto selfTy = cast<RankedTensorType>(self.getType());
Value totalSize = one;
for (int i = 0, s = selfTy.getRank(); i < s; ++i) {
Value index = b.create<arith::ConstantIndexOp>(i);
Value dim = b.create<tensor::DimOp>(self, index);
dim = b.create<arith::IndexCastOp>(sizeTy, dim);
totalSize = b.create<arith::MulIOp>(totalSize, dim);
}
Value inferredSize = b.create<arith::DivSIOp>(totalSize, knownSize);
for (auto &size : sizes) {
Value isNeg =
b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, size, zero);
size = b.create<arith::SelectOp>(isNeg, inferredSize, size);
}
auto ty = RankedTensorType::get(sizes.size(), sizes.front().getType());
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
outputDims);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> { class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
public: public:
@ -2348,10 +2433,12 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context); patterns.add<ConvertAtenReflectionPad2dOp>(typeConverter, context);
target.addIllegalOp<AtenFlattenUsingIntsOp>(); target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context); patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context); patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
target.addIllegalOp<AtenUnflattenIntOp>(); target.addIllegalOp<AtenUnflattenIntOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context); target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/200);
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
/*benefit=*/100);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeDimOp>(); target.addIllegalOp<AtenSqueezeDimOp>();

View File

@ -32,6 +32,7 @@ from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTen
from .xfail_sets import ( from .xfail_sets import (
LINALG_XFAIL_SET, LINALG_XFAIL_SET,
LINALG_CRASHING_SET,
MAKE_FX_TOSA_PASS_SET, MAKE_FX_TOSA_PASS_SET,
STABLEHLO_PASS_SET, STABLEHLO_PASS_SET,
STABLEHLO_CRASHING_SET, STABLEHLO_CRASHING_SET,
@ -99,7 +100,7 @@ def main():
if args.config == "linalg": if args.config == "linalg":
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = LINALG_XFAIL_SET xfail_set = LINALG_XFAIL_SET
crashing_set = set() crashing_set = LINALG_CRASHING_SET
elif args.config == "stablehlo": elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET xfail_set = all_test_unique_names - STABLEHLO_PASS_SET

View File

@ -24,6 +24,11 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"SplitWithSizes_Module_basic", "SplitWithSizes_Module_basic",
} }
LINALG_CRASHING_SET = {
# Crashes due to copy to a smaller destination buffer than the source buffer.
"SliceCopyStartGreaterThanDimSize_Module_basic",
}
TORCHDYNAMO_XFAIL_SET = { TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors #### General TorchDynamo/PyTorch errors
@ -2280,15 +2285,6 @@ ONNX_XFAIL_SET = {
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
# Failure - torch.aten.view lower # Failure - torch.aten.view lower
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
@ -2327,7 +2323,6 @@ ONNX_XFAIL_SET = {
"EmbeddingModuleF16_basic", "EmbeddingModuleF16_basic",
"EmbeddingModuleI32_basic", "EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic", "EmbeddingModuleI64_basic",
"FlattenDynamicModule_basic",
"GluStaticModule_basic", "GluStaticModule_basic",
"GroupNormModule_basic", "GroupNormModule_basic",
"IndexTensorHackedTwinModule3dInput_basic", "IndexTensorHackedTwinModule3dInput_basic",

View File

@ -992,6 +992,28 @@ class ReshapeAliasExpandModule(torch.nn.Module):
def ReshapeAliasExpandModule_basic(module, tu: TestUtils): def ReshapeAliasExpandModule_basic(module, tu: TestUtils):
module.forward(tu.rand(384)) module.forward(tu.rand(384))
# ==============================================================================
class ReshapeDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return a.view(a.size(1), a.size(0))
@register_test_case(module_factory=lambda: ReshapeDynamicModule())
def ReshapeDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3,4))
# ============================================================================== # ==============================================================================
class ReshapeAliasCollapseModule(torch.nn.Module): class ReshapeAliasCollapseModule(torch.nn.Module):
@ -1153,4 +1175,4 @@ class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule()) @register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule())
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))

View File

@ -23,7 +23,8 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc
// CHECK-LABEL: func.func @torch.aten.view$dynamictest( // CHECK-LABEL: func.func @torch.aten.view$dynamictest(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[BUILTIN_TENSOR]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32> // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@ -31,7 +32,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
%1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
%2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int> %2 = torch.prim.ListConstruct %1, %0 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32> %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
return %3 : !torch.vtensor<[?,?],f32> return %3 : !torch.vtensor<[?,?],f32>
} }
@ -41,7 +42,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
// CHECK-LABEL: func.func @torch.aten.view$dynamictest2( // CHECK-LABEL: func.func @torch.aten.view$dynamictest2(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor<?x6x?xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor<?x6x?xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2], [3]] : tensor<?x6x?xf32> into tensor<?x2x3x?xf32> // CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<?x2x3x?xf32> -> !torch.vtensor<[?,2,3,?],f32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<?x2x3x?xf32> -> !torch.vtensor<[?,2,3,?],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32>
@ -174,9 +175,8 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2
// CHECK: func.func @torch.aten.view$combineConcepts( // CHECK: func.func @torch.aten.view$combineConcepts(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32>
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2], [3], [4, 5, 6]] : tensor<8x?x?x?x2x1x3xf32> into tensor<8x?x?x?x6xf32> // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4], [5], [6]] : tensor<8x?x?x?x6xf32> into tensor<2x2x2x?x?x?x6xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32>
func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> { func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> {