Add torch.aten.flatten.using_ints and aten.MaxPool2d linalg lowering

- torch.aten.flatten.using_ints to linalg lowering
- torch.aten.max_pool2d to linalg lowering
- Support torch.aten.conv2d for more flexible dilation and strides values
pull/270/head
Yi Zhang 2021-07-07 20:59:47 +00:00
parent 496051163f
commit 0342b73bf1
10 changed files with 582 additions and 130 deletions

View File

@ -10,9 +10,11 @@ from torch_mlir_torchscript.annotations import annotate_args, export
# ==============================================================================
class MmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
@ -22,21 +24,26 @@ class MmModule(torch.nn.Module):
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmModule())
def MmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
@register_test_case(module_factory=lambda: MmModule())
def MmModule_chained(module, tu: TestUtils):
res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
module.forward(res, res)
# ==============================================================================
# A subgraph with multiple mm ops.
class MmDagModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
@ -46,15 +53,19 @@ class MmDagModule(torch.nn.Module):
def forward(self, lhs, rhs):
return torch.mm(lhs, torch.mm(lhs, rhs))
@register_test_case(module_factory=lambda: MmDagModule())
def MmDagModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))
# ==============================================================================
class MmTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
@ -63,26 +74,109 @@ class MmTanhModule(torch.nn.Module):
])
def forward(self, lhs, rhs):
return torch.tanh(self.matmul(lhs, rhs))
def matmul(self, lhs, rhs):
return torch.mm(lhs, rhs)
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2), tu.rand(2, 4))
class AdaptiveAvgPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((1,1))
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
class AdaptiveAvgPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.aap2d = torch.nn.AdaptiveAvgPool2d((1, 1))
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.aap2d(x)
@register_test_case(module_factory=lambda: AdaptiveAvgPool2dModule())
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9))
class FlattenStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(2, 4)
@export
@annotate_args([
None,
([10, 3, 8, 9, 3, 4], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenStaticModule())
def FlattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
class FlattenRank0Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(-1, -1)
@export
@annotate_args([
None,
([], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenRank0Module())
def FlattenRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(4.0))
class FlattenDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(2, 4)
@export
@annotate_args([
None,
([-1, -1, -1, 9, 3, -1], torch.float32, True),
])
def forward(self, x):
return self.flat(x)
@register_test_case(module_factory=lambda: FlattenDynamicModule())
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp2d = torch.nn.MaxPool2d(kernel_size=[6, 8],
stride=[2, 2],
padding=[3, 4],
dilation=2)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.mp2d(x)
@register_test_case(module_factory=lambda: MaxPool2dModule())
def MaxPool2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 20, 20) - 0.5)

View File

@ -9,42 +9,75 @@ from torch_mlir_torchscript.annotations import annotate_args, export
# ==============================================================================
class Conv2dNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias = False)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
class Conv2dNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias=False)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(module_factory=lambda: Conv2dNoPaddingModule())
def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
class Conv2dWithPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias = False, padding = 3)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
class Conv2dWithPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding=3)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(module_factory=lambda: Conv2dWithPaddingModule())
def Conv2dWithPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(in_channels=2,
out_channels=10,
kernel_size=3,
padding=3,
stride=2,
dilation=3,
bias=False)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideModule())
def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)

View File

@ -131,6 +131,28 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([], torch.float32, True),
])
def forward(self, a, b):
return a * b.flatten(-1, -1)
@register_test_case(module_factory=lambda: ElementwiseFlattenBroadcastModule())
def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6), tu.rand())
# ==============================================================================
class ElementwiseReluModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -15,19 +15,16 @@ XFAIL_SETS = {}
# These represent further work needed in npcomp to lower them properly
# to the backend contract.
_common_npcomp_lowering_xfails = {
'ResNet18Module_basic',
'QuantizedMLP_basic',
}
XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails
XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
# https://github.com/google/iree/pull/6407
'MmDagModule_basic',
'Mlp1LayerModule_basic',
'Mlp2LayerModule_basic',
'Conv2dNoPaddingModule_basic',
'AdaptiveAvgPool2dModule_basic',
# https://github.com/google/iree/issues/6416
'Conv2dWithPaddingModule_basic',
#https://reviews.llvm.org/D106658 to reach iree release
'MaxPool2dModule_basic',
'Conv2dWithPaddingDilationStrideModule_basic',
#https://github.com/google/iree/issues/6420
'FlattenDynamicModule_basic',
'ResNet18Module_basic'
}

View File

@ -65,7 +65,10 @@ class ValueReport:
@property
def failed(self):
return not torch.allclose(self.value, self.golden_value, rtol=1e-03, atol=1e-07)
if self.value.size() != self.golden_value.size():
return True
return not torch.allclose(
self.value, self.golden_value, rtol=1e-03, atol=1e-07)
def error_str(self):
assert self.failed

View File

@ -79,6 +79,87 @@ static bool isConstantIntListMatching(Value &value,
return true;
}
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
return b.create<tensor::DimOp>(loc, v, dimension);
}
// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
// dim_out =
// floor((dim_in + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1
static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt) {
Type intType = b.getIntegerType(64);
Type indexType = b.getIndexType();
auto castIndexToInt = [&](Value v) {
return b.create<IndexCastOp>(loc, intType, v);
};
auto castIntToIndex = [&](Value v) {
return b.create<IndexCastOp>(loc, indexType, v);
};
Value c1 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(2));
Value doublePadding = b.create<MulIOp>(loc, paddingInt, c2);
// in + 2 * padding
Value inAddDoublePadding =
b.create<AddIOp>(loc, castIndexToInt(in), doublePadding);
// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.create<SubIOp>(loc, kernelSizeInt, c1);
Value dilationTimesKernelSize =
b.create<MulIOp>(loc, dilationInt, kernelSizeSub1);
Value temp =
b.create<SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
Value dividend = b.create<SubIOp>(loc, temp, c1);
Value division = b.create<SignedFloorDivIOp>(loc, dividend, strideInt);
Value out = b.create<AddIOp>(loc, division, c1);
return castIntToIndex(out);
}
static SmallVector<Value>
getAsConstantIntValues(OpBuilder &b, Location loc,
SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
return b.create<ConstantOp>(loc, b.getIntegerAttr(b.getI64Type(), val));
}));
}
static SmallVector<Value>
getAsConstantIndexValues(OpBuilder &b, Location loc,
SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
return b.create<ConstantOp>(loc, b.getIndexAttr(val));
}));
}
static SmallVector<OpFoldResult>
getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(llvm::map_range(
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
}
// Helper function to get the padding tensor given the padding int values.
// It's assumed that the padding on the low end and high end are the same.
static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts) {
assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType");
Location loc = op->getLoc();
Value c0float = b.create<ConstantOp>(
loc, FloatAttr::get(
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
SmallVector<OpFoldResult> paddings = getAsOpFoldResult(b, loc, paddingInts);
Type ranked4DTensorType = linalg::PadTensorOp::inferResultType(
input.getType().cast<RankedTensorType>(), paddingInts, paddingInts);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
ranked4DTensorType, input, c0float, /*low=*/paddings, /*high=*/paddings,
loc, b);
return paddedInput;
}
namespace {
class ConvertAtenAdaptiveAvgPool2dOp
: public OpConversionPattern<AtenAdaptiveAvgPool2dOp> {
@ -106,11 +187,8 @@ public:
return rewriter.notifyMatchFailure(
op, "only support output_size with H and W both equal to constant 1");
auto getDimOp = [&](Value v, int dimension) {
return rewriter.create<tensor::DimOp>(loc, v, dimension);
};
Value N = getDimOp(input, 0);
Value C = getDimOp(input, 1);
Value N = getDimOp(rewriter, loc, input, 0);
Value C = getDimOp(rewriter, loc, input, 1);
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C}, elementType);
Value c0 =
@ -144,8 +222,8 @@ public:
.getResult(0);
// Calculate H*W so that avg can be got from sum / (H*W)
Value H = getDimOp(input, 2);
Value W = getDimOp(input, 3);
Value H = getDimOp(rewriter, loc, input, 2);
Value W = getDimOp(rewriter, loc, input, 3);
auto castIndexToInt = [&](Value v) {
return rewriter.create<IndexCastOp>(loc, IntegerType::get(context, 64),
v);
@ -200,100 +278,79 @@ public:
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type");
Type intType = IntegerType::get(context, 64);
Type indexType = IndexType::get(context);
auto getDimOp = [&](Value v, int dimension) {
return rewriter.create<tensor::DimOp>(loc, v, dimension);
};
auto castIntToIndex = [&](Value v) {
return rewriter.create<IndexCastOp>(loc, indexType, v);
};
auto castIndexToInt = [&](Value v) {
return rewriter.create<IndexCastOp>(loc, intType, v);
};
Value N = getDimOp(input, 0);
Value Hin = getDimOp(input, 2);
Value Win = getDimOp(input, 3);
Value F = getDimOp(weight, 0);
Value weightH = getDimOp(weight, 2);
Value weightW = getDimOp(weight, 3);
Value N = getDimOp(rewriter, loc, input, 0);
Value Hin = getDimOp(rewriter, loc, input, 2);
Value Win = getDimOp(rewriter, loc, input, 3);
Value F = getDimOp(rewriter, loc, weight, 0);
Value weightH = getDimOp(rewriter, loc, weight, 2);
Value weightW = getDimOp(rewriter, loc, weight, 3);
Value c1 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value c2 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 2));
llvm::SmallVector<int64_t> paddingIntValues;
if (!matchPattern(padding, m_TorchConstantIntList(paddingIntValues))) {
llvm::SmallVector<int64_t> paddingInts;
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure(
op, "only support constant padding values");
}
SmallVector<int64_t, 2> expects{1, 1};
if (!isConstantIntListMatching(stride, expects))
return rewriter.notifyMatchFailure(op, "only support stride [1, 1]");
if (!isConstantIntListMatching(dilation, expects))
return rewriter.notifyMatchFailure(op, "only support dilation [1, 1]");
// Unit strides and dilations.
auto linalgStrides = rewriter.getI64VectorAttr({1, 1});
auto linalgDilations = rewriter.getI64VectorAttr({1, 1});
llvm::SmallVector<int64_t, 2> strideInts;
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
if (!op.bias().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "only support None bias");
Value c1 = rewriter.create<ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, groups, c1);
rewriter.create<AssertOp>(loc, groupEqual1,
rewriter.getStringAttr("expect groups to be 1"));
// Pad the input tensor according to padding.
Value paddingH = rewriter.create<ConstantOp>(
loc, intType,
IntegerAttr::get(IntegerType::get(context, 64), paddingIntValues[0]));
Value paddingW = rewriter.create<ConstantOp>(
loc, intType,
IntegerAttr::get(IntegerType::get(context, 64), paddingIntValues[1]));
Value paddingHIndexType = castIntToIndex(paddingH);
Value paddingWIndexType = castIntToIndex(paddingW);
auto c0IndexAttr = rewriter.getIndexAttr(0);
Value c0float =
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
SmallVector<OpFoldResult, 4> paddings = {
c0IndexAttr, c0IndexAttr, paddingHIndexType, paddingWIndexType};
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput =
getPaddedTensor(op, rewriter, input, paddingIncludingNC);
Type ranked4DTensorType =
RankedTensorType::get({-1, -1, -1, -1}, elementType);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
ranked4DTensorType, input, c0float, /*low=*/paddings, /*high=*/paddings,
loc, rewriter);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
// Caculate the output tensor dims.
// Along each dim: dim_out = dim_in + 2*padding - weightSize + 1
auto getOutputDim = [&](Value in, Value paddingIntType, Value weightSize) {
Value doublePadding = rewriter.create<MulIOp>(loc, paddingIntType, c2);
// in + 2*paddingIntType
Value inAddDoublePadding =
rewriter.create<AddIOp>(loc, castIndexToInt(in), doublePadding);
Value weightSizeIntType = castIndexToInt(weightSize);
Value temp =
rewriter.create<SubIOp>(loc, inAddDoublePadding, weightSizeIntType);
Value out = rewriter.create<AddIOp>(loc, temp, c1);
return castIntToIndex(out);
};
Value Hout = getOutputDim(Hin, paddingH, weightH);
Value Wout = getOutputDim(Win, paddingW, weightW);
Value Hout = getOutputDimForConvOps(
rewriter, loc, Hin, paddingIntValues[0], dilationIntValues[0],
castIndexToInt(weightH), strideIntValues[0]);
Value Wout = getOutputDimForConvOps(
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
castIndexToInt(weightW), strideIntValues[1]);
Value c0float = rewriter.create<ConstantOp>(
loc,
FloatAttr::get(
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, F, Hout, Wout}, elementType);
Value initTensor0 =
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value conv2d =
rewriter
.create<linalg::Conv2DNchwOp>(
loc, ranked4DTensorType, ValueRange{paddedInput, weight},
ValueRange{initTensor0}, linalgStrides, linalgDilations)
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
initTensor0, stridesAttr, dilationAttr)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
@ -518,14 +575,11 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
auto getDimOp = [&](Value v, int dimension) {
return rewriter.create<tensor::DimOp>(loc, v, dimension);
};
Value inputDim0 = getDimOp(input, 0);
Value inputDim1 = getDimOp(input, 1);
Value weightDim0 = getDimOp(weight, 0);
Value weightDim1 = getDimOp(weight, 1);
Value biasDim0 = getDimOp(bias, 0);
Value inputDim0 = getDimOp(rewriter, loc, input, 0);
Value inputDim1 = getDimOp(rewriter, loc, input, 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
Value contractingDimEqual =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, inputDim1, weightDim1);
rewriter.create<AssertOp>(
@ -839,6 +893,165 @@ struct ConvertElementwiseOp : ConversionPattern {
};
} // namespace
namespace {
class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxPool2dOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
AtenMaxPool2dOp::Adaptor adaptor(operands);
Value self = adaptor.self();
Value kernelSize = adaptor.kernel_size();
Value stride = adaptor.stride();
Value padding = adaptor.padding();
Value dilation = adaptor.dilation();
Value ceilMode = adaptor.ceil_mode();
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type");
llvm::SmallVector<int64_t, 2> strideInts;
if (!matchPattern(stride, m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(dilation, m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
llvm::SmallVector<int64_t, 2> paddingInts;
if (!matchPattern(padding, m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int paddings");
llvm::SmallVector<int64_t, 2> kernelSizeInts;
if (!matchPattern(kernelSize, m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
Value falseValue = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(rewriter.getIntegerType(1), 0));
Value ceilModeFalse =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, ceilMode, falseValue);
rewriter.create<AssertOp>(
loc, ceilModeFalse,
rewriter.getStringAttr("only ceil_mode false is supported"));
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
Value paddedInput = getPaddedTensor(op, rewriter, self, paddingIncludingNC);
Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1);
Value H = getDimOp(rewriter, loc, self, 2);
Value W = getDimOp(rewriter, loc, self, 3);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> kernelSizeIntValues =
getAsConstantIntValues(rewriter, loc, kernelSizeInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
Value Hout = getOutputDimForConvOps(
rewriter, loc, H, paddingIntValues[0], dilationIntValues[0],
kernelSizeIntValues[0], strideIntValues[0]);
Value Wout = getOutputDimForConvOps(
rewriter, loc, W, paddingIntValues[1], dilationIntValues[1],
kernelSizeIntValues[1], strideIntValues[1]);
// Initialize output tensor with smallest floating point value
Value outTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, C, Hout, Wout}, elementType);
auto initialAttr = rewriter.getFloatAttr(
elementType,
APFloat::getSmallest(
elementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative*/ true));
Value initValue = rewriter.create<ConstantOp>(loc, initialAttr);
Value outTensorInitialized =
rewriter.create<linalg::FillOp>(loc, initValue, outTensor).getResult(0);
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value windowTensor = rewriter.create<linalg::InitTensorOp>(
loc, getAsConstantIndexValues(rewriter, loc, kernelSizeInts),
elementType);
Value maxPool2d = rewriter
.create<linalg::PoolingNchwMaxOp>(
loc, outTensorInitialized.getType(),
ValueRange{paddedInput, windowTensor},
outTensorInitialized, stridesAttr, dilationAttr)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();
}
};
} // namespace
namespace {
class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenFlattenUsingIntsOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
int64_t startDim;
if (!matchPattern(op.start_dim(), m_TorchConstantInt(&startDim)))
return rewriter.notifyMatchFailure(op, "start_dim must be constant");
int64_t endDim;
if (!matchPattern(op.end_dim(), m_TorchConstantInt(&endDim)))
return rewriter.notifyMatchFailure(op, "start_dim must be constant");
auto type = operands[0].getType().cast<RankedTensorType>();
auto inputRank = type.getRank();
auto resultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
if (startDim < 0)
startDim += inputRank;
if (endDim < 0)
endDim += inputRank;
if (inputRank == 0) {
SmallVector<ReassociationIndices> reassociation;
if (!(startDim >= -1 && startDim <= 0 && endDim >= -1 && endDim <= 0))
return rewriter.notifyMatchFailure(
op, "start_dim and end_dim must be in [-1, 0] when inputRank is 0");
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
op, resultType, operands[0], reassociation);
return success();
}
if (startDim < 0 || startDim >= inputRank || endDim < 0 ||
endDim >= inputRank || startDim > endDim)
return rewriter.notifyMatchFailure(
op, "statically invalid flattening dim range");
SmallVector<ReassociationIndices> reassociation(resultType.getRank());
int j = 0;
for (auto i : llvm::seq<int64_t>(0, inputRank)) {
reassociation[j].push_back(i);
if (i < startDim || i >= endDim)
j++;
}
Value collapsedTensor = rewriter.create<linalg::TensorCollapseShapeOp>(
op->getLoc(), operands[0], reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
collapsedTensor);
return success();
}
};
} // namespace
namespace {
class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
public:
@ -929,6 +1142,10 @@ public:
patterns.add<ConvertAtenConv2dOp>(typeConverter, context);
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
patterns.add<ConvertAtenAdaptiveAvgPool2dOp>(typeConverter, context);
target.addIllegalOp<AtenFlattenUsingIntsOp>();
patterns.add<ConvertAtenFlattenUsingIntsOp>(typeConverter, context);
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenMaxPool2dOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -9,6 +9,7 @@
#include "npcomp/Dialect/Torch/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
@ -154,6 +155,7 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<FuncOp>(createStdExpandOpsPass());
if (options.optimize) {
// Clean up any non-canonical code introduced in our linalg lowering.

View File

@ -265,9 +265,9 @@ public:
ValueKnowledge::getPessimisticValueState(op->getContext());
knowledge.dtype = operand.dtype;
if (operand.hasSizes && operand.sizes.size() == 0) {
// Rank 0 is special and flattens to rank 1.
// Rank 0 is special and flattens to rank 1 with size 1.
knowledge.hasSizes = true;
knowledge.sizes.push_back(kUnknownSize);
knowledge.sizes.push_back(1);
} else if (operand.hasSizes &&
matchPattern(flatten.start_dim(),
m_TorchConstantInt(&startDim)) &&

View File

@ -0,0 +1,84 @@
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%0 = torch.aten.flatten.using_ints %arg0, %int2, %int4 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
return %0 : !torch.vtensor<[3,3,?,3,5],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$basic_negative(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2,3,3,5],f32> -> tensor<3x3x2x2x3x3x5xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3, 4], [5], [6]] : tensor<3x3x2x2x3x3x5xf32> into tensor<3x3x12x3x5xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x3x12x3x5xf32> to tensor<3x3x?x3x5xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<3x3x?x3x5xf32> -> !torch.vtensor<[3,3,?,3,5],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,3,?,3,5],f32>
func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,2,3,3,5],f32>) -> !torch.vtensor<[3,3,?,3,5],f32> {
%int-5 = torch.constant.int -5
%int-3 = torch.constant.int -3
%0 = torch.aten.flatten.using_ints %arg0, %int-5, %int-3 : !torch.vtensor<[3,3,2,2,3,3,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,?,3,5],f32>
return %0 : !torch.vtensor<[3,3,?,3,5],f32>
}
// -----
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<18x2xf32> to tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int2 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
// CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[COLLAPSED]] : tensor<3x12xf32> to tensor<?x12xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[DYNAMIC]] : tensor<?x12xf32> -> !torch.vtensor<[?,12],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,12],f32>
func @torch.aten.flatten.using_ints$flatten_back(%arg0: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%0 = torch.aten.flatten.using_ints %arg0, %int1, %int-1 : !torch.vtensor<[3,3,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,12],f32>
return %0 : !torch.vtensor<[?,12],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.flatten.using_ints$rank0(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_expand_shape %[[BUILTIN_TENSOR]] [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[RESULT:.*]] = torch.from_builtin_tensor %[[COLLAPSED]] : tensor<1xf32> -> !torch.vtensor<[1],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32>
func @torch.aten.flatten.using_ints$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.flatten.using_ints %arg0, %int0, %int0 : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}

View File

@ -145,7 +145,7 @@ func @flatten_some(%arg0: !torch.tensor<[3,2,?,5],f32>) -> !torch.tensor {
}
// CHECK-LABEL: func @flatten_rank0(
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[?],f32>
// CHECK: torch.aten.flatten.using_ints{{.*}}-> !torch.tensor<[1],f32>
func @flatten_rank0(%arg0: !torch.tensor<[],f32>) -> !torch.tensor {
%end = torch.constant.int -1
%start = torch.constant.int 0