mirror of https://github.com/llvm/torch-mlir
Linalg lowering for aten.conv2d(bias=True)
Previously aten.conv2d was only lowered if there was no bias. Here lowering is extended to support bias.pull/469/head
parent
c598e01529
commit
2414bdb1f0
|
@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
|
||||||
module.forward(t)
|
module.forward(t)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dBiasNoPaddingModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.conv = torch.nn.Conv2d(2, 10, 3, bias=True)
|
||||||
|
self.train(False)
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule())
|
||||||
|
def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
|
||||||
|
t = tu.rand(5, 2, 10, 20)
|
||||||
|
module.forward(t)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dWithPaddingModule(torch.nn.Module):
|
class Conv2dWithPaddingModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -442,9 +442,6 @@ public:
|
||||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int dilations");
|
"only support constant int dilations");
|
||||||
if (!op.bias().getType().isa<Torch::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(op, "only support None bias");
|
|
||||||
|
|
||||||
Value c1 =
|
Value c1 =
|
||||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, 1));
|
||||||
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
|
Value groupEqual1 = rewriter.create<arith::CmpIOp>(
|
||||||
|
@ -473,22 +470,47 @@ public:
|
||||||
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
|
rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1],
|
||||||
castIndexToInt(weightW), strideIntValues[1]);
|
castIndexToInt(weightW), strideIntValues[1]);
|
||||||
|
|
||||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc,
|
|
||||||
FloatAttr::get(
|
|
||||||
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
|
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, ValueRange{N, F, Hout, Wout}, elementType);
|
loc, ValueRange{N, F, Hout, Wout}, elementType);
|
||||||
Value initTensor0 =
|
|
||||||
rewriter.create<linalg::FillOp>(loc, c0float, initTensor).getResult(0);
|
Value bias = adaptor.bias();
|
||||||
|
Value biasInitTensor;
|
||||||
|
if (bias.getType().isa<Torch::NoneType>()) {
|
||||||
|
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, FloatAttr::get(elementType, 0.0));
|
||||||
|
biasInitTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
||||||
|
.getResult(0);
|
||||||
|
} else {
|
||||||
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
|
if (biasType.getRank() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||||
|
if (elementType != biasType.getElementType())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||||
|
|
||||||
|
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
// bias is used to initialize the channels - dimension 1 of output
|
||||||
|
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
||||||
|
rewriter.getAffineDimExpr(1), context),
|
||||||
|
rewriter.getMultiDimIdentityMap(resultRank)};
|
||||||
|
SmallVector<StringRef> iteratorTypes(resultRank, "parallel");
|
||||||
|
biasInitTensor = rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, initTensor.getType(), bias, initTensor,
|
||||||
|
indexingMaps, iteratorTypes,
|
||||||
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(loc, args[0]);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||||
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
auto dilationAttr = rewriter.getI64VectorAttr(dilationInts);
|
||||||
Value conv2d =
|
Value conv2d =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::Conv2DNchwFchwOp>(
|
.create<linalg::Conv2DNchwFchwOp>(
|
||||||
loc, initTensor0.getType(), ValueRange{paddedInput, weight},
|
loc, biasInitTensor.getType(), ValueRange{paddedInput, weight},
|
||||||
initTensor0, stridesAttr, dilationAttr)
|
biasInitTensor, stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv2d);
|
||||||
|
|
Loading…
Reference in New Issue