[MLIR][TORCH] Add support for same input and output shapes for view op

This commit adds support for the cases of view op where the rank and
the shapes of the input and result are equal.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/701/head
Vivek Khandelwal 2022-03-24 17:33:12 +05:30
parent 02b6d04eb4
commit 88c216da13
2 changed files with 99 additions and 21 deletions

View File

@ -315,3 +315,76 @@ class ReshapeCollapseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ReshapeCollapseModule()) @register_test_case(module_factory=lambda: ReshapeCollapseModule())
def ReshapeCollapseModule_basic(module, tu: TestUtils): def ReshapeCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4)) module.forward(tu.rand(2, 4))
# ==============================================================================
class ViewNoChange1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, a):
return a.view(6)
@register_test_case(module_factory=lambda: ViewNoChange1dModule())
def ViewNoChange1dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6))
class ViewNoChange2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return a.view(5, 6)
@register_test_case(module_factory=lambda: ViewNoChange2dModule())
def ViewNoChange2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 6))
class ViewNoChange3dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return a.view(4, 5, 6)
@register_test_case(module_factory=lambda: ViewNoChange3dModule())
def ViewNoChange3dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6))
class ViewNoChangeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 5, 6], torch.float32, True),
])
def forward(self, a):
return a.view(4, 5, 6)
@register_test_case(module_factory=lambda: ViewNoChangeStaticModule())
def ViewNoChangeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6))

View File

@ -107,19 +107,6 @@ public:
auto resultType = auto resultType =
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); typeConverter->convertType(op.getType()).cast<RankedTensorType>();
int64_t resultRank = resultType.getRank(); int64_t resultRank = resultType.getRank();
// Currently, we only handle the expanding OR collapsing cases, we do not
// handle expanding And collapsing happening at the same time or cases where
// it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor.
// TODO: For the expanding And collapsing case, we will need to identify
// which dimensions are collapsing and which are expanding and do it in two
// steps.
// TODO: For neither collapsing nor expanding, we could find a intermediate
// shape to collapse and then expanded to the target shape. Like [2,3] =>
// [6] => [3, 2].
if (inputRank == resultRank)
return rewriter.notifyMatchFailure(
op, "unimplemented: the view op is neither expanding nor collapsing");
if (resultRank == 0) if (resultRank == 0)
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"result shape of rank 0 is invalid"); "result shape of rank 0 is invalid");
@ -147,11 +134,30 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "desired size list length mismatches with the result type rank"); op, "desired size list length mismatches with the result type rank");
} }
SmallVector<Value> inputSizeTorchInt = getTensorSizes(rewriter, loc, input);
ArrayRef<Value> expandedShapeTorchInt = SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt); ArrayRef<Value> expandedShapeInt =
ArrayRef<Value> collapsedShapeTorchInt = llvm::makeArrayRef(isCollapse ? inputSize : outputSizeInt);
llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt); ArrayRef<Value> collapsedShapeInt =
llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSize);
// Currently, we only handle the expanding or collapsing cases or the
// identity cases where the rank and shape of the input and result are
// equal, and the input itself is the result. We do not handle expanding And
// collapsing happening at the same time or cases where it's neither
// collapsing nor expanding like view of [2,3] for 3x2 tensor.
// TODO: For the expanding And collapsing case, we will need to identify
// which dimensions are collapsing and which are expanding and do it in two
// steps.
// TODO: For neither collapsing nor expanding, we could find a intermediate
// shape to collapse and then expanded to the target shape. Like [2,3] =>
// [6] => [3, 2].
if (inputRank == resultRank) {
for (unsigned i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSize[i], outputSizeInt[i]);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success();
}
// Iterate through the view op size list to do the following: // Iterate through the view op size list to do the following:
// //
@ -307,9 +313,8 @@ public:
op, op,
"desired size is not compatible with the input tensor size"); "desired size is not compatible with the input tensor size");
} }
checkDimEqualHelper(rewriter, loc, checkDimEqualHelper(rewriter, loc, collapsedShapeInt[collapsedDim],
collapsedShapeTorchInt[collapsedDim], expandedShapeInt[expandedDim]);
expandedShapeTorchInt[expandedDim]);
// To meet the second requirement from tensor.expand_shape // To meet the second requirement from tensor.expand_shape
// verification code. // verification code.
expandedShape[expandedDim] = kUnknownSize; expandedShape[expandedDim] = kUnknownSize;