diff --git a/e2e_testing/torchscript/conv.py b/e2e_testing/torchscript/conv.py index 3f43bfc42..4c8c62208 100644 --- a/e2e_testing/torchscript/conv.py +++ b/e2e_testing/torchscript/conv.py @@ -33,6 +33,28 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils): module.forward(t) +class Conv2dBiasNoPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d(2, 10, 3, bias=True) + self.train(False) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dBiasNoPaddingModule()) +def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10, 20) + module.forward(t) + + class Conv2dWithPaddingModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 42e059ea1..322405ea0 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -442,9 +442,6 @@ public: if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - if (!op.bias().getType().isa()) - return rewriter.notifyMatchFailure(op, "only support None bias"); - Value c1 = rewriter.create(loc, IntegerAttr::get(intType, 1)); Value groupEqual1 = rewriter.create( @@ -473,22 +470,47 @@ public: rewriter, loc, Win, paddingIntValues[1], dilationIntValues[1], castIndexToInt(weightW), strideIntValues[1]); - Value c0float = rewriter.create( - loc, - FloatAttr::get( - input.getType().cast().getElementType(), 0.0)); Value initTensor = rewriter.create( loc, ValueRange{N, F, Hout, Wout}, elementType); - Value initTensor0 = - rewriter.create(loc, c0float, initTensor).getResult(0); + + Value bias = adaptor.bias(); + Value biasInitTensor; + if (bias.getType().isa()) { + Value c0float = rewriter.create( + loc, FloatAttr::get(elementType, 0.0)); + biasInitTensor = rewriter.create(loc, c0float, initTensor) + .getResult(0); + } else { + auto biasType = bias.getType().cast(); + if (biasType.getRank() != 1) + return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); + if (elementType != biasType.getElementType()) + return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); + + auto resultRank = initTensor.getType().cast().getRank(); + SmallVector indexingMaps = { + // bias is used to initialize the channels - dimension 1 of output + AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, + rewriter.getAffineDimExpr(1), context), + rewriter.getMultiDimIdentityMap(resultRank)}; + SmallVector iteratorTypes(resultRank, "parallel"); + biasInitTensor = rewriter + .create( + loc, initTensor.getType(), bias, initTensor, + indexingMaps, iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + } auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); Value conv2d = rewriter .create( - loc, initTensor0.getType(), ValueRange{paddedInput, weight}, - initTensor0, stridesAttr, dilationAttr) + loc, biasInitTensor.getType(), ValueRange{paddedInput, weight}, + biasInitTensor, stridesAttr, dilationAttr) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv2d);