mirror of https://github.com/llvm/torch-mlir
Lower aten::view with linalg.collapse and linalg.expand
We only handle the expanding OR collapsing cases, we do not handle expanding And collapsing happening at the same time or cases where it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor. It's assumed that if a shape list element is got from `aten.size(tensor, dim)` the corresponding dim is not splitted or collapsed. This assumption makes it easier to deal with dynamic shapes.pull/496/head snapshot-20211217.149
parent
bc9abbc1c9
commit
d8ba68119e
|
@ -9,7 +9,6 @@ from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ViewExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -46,3 +45,60 @@ class ViewDynamicExpandModule(torch.nn.Module):
|
|||
def ViewDynamicExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 30, 384))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(a.size(0), a.size(1), 12, 32)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewDynamicExpandWithAtenSizeIntModule())
|
||||
def ViewDynamicExpandWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4, 384))
|
||||
|
||||
# ==============================================================================
|
||||
class ViewCollapseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(8)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseModule())
|
||||
def ViewCollapseModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, a, b, c):
|
||||
return a.view(a.size(0), int(b), int(c), a.size(3), 384)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseDynamicWithAtenSizeIntModule())
|
||||
def ViewCollapseDynamicWithAtenSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 5, 4, 12, 32), torch.tensor(3), torch.tensor(5))
|
||||
|
|
|
@ -108,6 +108,34 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
|
|||
return detail::torch_list_construct_op_binder(bind_values);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
||||
struct torch_tensor_size_int_op_binder {
|
||||
int64_t *dim;
|
||||
Value tensor;
|
||||
|
||||
/// Creates a matcher instance that binds the value to dim if match succeeds.
|
||||
torch_tensor_size_int_op_binder(Value tensor, int64_t *dim)
|
||||
: dim(dim), tensor(tensor) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
if (auto atenSizeIntOp = dyn_cast<Torch::AtenSizeIntOp>(op)) {
|
||||
if (atenSizeIntOp.self() == tensor) {
|
||||
if (matchPattern(atenSizeIntOp.dim(), m_TorchConstantInt(dim)))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the tensor and dim of `torch.size.int`.
|
||||
inline detail::torch_tensor_size_int_op_binder
|
||||
m_TorchTensorSizeInt(Value tensor, int64_t *dim) {
|
||||
return detail::torch_tensor_size_int_op_binder(tensor, dim);
|
||||
}
|
||||
|
||||
/// Create code to copy `tensor` to type `newType`.
|
||||
///
|
||||
/// This involves two independent steps, which we keep orthogonal in our
|
||||
|
|
|
@ -133,8 +133,13 @@ static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
|
|||
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
||||
}
|
||||
|
||||
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
|
||||
return b.create<tensor::DimOp>(loc, v, dimension);
|
||||
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
||||
if (auto tensorType = v.getType().cast<RankedTensorType>()) {
|
||||
if (!tensorType.isDynamicDim(dim))
|
||||
return b.create<arith::ConstantOp>(
|
||||
loc, b.getIndexAttr(tensorType.getShape()[dim]));
|
||||
}
|
||||
return b.create<tensor::DimOp>(loc, v, dim);
|
||||
}
|
||||
|
||||
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
||||
|
@ -2671,84 +2676,214 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = adaptor.self();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
auto resultType =
|
||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t resultRank = resultType.getRank();
|
||||
// When we only have expansion of dimensions in `aten.View`, the output
|
||||
// tensor rank will be strictly greater than the input tensor rank.
|
||||
// TODO: Handle the cases of `aten.View` op where,
|
||||
// 1. One or multiple dimensions are collapsed.
|
||||
// 2. Few dimensions are expanded and few other dimensions are collapsed.
|
||||
if (inputRank >= resultRank) {
|
||||
// Currently, we only handle the expanding OR collapsing cases, we do not
|
||||
// handle expanding And collapsing happening at the same time or cases where
|
||||
// it's neither collapsing nor expanding like view of [2,3] for 3x2 tensor.
|
||||
// TODO: For the expanding And collapsing case, we will need to identify
|
||||
// which dimensions are collapsing and which are expanding and do it in two
|
||||
// steps.
|
||||
// TODO: For neither collapsing nor expanding, we could find a intermediate
|
||||
// shape to collapse and then expanded to the target shape. Like [2,3] =>
|
||||
// [6] => [3, 2].
|
||||
if (inputRank == resultRank)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: operand tensor rank should be strictly less than "
|
||||
"the desired output rank");
|
||||
}
|
||||
op, "unimplemented: the view op is neither expanding nor collapsing");
|
||||
|
||||
if (resultRank == 0)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"result shape of rank 0 is invalid");
|
||||
|
||||
// TODO: add support for case inputRank 0 expanded to size 1
|
||||
if (inputRank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input rank 0 is not supported");
|
||||
|
||||
bool isCollapse = inputRank > resultRank ? true : false;
|
||||
int64_t collapsedRank = isCollapse ? resultRank : inputRank;
|
||||
int64_t expandedRank = isCollapse ? inputRank : resultRank;
|
||||
|
||||
// Extract the desired output size as a list of integers. This list should
|
||||
// have been created using the operation `torch.prim.ListConstruct`.
|
||||
SmallVector<Value> expectedSizeTorchInt;
|
||||
if (!getListConstructElements(op.size(), expectedSizeTorchInt)) {
|
||||
SmallVector<Value> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.size(), outputSizeTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: the desired size is "
|
||||
"unimplemented: the target size is "
|
||||
"not constructed from ListConstruct");
|
||||
}
|
||||
SmallVector<Value> expectedSize = getTypeConvertedValues(
|
||||
rewriter, loc, typeConverter, expectedSizeTorchInt);
|
||||
if (resultRank != (int64_t)expectedSize.size()) {
|
||||
SmallVector<Value> outputSizeInt = getTypeConvertedValues(
|
||||
rewriter, loc, typeConverter, outputSizeTorchInt);
|
||||
if (resultRank != (int64_t)outputSizeInt.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "desired size list length mismatches with the result type rank");
|
||||
}
|
||||
SmallVector<Value> inputSizeTorchInt = getTensorSizes(rewriter, loc, input);
|
||||
ArrayRef<Value> expandedShapeTorchInt =
|
||||
llvm::makeArrayRef(isCollapse ? inputSizeTorchInt : outputSizeInt);
|
||||
ArrayRef<Value> collapsedShapeTorchInt =
|
||||
llvm::makeArrayRef(isCollapse ? outputSizeInt : inputSizeTorchInt);
|
||||
|
||||
// Check if the `aten.View` can be legalized to `linalg.TensorExpandShape`.
|
||||
// It only handles the case of static dimension expansion. If the dimension
|
||||
// is dynamic, it must not be expanded/splitted.
|
||||
// TODO: Handle the case of dynamic dimension expansion.
|
||||
SmallVector<ReassociationIndices> reassociation(inputRank);
|
||||
SmallVector<int64_t> resultShape;
|
||||
int64_t j = 0;
|
||||
for (auto i : llvm::seq<int64_t>(0, inputRank)) {
|
||||
if (inputType.isDynamicDim(i)) {
|
||||
Value dim = getDimOp(rewriter, loc, input, i);
|
||||
if (j >= resultRank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "desired size is not compatible with the input tensor size");
|
||||
// Iterate through the view op size list to do the following:
|
||||
//
|
||||
// 1. Combine output size list and input tensor type info to get the most
|
||||
// static outputShape.
|
||||
//
|
||||
// 2. Fill in the reassociation for size list item where the output dim size
|
||||
// is got from `torch.aten.size.int(inputTensor, inputDim)`. We naively
|
||||
// assume this means the corresponding dimension is not expanded or
|
||||
// collapsed. Note this may technically not always be true.
|
||||
// TODO: think of a way better way to at least detect when this assumption
|
||||
// is violated.
|
||||
SmallVector<int64_t> outputShape(resultRank, kUnknownSize);
|
||||
SmallVector<ReassociationIndices> reassociation(collapsedRank);
|
||||
for (auto en : llvm::enumerate(outputSizeTorchInt)) {
|
||||
int64_t inputDim;
|
||||
int64_t outputDim = en.index();
|
||||
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
||||
if (matchPattern(en.value(),
|
||||
m_TorchTensorSizeInt(op.self(), &inputDim))) {
|
||||
auto collapsedDim = isCollapse ? outputDim : inputDim;
|
||||
auto expandedDim = isCollapse ? inputDim : outputDim;
|
||||
reassociation[collapsedDim].push_back(expandedDim);
|
||||
if (!inputType.isDynamicDim(inputDim)) {
|
||||
outputShape[outputDim] = inputShape[inputDim];
|
||||
continue;
|
||||
}
|
||||
checkDimEqualHelper(rewriter, loc, dim, expectedSize[j]);
|
||||
reassociation[i].push_back(j++);
|
||||
resultShape.push_back(kUnknownSize);
|
||||
} else {
|
||||
int64_t expandedDim = inputType.getDimSize(i);
|
||||
int64_t outputDim;
|
||||
// A do-while loop is used here to handle the cases where the input
|
||||
// tensor has a dimension of size 1.
|
||||
do {
|
||||
if (j >= resultRank ||
|
||||
!matchPattern(expectedSizeTorchInt[j],
|
||||
m_TorchConstantInt(&outputDim)) ||
|
||||
expandedDim % outputDim != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "total number of elements mismatch in the expansion");
|
||||
}
|
||||
|
||||
int64_t size;
|
||||
if (matchPattern(en.value(), m_TorchConstantInt(&size)))
|
||||
outputShape[outputDim] = size;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> collapsedShape =
|
||||
isCollapse ? outputShape : llvm::to_vector(inputShape);
|
||||
SmallVector<int64_t> expandedShape =
|
||||
isCollapse ? llvm::to_vector(inputShape) : outputShape;
|
||||
|
||||
// The while loop does the following:
|
||||
// 1. Fill in the reassociation indices for dimensions that are expanded.
|
||||
// Check the interval dimensions between two unchanged dims in the
|
||||
// collapsedShape. If the interval is size 1, associate all the dims
|
||||
// in the expandedShape shape until the next unchanged dim. If the interval
|
||||
// is larger than size 1, figure out the associations with assumptions that
|
||||
// dynamic dimensions are not splitted.
|
||||
// 2. Set collapsedShape and expandedShape following the requirements by
|
||||
// tensor.expand_shape verification code:
|
||||
// a. As long as one or more of the related dimensions in the expanded
|
||||
// shape is dynamic the collapsed dimension is dynamic.
|
||||
// b. If all of the related dimensions are static, the collapsed
|
||||
// dimension must be static. In other words, if a collapsed dimension is
|
||||
// dynamic, at least one of the related dimensions need to be dynamic.
|
||||
int64_t collapsedDim = 0, expandedDim = 0;
|
||||
while (collapsedDim < collapsedRank && expandedDim < expandedRank) {
|
||||
// Not empty means the associations has been filled in and the dimension
|
||||
// is unchanged.
|
||||
if (!reassociation[collapsedDim].empty()) {
|
||||
if (expandedDim != reassociation[collapsedDim][0])
|
||||
return op.emitOpError("Unsupported: expanded dims are off from the "
|
||||
"expected dim got from reassociation");
|
||||
collapsedDim++;
|
||||
expandedDim++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collect the dims that are collapsed until hitting the next dim that's
|
||||
// unchanged.
|
||||
SmallVector<int64_t> collapsedDims;
|
||||
while (collapsedDim < collapsedRank &&
|
||||
reassociation[collapsedDim].empty()) {
|
||||
collapsedDims.push_back(collapsedDim);
|
||||
collapsedDim++;
|
||||
}
|
||||
// the next reassociation is for a dim that's unchanged.
|
||||
int64_t expandedDimNext = collapsedDim != collapsedRank
|
||||
? reassociation[collapsedDim][0]
|
||||
: expandedRank;
|
||||
if (collapsedDims.size() == 1) {
|
||||
int64_t collapsedDimSize = 1;
|
||||
int64_t collapsedDim = collapsedDims[0];
|
||||
for (auto i : llvm::seq<int64_t>(expandedDim, expandedDimNext)) {
|
||||
reassociation[collapsedDim].push_back(i);
|
||||
if (collapsedDimSize == kUnknownSize)
|
||||
continue;
|
||||
|
||||
int64_t expandedDimSize = expandedShape[i];
|
||||
if (expandedDimSize == kUnknownSize) {
|
||||
collapsedDimSize = kUnknownSize;
|
||||
continue;
|
||||
}
|
||||
reassociation[i].push_back(j++);
|
||||
resultShape.push_back(outputDim);
|
||||
expandedDim /= outputDim;
|
||||
} while (expandedDim != 1);
|
||||
collapsedDimSize *= expandedShape[i];
|
||||
}
|
||||
// To meet both requirements from tensor.expand_shape verification code.
|
||||
collapsedShape[collapsedDim] = collapsedDimSize;
|
||||
expandedDim = expandedDimNext;
|
||||
continue;
|
||||
}
|
||||
|
||||
// collpasedDims are expanded to [expandedDim, expandedDimNext)
|
||||
if (expandedDimNext - expandedDim < (int64_t)collapsedDims.size())
|
||||
op.emitError("unimplemented: mixed of expanding and collapsing "
|
||||
"operations for view");
|
||||
for (auto collapsedDim : collapsedDims) {
|
||||
if (collapsedShape[collapsedDim] == kUnknownSize) {
|
||||
if (expandedDim >= expandedDimNext) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"desired size is not compatible with the input tensor size");
|
||||
}
|
||||
checkDimEqualHelper(rewriter, loc,
|
||||
collapsedShapeTorchInt[collapsedDim],
|
||||
expandedShapeTorchInt[expandedDim]);
|
||||
// To meet the second requirement from tensor.expand_shape
|
||||
// verification code.
|
||||
expandedShape[expandedDim] = kUnknownSize;
|
||||
reassociation[collapsedDim].push_back(expandedDim++);
|
||||
} else {
|
||||
int64_t remainingSizeToExpand = collapsedShape[collapsedDim];
|
||||
// A do-while loop is used here to handle the cases where the
|
||||
// collapsed shape tensor has a dimension of size 1.
|
||||
do {
|
||||
int64_t expandedDimSize = expandedShape[expandedDim];
|
||||
if (expandedDim >= expandedDimNext ||
|
||||
expandedShape[expandedDim] == kUnknownSize ||
|
||||
remainingSizeToExpand % expandedDimSize != 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "total number of elements mismatch in the expansion");
|
||||
}
|
||||
reassociation[collapsedDim].push_back(expandedDim++);
|
||||
remainingSizeToExpand /= expandedDimSize;
|
||||
} while (remainingSizeToExpand != 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make sure that the splitted dimensions have the same number of elements
|
||||
// as the dimension got splitted from.
|
||||
if (j != resultRank)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "desired size is not compatible with the input tensor size");
|
||||
|
||||
Type expandType =
|
||||
RankedTensorType::get(resultShape, resultType.getElementType());
|
||||
Value expandOp = rewriter.create<linalg::TensorExpandShapeOp>(
|
||||
loc, expandType, adaptor.self(), reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, expandOp);
|
||||
if (collapsedDim != collapsedRank || expandedDim != expandedRank)
|
||||
return rewriter.notifyMatchFailure(op, "view shape is not supported");
|
||||
Type adjustedResultType =
|
||||
RankedTensorType::get(isCollapse ? collapsedShape : expandedShape,
|
||||
resultType.getElementType());
|
||||
Type adjustedInputType =
|
||||
RankedTensorType::get(isCollapse ? expandedShape : collapsedShape,
|
||||
resultType.getElementType());
|
||||
Value castedInput =
|
||||
rewriter.create<tensor::CastOp>(loc, adjustedInputType, input);
|
||||
Value result =
|
||||
isCollapse
|
||||
? rewriter
|
||||
.create<linalg::TensorCollapseShapeOp>(
|
||||
loc, adjustedResultType, castedInput, reassociation)
|
||||
.result()
|
||||
: rewriter
|
||||
.create<linalg::TensorExpandShapeOp>(
|
||||
loc, adjustedResultType, castedInput, reassociation)
|
||||
.result();
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue