mirror of https://github.com/llvm/torch-mlir
Linalg lowering for aten.conv2d(bias=True)
Previously aten.conv2d was only lowered if there was no bias. Here lowering is extended to support bias.pull/469/head
parent
c598e01529
commit
2414bdb1f0
|
@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
|
|||
module.forward(t)
|
||||
|
||||
|
||||
class Conv2dBiasNoPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(2, 10, 3, bias=True)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule())
|
||||
def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
|
||||
class Conv2dWithPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -442,9 +442,6 @@ public:
|
|||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
if (!op.bias().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "only support None bias");
|
||||
|
||||
Value c1 =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
|
||||
|
@ -473,22 +470,47 @@ public:
|
|||
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
|
||||
castIndexToInt(weightW), strideIntValues[1]);
|
||||
|
||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||
loc,
|
||||
FloatAttr::get(
|
||||
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{N, F, Hout, Wout}, elementType);
|
||||
Value initTensor0 =
|
||||
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);
|
||||
|
||||
Value bias = adaptor.bias();
|
||||
Value biasInitTensor;
|
||||
if (bias.getType().isa<Torch::NoneType>()) {
|
||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
||||
.getResult(0);
|
||||
} else {
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
if (biasType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||
if (elementType != biasType.getElementType())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||
|
||||
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
// bias is used to initialize the channels - dimension 1 of output
|
||||
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
||||
rewriter.getAffineDimExpr(1), context),
|
||||
rewriter.getMultiDimIdentityMap(resultRank)};
|
||||
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
|
||||
biasInitTensor = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||
Value conv2d =
|
||||
rewriter
|
||||
.create<linalg::Conv2DNchwFchwOp>(
|
||||
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
|
||||
initTensor0, stridesAttr, dilationAttr)
|
||||
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
|
||||
biasInitTensor, stridesAttr, dilationAttr)
|
||||
.getResult(0);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
||||
|
|
Loading…
Reference in New Issue