diff --git a/e2e_testing/torchscript/view.py b/e2e_testing/torchscript/view.py index e3c33695b..a9971020e 100644 --- a/e2e_testing/torchscript/view.py +++ b/e2e_testing/torchscript/view.py @@ -9,7 +9,6 @@ from torch_mlir_e2e_test.torchscript.registry import register_test_case from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== - class ViewExpandModule(torch.nn.Module): def __init__(self): super().__init__() @@ -46,3 +45,60 @@ class ViewDynamicExpandModule(torch.nn.Module): def ViewDynamicExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 30, 384)) + +# ============================================================================== +class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), a.size(1), 12, 32) + +@register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule()) +def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 384)) + +# ============================================================================== +class ViewCollapseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(8) + +@register_test_case(module_factory=lambda: ViewCollapseModule()) +def ViewCollapseModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4)) + + +# ============================================================================== +class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1, -1], torch.float32, True), + ([], torch.int64, True), + ([], torch.int64, True), + ]) + + def forward(self, a, b, c): + return a.view(a.size(0), int(b), int(c), a.size(3), 384) + +@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule()) +def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5)) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index f8aff87fa..1c556a11a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -108,6 +108,34 @@ m_TorchConstantIntList(SmallVectorImpl &bind_values) { return detail::torch_list_construct_op_binder(bind_values); } +namespace detail { +/// Matches the expected tensor and dim from `torch.aten.size.int`. +struct torch_tensor_size_int_op_binder { + int64_t *dim; + Value tensor; + + /// Creates a matcher instance that binds the value to dim if match succeeds. + torch_tensor_size_int_op_binder(Value tensor, int64_t *dim) + : dim(dim), tensor(tensor) {} + + bool match(Operation *op) { + if (auto atenSizeIntOp = dyn_cast(op)) { + if (atenSizeIntOp.self() == tensor) { + if (matchPattern(atenSizeIntOp.dim(), m_TorchConstantInt(dim))) + return true; + } + } + return false; + } +}; +} // namespace detail + +/// Matches the tensor and dim of `torch.size.int`. +inline detail::torch_tensor_size_int_op_binder +m_TorchTensorSizeInt(Value tensor, int64_t *dim) { + return detail::torch_tensor_size_int_op_binder(tensor, dim); +} + /// Create code to copy `tensor` to type `newType`. /// /// This involves two independent steps, which we keep orthogonal in our diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7e558a84f..d2190b646 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -133,8 +133,13 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) { return b.create(loc, b.getI64Type(), idx); } -static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) { - return b.create(loc, v, dimension); +static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { + if (auto tensorType = v.getType().cast()) { + if (!tensorType.isDynamicDim(dim)) + return b.create( + loc, b.getIndexAttr(tensorType.getShape()[dim])); + } + return b.create(loc, v, dim); } static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, @@ -2671,84 +2676,214 @@ public: Location loc = op.getLoc(); Value input = adaptor.self(); auto inputType = input.getType().cast(); + ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputType.getRank(); TypeConverter *typeConverter = getTypeConverter(); auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - // When we only have expansion of dimensions in `aten.View`, the output - // tensor rank will be strictly greater than the input tensor rank. - // TODO: Handle the cases of `aten.View` op where, - // 1. One or multiple dimensions are collapsed. - // 2. Few dimensions are expanded and few other dimensions are collapsed. - if (inputRank >= resultRank) { + // Currently, we only handle the expanding OR collapsing cases, we do not + // handle expanding And collapsing happening at the same time or cases where + // it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor. + // TODO: For the expanding And collapsing case, we will need to identify + // which dimensions are collapsing and which are expanding and do it in two + // steps. + // TODO: For neither collapsing nor expanding, we could find a intermediate + // shape to collapse and then expanded to the target shape. Like [2,3] => + // [6] => [3, 2]. + if (inputRank == resultRank) return rewriter.notifyMatchFailure( - op, "unimplemented: operand tensor rank should be strictly less than " - "the desired output rank"); - } + op, "unimplemented: the view op is neither expanding nor collapsing"); + + if (resultRank == 0) + return rewriter.notifyMatchFailure(op, + "result shape of rank 0 is invalid"); + + // TODO: add support for case inputRank 0 expanded to size 1 + if (inputRank == 0) + return rewriter.notifyMatchFailure( + op, "unimplemented: input rank 0 is not supported"); + + bool isCollapse = inputRank > resultRank ? true : false; + int64_t collapsedRank = isCollapse ? resultRank : inputRank; + int64_t expandedRank = isCollapse ? inputRank : resultRank; // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. - SmallVector expectedSizeTorchInt; - if (!getListConstructElements(op.size(), expectedSizeTorchInt)) { + SmallVector outputSizeTorchInt; + if (!getListConstructElements(op.size(), outputSizeTorchInt)) { return rewriter.notifyMatchFailure(op, - "unimplemented: the desired size is " + "unimplemented: the target size is " "not constructed from ListConstruct"); } - SmallVector expectedSize = getTypeConvertedValues( - rewriter, loc, typeConverter, expectedSizeTorchInt); - if (resultRank != (int64_t)expectedSize.size()) { + SmallVector outputSizeInt = getTypeConvertedValues( + rewriter, loc, typeConverter, outputSizeTorchInt); + if (resultRank != (int64_t)outputSizeInt.size()) { return rewriter.notifyMatchFailure( op, "desired size list length mismatches with the result type rank"); } + SmallVector inputSizeTorchInt = getTensorSizes(rewriter, loc, input); + ArrayRef expandedShapeTorchInt = + llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt); + ArrayRef collapsedShapeTorchInt = + llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt); - // Check if the `aten.View` can be legalized to `linalg.TensorExpandShape`. - // It only handles the case of static dimension expansion. If the dimension - // is dynamic, it must not be expanded/splitted. - // TODO: Handle the case of dynamic dimension expansion. - SmallVector reassociation(inputRank); - SmallVector resultShape; - int64_t j = 0; - for (auto i : llvm::seq(0, inputRank)) { - if (inputType.isDynamicDim(i)) { - Value dim = getDimOp(rewriter, loc, input, i); - if (j >= resultRank) { - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); + // Iterate through the view op size list to do the following: + // + // 1. Combine output size list and input tensor type info to get the most + // static outputShape. + // + // 2. Fill in the reassociation for size list item where the output dim size + // is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively + // assume this means the corresponding dimension is not expanded or + // collapsed. Note this may technically not always be true. + // TODO: think of a way better way to at least detect when this assumption + // is violated. + SmallVector outputShape(resultRank, kUnknownSize); + SmallVector reassociation(collapsedRank); + for (auto en : llvm::enumerate(outputSizeTorchInt)) { + int64_t inputDim; + int64_t outputDim = en.index(); + // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim + if (matchPattern(en.value(), + m_TorchTensorSizeInt(op.self(), &inputDim))) { + auto collapsedDim = isCollapse ? outputDim : inputDim; + auto expandedDim = isCollapse ? inputDim : outputDim; + reassociation[collapsedDim].push_back(expandedDim); + if (!inputType.isDynamicDim(inputDim)) { + outputShape[outputDim] = inputShape[inputDim]; + continue; } - checkDimEqualHelper(rewriter, loc, dim, expectedSize[j]); - reassociation[i].push_back(j++); - resultShape.push_back(kUnknownSize); - } else { - int64_t expandedDim = inputType.getDimSize(i); - int64_t outputDim; - // A do-while loop is used here to handle the cases where the input - // tensor has a dimension of size 1. - do { - if (j >= resultRank || - !matchPattern(expectedSizeTorchInt[j], - m_TorchConstantInt(&outputDim)) || - expandedDim % outputDim != 0) { - return rewriter.notifyMatchFailure( - op, "total number of elements mismatch in the expansion"); + } + + int64_t size; + if (matchPattern(en.value(), m_TorchConstantInt(&size))) + outputShape[outputDim] = size; + } + + SmallVector collapsedShape = + isCollapse ? outputShape : llvm::to_vector(inputShape); + SmallVector expandedShape = + isCollapse ? llvm::to_vector(inputShape) : outputShape; + + // The while loop does the following: + // 1. Fill in the reassociation indices for dimensions that are expanded. + // Check the interval dimensions between two unchanged dims in the + // collapsedShape. If the interval is size 1, associate all the dims + // in the expandedShape shape until the next unchanged dim. If the interval + // is larger than size 1, figure out the associations with assumptions that + // dynamic dimensions are not splitted. + // 2. Set collapsedShape and expandedShape following the requirements by + // tensor.expand_shape verification code: + // a. As long as one or more of the related dimensions in the expanded + // shape is dynamic the collapsed dimension is dynamic. + // b. If all of the related dimensions are static, the collapsed + // dimension must be static. In other words, if a collapsed dimension is + // dynamic, at least one of the related dimensions need to be dynamic. + int64_t collapsedDim = 0, expandedDim = 0; + while (collapsedDim < collapsedRank && expandedDim < expandedRank) { + // Not empty means the associations has been filled in and the dimension + // is unchanged. + if (!reassociation[collapsedDim].empty()) { + if (expandedDim != reassociation[collapsedDim][0]) + return op.emitOpError("Unsupported: expanded dims are off from the " + "expected dim got from reassociation"); + collapsedDim++; + expandedDim++; + continue; + } + + // Collect the dims that are collapsed until hitting the next dim that's + // unchanged. + SmallVector collapsedDims; + while (collapsedDim < collapsedRank && + reassociation[collapsedDim].empty()) { + collapsedDims.push_back(collapsedDim); + collapsedDim++; + } + // the next reassociation is for a dim that's unchanged. + int64_t expandedDimNext = collapsedDim != collapsedRank + ? reassociation[collapsedDim][0] + : expandedRank; + if (collapsedDims.size() == 1) { + int64_t collapsedDimSize = 1; + int64_t collapsedDim = collapsedDims[0]; + for (auto i : llvm::seq(expandedDim, expandedDimNext)) { + reassociation[collapsedDim].push_back(i); + if (collapsedDimSize == kUnknownSize) + continue; + + int64_t expandedDimSize = expandedShape[i]; + if (expandedDimSize == kUnknownSize) { + collapsedDimSize = kUnknownSize; + continue; } - reassociation[i].push_back(j++); - resultShape.push_back(outputDim); - expandedDim /= outputDim; - } while (expandedDim != 1); + collapsedDimSize *= expandedShape[i]; + } + // To meet both requirements from tensor.expand_shape verification code. + collapsedShape[collapsedDim] = collapsedDimSize; + expandedDim = expandedDimNext; + continue; + } + + // collpasedDims are expanded to [expandedDim, expandedDimNext) + if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size()) + op.emitError("unimplemented: mixed of expanding and collapsing " + "operations for view"); + for (auto collapsedDim : collapsedDims) { + if (collapsedShape[collapsedDim] == kUnknownSize) { + if (expandedDim >= expandedDimNext) { + return rewriter.notifyMatchFailure( + op, + "desired size is not compatible with the input tensor size"); + } + checkDimEqualHelper(rewriter, loc, + collapsedShapeTorchInt[collapsedDim], + expandedShapeTorchInt[expandedDim]); + // To meet the second requirement from tensor.expand_shape + // verification code. + expandedShape[expandedDim] = kUnknownSize; + reassociation[collapsedDim].push_back(expandedDim++); + } else { + int64_t remainingSizeToExpand = collapsedShape[collapsedDim]; + // A do-while loop is used here to handle the cases where the + // collapsed shape tensor has a dimension of size 1. + do { + int64_t expandedDimSize = expandedShape[expandedDim]; + if (expandedDim >= expandedDimNext || + expandedShape[expandedDim] == kUnknownSize || + remainingSizeToExpand % expandedDimSize != 0) { + return rewriter.notifyMatchFailure( + op, "total number of elements mismatch in the expansion"); + } + reassociation[collapsedDim].push_back(expandedDim++); + remainingSizeToExpand /= expandedDimSize; + } while (remainingSizeToExpand != 1); + } } } - // Make sure that the splitted dimensions have the same number of elements - // as the dimension got splitted from. - if (j != resultRank) - return rewriter.notifyMatchFailure( - op, "desired size is not compatible with the input tensor size"); - Type expandType = - RankedTensorType::get(resultShape, resultType.getElementType()); - Value expandOp = rewriter.create( - loc, expandType, adaptor.self(), reassociation); - rewriter.replaceOpWithNewOp(op, resultType, expandOp); + if (collapsedDim != collapsedRank || expandedDim != expandedRank) + return rewriter.notifyMatchFailure(op, "view shape is not supported"); + Type adjustedResultType = + RankedTensorType::get(isCollapse ? collapsedShape : expandedShape, + resultType.getElementType()); + Type adjustedInputType = + RankedTensorType::get(isCollapse ? expandedShape : collapsedShape, + resultType.getElementType()); + Value castedInput = + rewriter.create(loc, adjustedInputType, input); + Value result = + isCollapse + ? rewriter + .create( + loc, adjustedResultType, castedInput, reassociation) + .result() + : rewriter + .create( + loc, adjustedResultType, castedInput, reassociation) + .result(); + rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } };