mirror of https://github.com/llvm/torch-mlir
[linalg] Implement strict mode lowering for aten.view. (#3319)
* Enables assume_strict_symbolic_shapes on fx_importer imported programs, indicating strict shape semantics. * Reworks the view->reshape lowering to take advantage of strict mode and do one of: * Collapse to 0D * Flatten/Unflatten when there is an inferred dim. * Fallback to tensor.reshape * Splits some test cases up and adds an attribute to control the old pattern (so new corners can be tested in strict mode in isolation). * Dynamic inferred mode needs upstream work to generalize expand_shape (so that case is suppressed here). * Deletes the assert from the existing tensor.reshape lowering if strict shape mode is enabled (since the condition it is dynamically asserting cannot happen).pull/3328/head
parent
adafd51823
commit
00efec0b73
|
@ -940,6 +940,9 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"legacy view lowering diabled");
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
|
@ -1284,6 +1287,9 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (op->getParentOp()->hasAttr("torch.disable_legacy_view"))
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"legacy view lowering diabled");
|
||||
SmallVector<Value> sizes;
|
||||
if (!getListConstructElements(op.getSize(), sizes))
|
||||
return op.emitError(
|
||||
|
@ -1319,12 +1325,16 @@ public:
|
|||
size = convert;
|
||||
}
|
||||
|
||||
// Check we are only inferring one dimension:
|
||||
Value countPred =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, countPred,
|
||||
b.getStringAttr("must have at most one inferred (negative) dimension"));
|
||||
// Check we are only inferring one dimension if not in strict mode. In
|
||||
// strict mode, there will only ever statically be one inferred dim.
|
||||
if (!isAssumingStrictSymbolicShapes(rewriter)) {
|
||||
Value countPred =
|
||||
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, count, one);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, countPred,
|
||||
b.getStringAttr(
|
||||
"must have at most one inferred (negative) dimension"));
|
||||
}
|
||||
|
||||
// Determine the total size of the inferred dimension and update the
|
||||
// inferred dimension:
|
||||
|
@ -1356,6 +1366,165 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenViewOpStrict : public OpConversionPattern<AtenViewOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isAssumingStrictSymbolicShapes(rewriter))
|
||||
return rewriter.notifyMatchFailure(op.getLoc(),
|
||||
"not strict symbolic shapes");
|
||||
SmallVector<Value> sizeValues;
|
||||
if (!getListConstructElements(op.getSize(), sizeValues))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor size list is not from list construct");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||
|
||||
// Handle collapse to 0D.
|
||||
if (sizeValues.empty()) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
||||
op, resultType, adaptor.getSelf(), ArrayRef<ReassociationIndices>{});
|
||||
return success();
|
||||
}
|
||||
|
||||
// If there is a static inferred dimension (-1), then we emit a
|
||||
// flatten/unflatten and let that proceed through its lowering.
|
||||
// Otherwise, emit a tensor.reshape. Note that this relies on the fact that
|
||||
// Torch does not allow such an op to have a symbolic inferred dim.
|
||||
int inferredDim = -1;
|
||||
bool staticSizes = true;
|
||||
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
|
||||
int64_t dim;
|
||||
if (!matchPattern(sizeValues[i], m_TorchConstantInt(&dim))) {
|
||||
staticSizes = false;
|
||||
continue;
|
||||
}
|
||||
if (dim == -1) {
|
||||
inferredDim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// While it should be illegal to have a view op with fully known sizes
|
||||
// and a dynamic shape, in reality, torch IR is a bit loosey and
|
||||
// progressively resolves to this state. There are delicate invariants
|
||||
// on the ops we produce that require this, so we enforce.
|
||||
if (staticSizes && !resultType.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(loc,
|
||||
"view cannot be converted with static "
|
||||
"sizes and a dynamic result type");
|
||||
}
|
||||
|
||||
// Handle inferred dim case.
|
||||
// TODO: Remove the restriction on staticSizes once flatten/unflatten
|
||||
// reliably work with multiple dynamic dimensions.
|
||||
if (inferredDim >= 0 && staticSizes) {
|
||||
if (!staticSizes) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "view to flatten/unflatten only supported for static sizes");
|
||||
}
|
||||
// This is a torch-torch conversion, so only non adapted types are
|
||||
// involved.
|
||||
auto selfTy = dyn_cast<ValueTensorType>(op.getSelf().getType());
|
||||
if (!selfTy || !selfTy.hasSizes())
|
||||
return failure();
|
||||
|
||||
// Work out the 1D flattened type.
|
||||
int64_t flatDim = 1;
|
||||
auto selfSizes = selfTy.getSizes();
|
||||
for (int64_t dim : selfSizes) {
|
||||
if (dim == kUnknownSize) {
|
||||
flatDim = kUnknownSize;
|
||||
break;
|
||||
}
|
||||
flatDim *= dim;
|
||||
}
|
||||
// Flatten to 1D.
|
||||
ValueTensorType flatType = rewriter.getType<ValueTensorType>(
|
||||
ArrayRef<int64_t>{flatDim}, selfTy.getOptionalDtype());
|
||||
Value dimStart = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value dimEnd = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1));
|
||||
Value flatSelf = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
|
||||
loc, flatType, op.getSelf(), dimStart, dimEnd);
|
||||
|
||||
// Unflatten to requested size.
|
||||
rewriter.replaceOpWithNewOp<AtenUnflattenIntOp>(
|
||||
op, op.getResult().getType(), flatSelf, dimStart, op.getSize());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Generate output dims, either based on whether there is an inferred dim
|
||||
// present or all dims are specified.
|
||||
auto sizeTy = cast<IntegerType>(
|
||||
typeConverter->convertType(sizeValues.front().getType()));
|
||||
SmallVector<Value> outputDimValues;
|
||||
assert(sizeTy && "Type converter did not handle size");
|
||||
if (inferredDim >= 0) {
|
||||
// Inferred dim. If the above flatten/unflatten logic ever catches
|
||||
// everything, this branch can go away entirely.
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc, sizeTy, rewriter.getIntegerAttr(sizeTy, 1));
|
||||
Value sizeProduct = one;
|
||||
// Multiply the non-inferred target sizes.
|
||||
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
|
||||
if (i == inferredDim)
|
||||
continue;
|
||||
Value size = sizeValues[i];
|
||||
Value convertedSize = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, sizeTy, size);
|
||||
assert(convertedSize && "Type converter did not handle size");
|
||||
sizeProduct =
|
||||
rewriter.create<arith::MulIOp>(loc, sizeProduct, convertedSize);
|
||||
}
|
||||
|
||||
// Multiply the self tensor sizes.
|
||||
Value selfProduct = one;
|
||||
for (int i = 0, e = selfTy.getRank(); i < e; ++i) {
|
||||
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
|
||||
Value dim = rewriter.create<tensor::DimOp>(loc, self, index);
|
||||
dim = rewriter.create<arith::IndexCastOp>(loc, sizeTy, dim);
|
||||
selfProduct = rewriter.create<arith::MulIOp>(loc, selfProduct, dim);
|
||||
}
|
||||
|
||||
Value inferredSize =
|
||||
rewriter.create<arith::DivUIOp>(loc, selfProduct, sizeProduct);
|
||||
for (int i = 0, e = sizeValues.size(); i < e; ++i) {
|
||||
if (i == inferredDim) {
|
||||
outputDimValues.push_back(inferredSize);
|
||||
} else {
|
||||
outputDimValues.push_back(typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, sizeTy, sizeValues[i]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No inferred dim. So output dims are just pass through.
|
||||
for (Value torchSize : sizeValues) {
|
||||
outputDimValues.push_back(typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, sizeTy, torchSize));
|
||||
}
|
||||
}
|
||||
|
||||
// Normal lowering to reshape with fully computed sizes.
|
||||
auto outputDimsTy = RankedTensorType::get(
|
||||
outputDimValues.size(), outputDimValues.front().getType());
|
||||
auto outputDims = rewriter.create<tensor::FromElementsOp>(loc, outputDimsTy,
|
||||
outputDimValues);
|
||||
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(
|
||||
op, resultType, adaptor.getSelf(), outputDims);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
|
||||
public:
|
||||
|
@ -2459,6 +2628,9 @@ SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
|
|||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
// Add some legal ops for torch-torch lowering.
|
||||
target.addLegalOp<ConstantIntOp>();
|
||||
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenReflectionPad1dOp>();
|
||||
patterns.add<ConvertAtenReflectionPad1dOp>(typeConverter, context);
|
||||
|
@ -2468,10 +2640,23 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
patterns.add<ConvertAtenUnflattenIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnflattenIntOp>();
|
||||
|
||||
// View op sadness: In the future, we only want ConvertAtenViewOpStrict,
|
||||
// but this requires work upstream to fully generalize reshape handling.
|
||||
// In the meantime, the analysis based ConvertAtenViewOp tries hard to
|
||||
// produce expand/collapse shapes, the ConvertAtenViewOpStrict does the
|
||||
// right thing but cannot be fully supported for dynamic shapes, and
|
||||
// ConvertAtenViewOpToReshape overly pessimizes and generates a lot of IR
|
||||
// due to not statically switching between inferred and non-inferred view
|
||||
// cases. They are ordered by optimiality of the lowerings they generate
|
||||
// when they are able.
|
||||
target.addIllegalOp<AtenViewOp>();
|
||||
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/200);
|
||||
patterns.add<ConvertAtenViewOp>(typeConverter, context, /*benefit=*/300);
|
||||
patterns.add<ConvertAtenViewOpStrict>(typeConverter, context,
|
||||
/*benefit=*/200);
|
||||
patterns.add<ConvertAtenViewOpToReshape>(typeConverter, context,
|
||||
/*benefit=*/100);
|
||||
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeDimOp>();
|
||||
|
|
|
@ -103,6 +103,7 @@ from ..ir import (
|
|||
StringAttr,
|
||||
SymbolTable,
|
||||
Type as IrType,
|
||||
UnitAttr,
|
||||
Value,
|
||||
)
|
||||
|
||||
|
@ -642,6 +643,10 @@ class FxImporter:
|
|||
func_op = func_dialect.FuncOp(
|
||||
func_name, ftype, ip=self._m_ip, visibility=func_visibility
|
||||
)
|
||||
# Programs imported from FX have strong guarantees. Setting this attribute
|
||||
# causes various lowerings to be able to emit more efficient code or
|
||||
# handle more cases. See isAssumingStrictSymbolicShapes().
|
||||
func_op.attributes["torch.assume_strict_symbolic_shapes"] = UnitAttr.get()
|
||||
entry_block = Block.create_at_start(func_op.body, ftype.inputs)
|
||||
|
||||
node_importer = GraphNodeImporter(
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$twotothree(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32>
|
||||
|
||||
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
|
@ -21,13 +22,15 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
|
@ -40,13 +43,15 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest2(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor<?x6x?xf32>
|
||||
// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<?x2x3x?xf32> -> !torch.vtensor<[?,2,3,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> {
|
||||
func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%int0 = torch.constant.int 0
|
||||
|
@ -60,7 +65,7 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamicVal(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32>
|
||||
|
@ -68,7 +73,9 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !
|
|||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
|
||||
func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int128 = torch.constant.int 128
|
||||
%int1 = torch.constant.int 1
|
||||
%int16 = torch.constant.int 16
|
||||
|
@ -80,7 +87,7 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten$dynamicValOutput(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2]] : tensor<4x5x6xf32> into tensor<120xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [8, 1, 15, 1] : tensor<120xf32> into tensor<8x1x15x1xf32>
|
||||
|
@ -88,7 +95,9 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !
|
|||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<8x1x?x1xf32> -> !torch.vtensor<[8,1,?,1],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[8,1,?,1],f32>
|
||||
|
||||
func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> {
|
||||
func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int8 = torch.constant.int 8
|
||||
%int1 = torch.constant.int 1
|
||||
%int-1 = torch.constant.int -1
|
||||
|
@ -100,7 +109,7 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten$dynamicValOutput2(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2]] : tensor<4x5x6xf32> into tensor<4x30xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [2, 1, 2, 3, 10] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32>
|
||||
|
@ -109,7 +118,9 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t
|
|||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,1,2,3,?],f32>
|
||||
|
||||
// 4 -> [2,1,2] [5,6] -> [3,10].
|
||||
func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> {
|
||||
func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 3
|
||||
|
@ -122,14 +133,16 @@ func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32>
|
||||
|
||||
func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
|
||||
func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%int-1 = torch.constant.int -1
|
||||
|
@ -141,7 +154,7 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$singleUnknownMatches0(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,3,?,2,3],f32> -> tensor<10x3x?x2x3xf32>
|
||||
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3, 4]] : tensor<10x3x?x2x3xf32> into tensor<30x?x6xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
|
@ -154,7 +167,9 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
|
|||
// Associations are,
|
||||
// -- for collapse, [0,1], [2], [3,4] and
|
||||
// -- for expand [0,1,2], [3], [4].
|
||||
func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> {
|
||||
func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%int6 = torch.constant.int 6
|
||||
|
@ -175,13 +190,15 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2
|
|||
// but one which matches between the input and the output
|
||||
|
||||
// CHECK: func.func @torch.aten.view$combineConcepts(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32>
|
||||
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32>
|
||||
|
||||
func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> {
|
||||
func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
|
||||
%int1 = torch.constant.int 1
|
||||
%size1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[8,?,?,?,2,1,3], f32>, !torch.int -> !torch.int
|
||||
|
@ -200,12 +217,14 @@ func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$multiDynamicsInSourceOfCollapse
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,2,?,4,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,2,?,4,?],f32>) -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,2,?,4,?],f32> -> tensor<?x2x?x4x?xf32>
|
||||
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x2x?x4x?xf32> into tensor<?xf32>
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[COLLAPSE]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtensor<[?,2,?,4,?], f32>) -> !torch.vtensor<[?], f32> {
|
||||
func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtensor<[?,2,?,4,?], f32>) -> !torch.vtensor<[?], f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,2,?,4,?], f32>, !torch.list<int> -> !torch.vtensor<[?], f32>
|
||||
|
@ -215,7 +234,7 @@ func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtens
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$castingView
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||
|
||||
// The current lowring only succeeds if the input (arg0) has shape [3,4,5],
|
||||
// determined at runtime. This is a bit limiting, and we'll probably want to
|
||||
|
@ -225,7 +244,9 @@ func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtens
|
|||
// CHECK-COUNT-2: cf.assert {{.*}} "mismatching contracting dimension
|
||||
// CHECK: return {{.*}} : !torch.vtensor<[3,4,5],f32>
|
||||
|
||||
func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32> {
|
||||
func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int4 = torch.constant.int 4
|
||||
%int5 = torch.constant.int 5
|
||||
|
@ -240,7 +261,7 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) ->
|
|||
// We expect this to lower to a collapse with [0], [1], [2,3] followed by
|
||||
// an expand with [0,1], [2], [3]:
|
||||
// CHECK: func.func @torch.aten.view$dynamicInferredSame(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,?,2,3],f32> -> tensor<10x?x2x3xf32>
|
||||
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<10x?x2x3xf32> into tensor<10x?x6xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
|
@ -249,7 +270,9 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) ->
|
|||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x5x?x6xf32> -> !torch.vtensor<[2,5,?,6],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,5,?,6],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> {
|
||||
func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%int2 = torch.constant.int 2
|
||||
%int5 = torch.constant.int 5
|
||||
%int6 = torch.constant.int 6
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// Since we want to migrate to the strict view op lowering, these test cases
|
||||
// verify this one pattern specifically via attributes on the functions that
|
||||
// disable the legacy behavior.
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$twotothree
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
|
||||
// CHECK: %[[T3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[T2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[N2:.*]] = torch_c.to_i64 %[[T2]]
|
||||
// CHECK: %[[N3:.*]] = torch_c.to_i64 %[[T3]]
|
||||
// CHECK: %[[ELEMENTS:.*]] = tensor.from_elements %[[N2]], %[[N3]] : tensor<2xi64>
|
||||
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[ARG0]](%[[ELEMENTS]]) : (tensor<3x2xf32>, tensor<2xi64>) -> tensor<2x3xf32>
|
||||
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
|
||||
return %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$zerod
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0
|
||||
// CHECK: tensor.collapse_shape %0 [] : tensor<?x?xf32> into tensor<f32>
|
||||
func.func @torch.aten.view$zerod(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
return %1 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0
|
||||
// CHECK: %[[ARG1:.*]] = torch_c.to_i64 %arg1
|
||||
// CHECK: %[[ARG2:.*]] = torch_c.to_i64 %arg2
|
||||
// CHECK: %[[ELTS:.*]] = tensor.from_elements %[[ARG1]], %[[ARG2]] : tensor<2xi64>
|
||||
// CHECK: tensor.reshape %[[ARG0]](%[[ELTS]]) : (tensor<?x?xf32>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%2 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest2(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32>
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor<?x6x?xf32>
|
||||
// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]]
|
||||
// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<?x2x3x?xf32> -> !torch.vtensor<[?,2,3,?],f32>
|
||||
// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32>
|
||||
|
||||
func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%int3 = torch.constant.int 3
|
||||
%int2 = torch.constant.int 2
|
||||
%int0 = torch.constant.int 0
|
||||
%2 = torch.aten.size.int %arg0, %int2 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int
|
||||
%1 = torch.prim.ListConstruct %0, %int2, %int3, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %1 : !torch.vtensor<[?,6,?],f32>, !torch.list<int> -> !torch.vtensor<[?,2,3,?], f32>
|
||||
return %3 : !torch.vtensor<[?,2,3,?], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamicVal(
|
||||
// CHECK: tensor.reshape {{.*}} : (tensor<1x?x128xf32>, tensor<3xi64>) -> tensor<16x1x128xf32>
|
||||
func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%int128 = torch.constant.int 128
|
||||
%int1 = torch.constant.int 1
|
||||
%int16 = torch.constant.int 16
|
||||
%0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list<int> -> !torch.vtensor<[16,1,128],f32>
|
||||
return %1 : !torch.vtensor<[16,1,128],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim
|
||||
// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32>
|
||||
// CHECK: %[[CAST1:.*]] = tensor.cast %[[COLLAPSED]] : tensor<12xf32> to tensor<12xf32>
|
||||
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[CAST1]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32>
|
||||
func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>
|
||||
return %1 : !torch.vtensor<[3,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Note that this is presently going down a fallback path as an explicit
|
||||
// reshape. Someday, this should generate flatten/unflatten.
|
||||
// CHECK-LABEL: func.func @torch.aten$dynamicValOutput
|
||||
// CHECK: %[[SELF:.*]] = torch_c.to_builtin_tensor %arg0
|
||||
// CHECK: %[[CONSTANT1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[PROD1:.*]] = arith.constant 1
|
||||
// CHECK-DAG: %[[ARG1_CVT:.*]] = torch_c.to_i64 %arg1
|
||||
// CHECK-DAG: %[[PROD2:.*]] = arith.muli %[[PROD1]], %[[ARG1_CVT]]
|
||||
// CHECK-DAG: %[[ONEI64:.*]] = torch_c.to_i64 %[[CONSTANT1]]
|
||||
// CHECK-DAG: %[[PROD3:.*]] = arith.muli %[[PROD2]], %[[ONEI64]]
|
||||
// CHECK-DAG: %[[ONEI64_0:.*]] = torch_c.to_i64 %[[CONSTANT1]]
|
||||
// CHECK-DAG: %[[PROD4:.*]] = arith.muli %[[PROD3]], %[[ONEI64_0]]
|
||||
// CHECK-DAG: %[[INDEX0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[DIM0_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX0]] : tensor<?x?x?xf32>
|
||||
// CHECK-DAG: %[[DIM0:.*]] = arith.index_cast %[[DIM0_INDEX]] : index to i64
|
||||
// CHECK-DAG: %[[KNOWN0:.*]] = arith.muli %[[PROD1]], %[[DIM0]] : i64
|
||||
// CHECK-DAG: %[[INDEX1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[DIM1_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX1]] : tensor<?x?x?xf32>
|
||||
// CHECK-DAG: %[[DIM1:.*]] = arith.index_cast %[[DIM1_INDEX]] : index to i64
|
||||
// CHECK-DAG: %[[KNOWN1:.*]] = arith.muli %[[KNOWN0]], %[[DIM1]] : i64
|
||||
// CHECK-DAG: %[[INDEX2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[DIM2_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX2]] : tensor<?x?x?xf32>
|
||||
// CHECK-DAG: %[[DIM2:.*]] = arith.index_cast %[[DIM2_INDEX]] : index to i64
|
||||
// CHECK-DAG: %[[KNOWN2:.*]] = arith.muli %[[KNOWN1]], %[[DIM2]] : i64
|
||||
// CHECK-DAG: %[[DIMINFER:.*]] = arith.divui %[[KNOWN2]], %[[PROD4]] : i64
|
||||
// CHECK: %[[DIM0:.*]] = torch_c.to_i64 %arg1
|
||||
// CHECK: %[[DIM1:.*]] = torch_c.to_i64 %[[CONSTANT1]]
|
||||
// CHECK: %[[DIM3:.*]] = torch_c.to_i64 %[[CONSTANT1]]
|
||||
// CHECK: %[[OUTPUT_DIMS:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]], %[[DIMINFER]], %[[DIM3]] : tensor<4xi64>
|
||||
// CHECK: tensor.reshape %[[SELF]](%[[OUTPUT_DIMS]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x1x?x1xf32>
|
||||
//
|
||||
func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[?, ?, ?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,1,?,1],f32>
|
||||
attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view}
|
||||
{
|
||||
%int1 = torch.constant.int 1
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.prim.ListConstruct %arg1, %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?, ?, ?],f32>, !torch.list<int> -> !torch.vtensor<[?,1,?,1],f32>
|
||||
return %1 : !torch.vtensor<[?,1,?,1],f32>
|
||||
}
|
Loading…
Reference in New Issue