mirror of https://github.com/llvm/torch-mlir
Add torch.aten.flatten.using_ints and aten.MaxPool2d linalg lowering
- torch.aten.flatten.using_ints to linalg lowering - torch.aten.max_pool2d to linalg lowering - Support torch.aten.conv2d for more flexible dilation and strides valuespull/270/head
parent
496051163f
commit
0342b73bf1
|
@ -10,9 +10,11 @@ from torch_mlir_torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MmModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -22,21 +24,26 @@ class MmModule(torch.nn.Module):
|
|||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmModule())
|
||||
def MmModule_chained(module, tu: TestUtils):
|
||||
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
module.forward(res, res)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# A subgraph with multiple mm ops.
|
||||
class MmDagModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -46,15 +53,19 @@ class MmDagModule(torch.nn.Module):
|
|||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, torch.mm(lhs, rhs))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmDagModule())
|
||||
def MmDagModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 4), tu.rand(4, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MmTanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -63,26 +74,109 @@ class MmTanhModule(torch.nn.Module):
|
|||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.tanh(self.matmul(lhs, rhs))
|
||||
|
||||
def matmul(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmTanhModule())
|
||||
def MmTanhModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2), tu.rand(2, 4))
|
||||
|
||||
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((1,1))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
class AdaptiveAvgPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule())
|
||||
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9))
|
||||
|
||||
|
||||
class FlattenStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flat = torch.nn.Flatten(2, 4)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([10, 3, 8, 9, 3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.flat(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: FlattenStaticModule())
|
||||
def FlattenStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||
|
||||
|
||||
class FlattenRank0Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flat = torch.nn.Flatten(-1, -1)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.flat(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: FlattenRank0Module())
|
||||
def FlattenRank0Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(4.0))
|
||||
|
||||
|
||||
class FlattenDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flat = torch.nn.Flatten(2, 4)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, 9, 3, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.flat(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: FlattenDynamicModule())
|
||||
def FlattenDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||
|
||||
|
||||
class MaxPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mp2d = torch.nn.MaxPool2d(kernel_size=[6, 8],
|
||||
stride=[2, 2],
|
||||
padding=[3, 4],
|
||||
dilation=2)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 1, 20, 20) - 0.5)
|
||||
|
|
|
@ -9,42 +9,75 @@ from torch_mlir_torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class Conv2dNoPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(2, 10, 3, bias = False)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
class Conv2dNoPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(2, 10, 3, bias=False)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dNoPaddingModule())
|
||||
def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
class Conv2dWithPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(2, 10, 3, bias = False, padding = 3)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
class Conv2dWithPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding=3)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dWithPaddingModule())
|
||||
def Conv2dWithPaddingModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
|
||||
class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(in_channels=2,
|
||||
out_channels=10,
|
||||
kernel_size=3,
|
||||
padding=3,
|
||||
stride=2,
|
||||
dilation=3,
|
||||
bias=False)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideModule())
|
||||
def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
|
|
@ -131,6 +131,28 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return a * b.flatten(-1, -1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseFlattenBroadcastModule())
|
||||
def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseReluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -15,19 +15,16 @@ XFAIL_SETS = {}
|
|||
# These represent further work needed in npcomp to lower them properly
|
||||
# to the backend contract.
|
||||
_common_npcomp_lowering_xfails = {
|
||||
'ResNet18Module_basic',
|
||||
'QuantizedMLP_basic',
|
||||
}
|
||||
|
||||
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
|
||||
|
||||
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
|
||||
# https://github.com/google/iree/pull/6407
|
||||
'MmDagModule_basic',
|
||||
'Mlp1LayerModule_basic',
|
||||
'Mlp2LayerModule_basic',
|
||||
'Conv2dNoPaddingModule_basic',
|
||||
'AdaptiveAvgPool2dModule_basic',
|
||||
# https://github.com/google/iree/issues/6416
|
||||
'Conv2dWithPaddingModule_basic',
|
||||
#https://reviews.llvm.org/D106658 to reach iree release
|
||||
'MaxPool2dModule_basic',
|
||||
'Conv2dWithPaddingDilationStrideModule_basic',
|
||||
#https://github.com/google/iree/issues/6420
|
||||
'FlattenDynamicModule_basic',
|
||||
'ResNet18Module_basic'
|
||||
}
|
||||
|
|
|
@ -65,7 +65,10 @@ class ValueReport:
|
|||
|
||||
@property
|
||||
def failed(self):
|
||||
return not torch.allclose(self.value, self.golden_value, rtol=1e-03, atol=1e-07)
|
||||
if self.value.size() != self.golden_value.size():
|
||||
return True
|
||||
return not torch.allclose(
|
||||
self.value, self.golden_value, rtol=1e-03, atol=1e-07)
|
||||
|
||||
def error_str(self):
|
||||
assert self.failed
|
||||
|
|
|
@ -79,6 +79,87 @@ static bool isConstantIntListMatching(Value &value,
|
|||
return true;
|
||||
}
|
||||
|
||||
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
|
||||
return b.create<tensor::DimOp>(loc, v, dimension);
|
||||
}
|
||||
|
||||
// Helper function to caculate the output tensor dims for convolution-like ops.
|
||||
// Along each dim:
|
||||
// dim_out =
|
||||
// floor((dim_in + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1
|
||||
static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
|
||||
Value paddingInt, Value dilationInt,
|
||||
Value kernelSizeInt, Value strideInt) {
|
||||
Type intType = b.getIntegerType(64);
|
||||
Type indexType = b.getIndexType();
|
||||
auto castIndexToInt = [&](Value v) {
|
||||
return b.create<IndexCastOp>(loc, intType, v);
|
||||
};
|
||||
auto castIntToIndex = [&](Value v) {
|
||||
return b.create<IndexCastOp>(loc, indexType, v);
|
||||
};
|
||||
Value c1 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(1));
|
||||
Value c2 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(2));
|
||||
|
||||
Value doublePadding = b.create<MulIOp>(loc, paddingInt, c2);
|
||||
// in + 2 * padding
|
||||
Value inAddDoublePadding =
|
||||
b.create<AddIOp>(loc, castIndexToInt(in), doublePadding);
|
||||
|
||||
// dilation * (kernelSize - 1)
|
||||
Value kernelSizeSub1 = b.create<SubIOp>(loc, kernelSizeInt, c1);
|
||||
Value dilationTimesKernelSize =
|
||||
b.create<MulIOp>(loc, dilationInt, kernelSizeSub1);
|
||||
|
||||
Value temp =
|
||||
b.create<SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
|
||||
Value dividend = b.create<SubIOp>(loc, temp, c1);
|
||||
Value division = b.create<SignedFloorDivIOp>(loc, dividend, strideInt);
|
||||
Value out = b.create<AddIOp>(loc, division, c1);
|
||||
return castIntToIndex(out);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
getAsConstantIntValues(OpBuilder &b, Location loc,
|
||||
SmallVectorImpl<int64_t> &ints) {
|
||||
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
|
||||
return b.create<ConstantOp>(loc, b.getIntegerAttr(b.getI64Type(), val));
|
||||
}));
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
getAsConstantIndexValues(OpBuilder &b, Location loc,
|
||||
SmallVectorImpl<int64_t> &ints) {
|
||||
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
|
||||
return b.create<ConstantOp>(loc, b.getIndexAttr(val));
|
||||
}));
|
||||
}
|
||||
|
||||
static SmallVector<OpFoldResult>
|
||||
getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
|
||||
return llvm::to_vector<4>(llvm::map_range(
|
||||
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
|
||||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values.
|
||||
// It's assumed that the padding on the low end and high end are the same.
|
||||
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &paddingInts) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
"input must be RankedTensorType");
|
||||
Location loc = op->getLoc();
|
||||
Value c0float = b.create<ConstantOp>(
|
||||
loc, FloatAttr::get(
|
||||
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
|
||||
SmallVector<OpFoldResult> paddings = getAsOpFoldResult(b, loc, paddingInts);
|
||||
Type ranked4DTensorType = linalg::PadTensorOp::inferResultType(
|
||||
input.getType().cast<RankedTensorType>(), paddingInts, paddingInts);
|
||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
||||
ranked4DTensorType, input, c0float, /*low=*/paddings, /*high=*/paddings,
|
||||
loc, b);
|
||||
return paddedInput;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenAdaptiveAvgPool2dOp
|
||||
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
|
||||
|
@ -106,11 +187,8 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only support output_size with H and W both equal to constant 1");
|
||||
|
||||
auto getDimOp = [&](Value v, int dimension) {
|
||||
return rewriter.create<tensor::DimOp>(loc, v, dimension);
|
||||
};
|
||||
Value N = getDimOp(input, 0);
|
||||
Value C = getDimOp(input, 1);
|
||||
Value N = getDimOp(rewriter, loc, input, 0);
|
||||
Value C = getDimOp(rewriter, loc, input, 1);
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, C}, elementType);
|
||||
Value c0 =
|
||||
|
@ -144,8 +222,8 @@ public:
|
|||
.getResult(0);
|
||||
|
||||
// Calculate H*W so that avg can be got from sum / (H*W)
|
||||
Value H = getDimOp(input, 2);
|
||||
Value W = getDimOp(input, 3);
|
||||
Value H = getDimOp(rewriter, loc, input, 2);
|
||||
Value W = getDimOp(rewriter, loc, input, 3);
|
||||
auto castIndexToInt = [&](Value v) {
|
||||
return rewriter.create<IndexCastOp>(loc, IntegerType::get(context, 64),
|
||||
v);
|
||||
|
@ -200,100 +278,79 @@ public:
|
|||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
op.emitError("unimplemented: non-floating point type");
|
||||
|
||||
Type intType = IntegerType::get(context, 64);
|
||||
Type indexType = IndexType::get(context);
|
||||
|
||||
auto getDimOp = [&](Value v, int dimension) {
|
||||
return rewriter.create<tensor::DimOp>(loc, v, dimension);
|
||||
};
|
||||
|
||||
auto castIntToIndex = [&](Value v) {
|
||||
return rewriter.create<IndexCastOp>(loc, indexType, v);
|
||||
};
|
||||
|
||||
auto castIndexToInt = [&](Value v) {
|
||||
return rewriter.create<IndexCastOp>(loc, intType, v);
|
||||
};
|
||||
|
||||
Value N = getDimOp(input, 0);
|
||||
Value Hin = getDimOp(input, 2);
|
||||
Value Win = getDimOp(input, 3);
|
||||
Value F = getDimOp(weight, 0);
|
||||
Value weightH = getDimOp(weight, 2);
|
||||
Value weightW = getDimOp(weight, 3);
|
||||
Value N = getDimOp(rewriter, loc, input, 0);
|
||||
Value Hin = getDimOp(rewriter, loc, input, 2);
|
||||
Value Win = getDimOp(rewriter, loc, input, 3);
|
||||
Value F = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightH = getDimOp(rewriter, loc, weight, 2);
|
||||
Value weightW = getDimOp(rewriter, loc, weight, 3);
|
||||
|
||||
Value c1 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||
Value c2 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 2));
|
||||
|
||||
llvm::SmallVector<int64_t> paddingIntValues;
|
||||
if (!matchPattern(padding, m_TorchConstantIntList(paddingIntValues))) {
|
||||
llvm::SmallVector<int64_t> paddingInts;
|
||||
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant padding values");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> expects{1, 1};
|
||||
if (!isConstantIntListMatching(stride, expects))
|
||||
return rewriter.notifyMatchFailure(op, "only support stride [1, 1]");
|
||||
if (!isConstantIntListMatching(dilation, expects))
|
||||
return rewriter.notifyMatchFailure(op, "only support dilation [1, 1]");
|
||||
// Unit strides and dilations.
|
||||
auto linalgStrides = rewriter.getI64VectorAttr({1, 1});
|
||||
auto linalgDilations = rewriter.getI64VectorAttr({1, 1});
|
||||
|
||||
llvm::SmallVector<int64_t, 2> strideInts;
|
||||
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int strides");
|
||||
llvm::SmallVector<int64_t, 2> dilationInts;
|
||||
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (!op.bias().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "only support None bias");
|
||||
|
||||
Value c1 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||
Value groupEqual1 =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, groups, c1);
|
||||
rewriter.create<AssertOp>(loc, groupEqual1,
|
||||
rewriter.getStringAttr("expect groups to be 1"));
|
||||
|
||||
// Pad the input tensor according to padding.
|
||||
Value paddingH = rewriter.create<ConstantOp>(
|
||||
loc, intType,
|
||||
IntegerAttr::get(IntegerType::get(context, 64), paddingIntValues[0]));
|
||||
Value paddingW = rewriter.create<ConstantOp>(
|
||||
loc, intType,
|
||||
IntegerAttr::get(IntegerType::get(context, 64), paddingIntValues[1]));
|
||||
Value paddingHIndexType = castIntToIndex(paddingH);
|
||||
Value paddingWIndexType = castIntToIndex(paddingW);
|
||||
auto c0IndexAttr = rewriter.getIndexAttr(0);
|
||||
Value c0float =
|
||||
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
SmallVector<OpFoldResult, 4> paddings = {
|
||||
c0IndexAttr, c0IndexAttr, paddingHIndexType, paddingWIndexType};
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
Value paddedInput =
|
||||
getPaddedTensor(op, rewriter, input, paddingIncludingNC);
|
||||
|
||||
Type ranked4DTensorType =
|
||||
RankedTensorType::get({-1, -1, -1, -1}, elementType);
|
||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
||||
ranked4DTensorType, input, c0float, /*low=*/paddings, /*high=*/paddings,
|
||||
loc, rewriter);
|
||||
SmallVector<Value> paddingIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> dilationIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
// Caculate the output tensor dims.
|
||||
// Along each dim: dim_out = dim_in + 2*padding - weightSize + 1
|
||||
auto getOutputDim = [&](Value in, Value paddingIntType, Value weightSize) {
|
||||
Value doublePadding = rewriter.create<MulIOp>(loc, paddingIntType, c2);
|
||||
// in + 2*paddingIntType
|
||||
Value inAddDoublePadding =
|
||||
rewriter.create<AddIOp>(loc, castIndexToInt(in), doublePadding);
|
||||
Value weightSizeIntType = castIndexToInt(weightSize);
|
||||
Value temp =
|
||||
rewriter.create<SubIOp>(loc, inAddDoublePadding, weightSizeIntType);
|
||||
Value out = rewriter.create<AddIOp>(loc, temp, c1);
|
||||
return castIntToIndex(out);
|
||||
};
|
||||
Value Hout = getOutputDim(Hin, paddingH, weightH);
|
||||
Value Wout = getOutputDim(Win, paddingW, weightW);
|
||||
Value Hout = getOutputDimForConvOps(
|
||||
rewriter, loc, Hin, paddingIntValues[0], dilationIntValues[0],
|
||||
castIndexToInt(weightH), strideIntValues[0]);
|
||||
Value Wout = getOutputDimForConvOps(
|
||||
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
|
||||
castIndexToInt(weightW), strideIntValues[1]);
|
||||
|
||||
Value c0float = rewriter.create<ConstantOp>(
|
||||
loc,
|
||||
FloatAttr::get(
|
||||
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, F, Hout, Wout}, elementType);
|
||||
Value initTensor0 =
|
||||
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
Value conv2d =
|
||||
rewriter
|
||||
.create<linalg::Conv2DNchwOp>(
|
||||
loc, ranked4DTensorType, ValueRange{paddedInput, weight},
|
||||
ValueRange{initTensor0}, linalgStrides, linalgDilations)
|
||||
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
|
||||
initTensor0, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
||||
|
@ -518,14 +575,11 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||
|
||||
auto getDimOp = [&](Value v, int dimension) {
|
||||
return rewriter.create<tensor::DimOp>(loc, v, dimension);
|
||||
};
|
||||
Value inputDim0 = getDimOp(input, 0);
|
||||
Value inputDim1 = getDimOp(input, 1);
|
||||
Value weightDim0 = getDimOp(weight, 0);
|
||||
Value weightDim1 = getDimOp(weight, 1);
|
||||
Value biasDim0 = getDimOp(bias, 0);
|
||||
Value inputDim0 = getDimOp(rewriter, loc, input, 0);
|
||||
Value inputDim1 = getDimOp(rewriter, loc, input, 1);
|
||||
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
Value contractingDimEqual =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, inputDim1, weightDim1);
|
||||
rewriter.create<AssertOp>(
|
||||
|
@ -839,6 +893,165 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMaxPool2dOp op, llvm::ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
AtenMaxPool2dOp::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
Value kernelSize = adaptor.kernel_size();
|
||||
Value stride = adaptor.stride();
|
||||
Value padding = adaptor.padding();
|
||||
Value dilation = adaptor.dilation();
|
||||
Value ceilMode = adaptor.ceil_mode();
|
||||
|
||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
op.emitError("unimplemented: non-floating point type");
|
||||
|
||||
llvm::SmallVector<int64_t, 2> strideInts;
|
||||
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int strides");
|
||||
llvm::SmallVector<int64_t, 2> dilationInts;
|
||||
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
llvm::SmallVector<int64_t, 2> paddingInts;
|
||||
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int paddings");
|
||||
llvm::SmallVector<int64_t, 2> kernelSizeInts;
|
||||
if (!matchPattern(kernelSize, m_TorchConstantIntList(kernelSizeInts)))
|
||||
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
|
||||
|
||||
Value falseValue = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
|
||||
Value ceilModeFalse =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, ceilMode, falseValue);
|
||||
rewriter.create<AssertOp>(
|
||||
loc, ceilModeFalse,
|
||||
rewriter.getStringAttr("only ceil_mode false is supported"));
|
||||
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
Value paddedInput = getPaddedTensor(op, rewriter, self, paddingIncludingNC);
|
||||
|
||||
Value N = getDimOp(rewriter, loc, self, 0);
|
||||
Value C = getDimOp(rewriter, loc, self, 1);
|
||||
Value H = getDimOp(rewriter, loc, self, 2);
|
||||
Value W = getDimOp(rewriter, loc, self, 3);
|
||||
|
||||
SmallVector<Value> paddingIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> dilationIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||
SmallVector<Value> kernelSizeIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
|
||||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
Value Hout = getOutputDimForConvOps(
|
||||
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
|
||||
kernelSizeIntValues[0], strideIntValues[0]);
|
||||
Value Wout = getOutputDimForConvOps(
|
||||
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
|
||||
kernelSizeIntValues[1], strideIntValues[1]);
|
||||
|
||||
// Initialize output tensor with smallest floating point value
|
||||
Value outTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, C, Hout, Wout}, elementType);
|
||||
auto initialAttr = rewriter.getFloatAttr(
|
||||
elementType,
|
||||
APFloat::getSmallest(
|
||||
elementType.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*Negative*/ true));
|
||||
Value initValue = rewriter.create<ConstantOp>(loc, initialAttr);
|
||||
Value outTensorInitialized =
|
||||
rewriter.create<linalg::FillOp>(loc, initValue, outTensor).getResult(0);
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
|
||||
elementType);
|
||||
|
||||
Value maxPool2d = rewriter
|
||||
.create<linalg::PoolingNchwMaxOp>(
|
||||
loc, outTensorInitialized.getType(),
|
||||
ValueRange{paddedInput, windowTensor},
|
||||
outTensorInitialized, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenFlattenUsingIntsOp
|
||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenFlattenUsingIntsOp op, llvm::ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
int64_t startDim;
|
||||
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&startDim)))
|
||||
return rewriter.notifyMatchFailure(op, "start_dim must be constant");
|
||||
int64_t endDim;
|
||||
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&endDim)))
|
||||
return rewriter.notifyMatchFailure(op, "start_dim must be constant");
|
||||
auto type = operands[0].getType().cast<RankedTensorType>();
|
||||
auto inputRank = type.getRank();
|
||||
auto resultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
if (startDim < 0)
|
||||
startDim += inputRank;
|
||||
if (endDim < 0)
|
||||
endDim += inputRank;
|
||||
|
||||
if (inputRank == 0) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0");
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
|
||||
op, resultType, operands[0], reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (startDim < 0 || startDim >= inputRank || endDim < 0 ||
|
||||
endDim >= inputRank || startDim > endDim)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "statically invalid flattening dim range");
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation(resultType.getRank());
|
||||
int j = 0;
|
||||
for (auto i : llvm::seq<int64_t>(0, inputRank)) {
|
||||
reassociation[j].push_back(i);
|
||||
if (i < startDim || i >= endDim)
|
||||
j++;
|
||||
}
|
||||
Value collapsedTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
|
||||
op->getLoc(), operands[0], reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
collapsedTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
|
||||
public:
|
||||
|
@ -929,6 +1142,10 @@ public:
|
|||
patterns.add<ConvertAtenConv2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlattenUsingIntsOp>();
|
||||
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
|
@ -154,6 +155,7 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced in our linalg lowering.
|
||||
|
|
|
@ -265,9 +265,9 @@ public:
|
|||
ValueKnowledge::getPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = operand.dtype;
|
||||
if (operand.hasSizes && operand.sizes.size() == 0) {
|
||||
// Rank 0 is special and flattens to rank 1.
|
||||
// Rank 0 is special and flattens to rank 1 with size 1.
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.push_back(kUnknownSize);
|
||||
knowledge.sizes.push_back(1);
|
||||
} else if (operand.hasSizes &&
|
||||
matchPattern(flatten.start_dim(),
|
||||
m_TorchConstantInt(&startDim)) &&
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
|
||||
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%int4 = torch.constant.int 4
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
return %0 : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
|
||||
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
|
||||
%int-5 = torch.constant.int -5
|
||||
%int-3 = torch.constant.int -3
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int-5, %int-3 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
|
||||
return %0 : !torch.vtensor<[3,3,?,3,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
|
||||
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int2 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
|
||||
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
|
||||
|
||||
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int-1 = torch.constant.int -1
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12],f32>
|
||||
return %0 : !torch.vtensor<[?,12],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
|
||||
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
|
||||
|
||||
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
|
@ -145,7 +145,7 @@ func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_rank0(
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[?],f32>
|
||||
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[1],f32>
|
||||
func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
|
||||
%end = torch.constant.int -1
|
||||
%start = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue