Linalg lowering for aten.conv2d(bias=True)

Previously aten.conv2d was only lowered if there was no bias.
Here lowering is extended to support bias.
pull/469/head
Liam Fitzpatrick 2021-12-08 21:52:29 +00:00 committed by Sean Silva
parent c598e01529
commit 2414bdb1f0
2 changed files with 55 additions and 11 deletions

View File

@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
module.forward(t) module.forward(t)
class Conv2dBiasNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(2, 10, 3, bias=True)
self.train(False)
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule())
def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
class Conv2dWithPaddingModule(torch.nn.Module): class Conv2dWithPaddingModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -442,9 +442,6 @@ public:
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int dilations"); "only support constant int dilations");
if (!op.bias().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "only support None bias");
Value c1 = Value c1 =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1)); rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
Value groupEqual1 = rewriter.create<arith::CmpIOp>( Value groupEqual1 = rewriter.create<arith::CmpIOp>(
@ -473,22 +470,47 @@ public:
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1], rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
castIndexToInt(weightW), strideIntValues[1]); castIndexToInt(weightW), strideIntValues[1]);
Value c0float = rewriter.create<arith::ConstantOp>(
loc,
FloatAttr::get(
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
Value initTensor = rewriter.create<linalg::InitTensorOp>( Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{N, F, Hout, Wout}, elementType); loc, ValueRange{N, F, Hout, Wout}, elementType);
Value initTensor0 =
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0); Value bias = adaptor.bias();
Value biasInitTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
if (elementType != biasType.getElementType())
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
}
auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
Value conv2d = Value conv2d =
rewriter rewriter
.create<linalg::Conv2DNchwFchwOp>( .create<linalg::Conv2DNchwFchwOp>(
loc, initTensor0.getType(), ValueRange{paddedInput, weight}, loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
initTensor0, stridesAttr, dilationAttr) biasInitTensor, stridesAttr, dilationAttr)
.getResult(0); .getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);