diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index e6b9e6c24..8fa598f8d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -366,39 +366,40 @@ public: Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); - // TODO: Handle the case of bias being None (bias is optional). if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); - auto biasType = bias.getType().cast(); if (inputType.getRank() != 2 && inputType.getRank() != 3) { return rewriter.notifyMatchFailure( op, "expected input to be rank 2 or rank 3"); } - // Only handle the case of rank 2 `weight` for now. - // TODO: Insert the appropriate reshape to collapse any leading dimensions. - if (weightType.getRank() != 2 || biasType.getRank() != 1) { - return rewriter.notifyMatchFailure( - op, "expected weight to be rank 2 and bias to be rank 1"); - } - // TODO: Handle type promotion. What are ATen's promotion rules? - if (inputType.getElementType() != weightType.getElementType() || - inputType.getElementType() != biasType.getElementType()) { - return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); + if (!bias.getType().isa()) { + auto biasType = bias.getType().cast(); + // Only handle the case of rank 2 `weight` for now. + // TODO: Insert the appropriate reshape to collapse any leading dimensions. + if (weightType.getRank() != 2 || biasType.getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "expected weight to be rank 2 and bias to be rank 1"); + } + // TODO: Handle type promotion. What are ATen's promotion rules? + if (inputType.getElementType() != weightType.getElementType() || + inputType.getElementType() != biasType.getElementType()) { + return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); + } + // TODO: We can handle a static size 1 here at some complexity cost, but the + // dynamic case is not representable in linalg. We don't handle either for + // now. Biases are generally statically shaped for most models (since for + // inference they are constants, and for training they don't change shape + // typically), so this is not too constraining. + auto biasSize = bias.getType().cast().getShape()[0]; + if (biasSize == 1 || biasSize == ShapedType::kDynamicSize) + return rewriter.notifyMatchFailure( + op, "unimplemented: size-1 broadcasting for aten::LinearOp"); } - // TODO: We can handle a static size 1 here at some complexity cost, but the - // dynamic case is not representable in linalg. We don't handle either for - // now. Biases are generally statically shaped for most models (since for - // inference they are constants, and for training they don't change shape - // typically), so this is not too constraining. - auto biasSize = bias.getType().cast().getShape()[0]; - if (biasSize == 1 || biasSize == ShapedType::kDynamicSize) - return rewriter.notifyMatchFailure( - op, "unimplemented: size-1 broadcasting for aten::LinearOp"); Value batchDim = nullptr; int restDim = 0; @@ -411,20 +412,23 @@ public: Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1); Value weightDim0 = getDimOp(rewriter, loc, weight, 0); Value weightDim1 = getDimOp(rewriter, loc, weight, 1); - Value biasDim0 = getDimOp(rewriter, loc, bias, 0); Value contractingDimEqual = rewriter.create( loc, arith::CmpIPredicate::eq, inputDim1, weightDim1); rewriter.create( loc, contractingDimEqual, rewriter.getStringAttr( "mismatching contracting dimension for aten.linear")); - // Here we take advantage of ruling out the size-1 case above. - // In the static-size-1 case, we will not emit this check at all. - Value biasSizeCorrect = rewriter.create( - loc, arith::CmpIPredicate::eq, weightDim0, biasDim0); - rewriter.create( - loc, biasSizeCorrect, - rewriter.getStringAttr("mismatching bias size for aten.linear")); + + if (!bias.getType().isa()) { + Value biasDim0 = getDimOp(rewriter, loc, bias, 0); + // Here we take advantage of ruling out the size-1 case above. + // In the static-size-1 case, we will not emit this check at all. + Value biasSizeCorrect = rewriter.create( + loc, arith::CmpIPredicate::eq, weightDim0, biasDim0); + rewriter.create( + loc, biasSizeCorrect, + rewriter.getStringAttr("mismatching bias size for aten.linear")); + } Value initTensor; SmallVector broadcastIndexingMaps; @@ -455,16 +459,26 @@ public: } SmallVector iteratorTypes(inputType.getRank(), "parallel"); - Value broadcasted = - rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - /*indexingMaps=*/broadcastIndexingMaps, - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + Value broadcasted; + if (!bias.getType().isa()) { + broadcasted = + rewriter + .create( + loc, initTensor.getType(), bias, initTensor, + /*indexingMaps=*/broadcastIndexingMaps, + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + } else { + Type elementType = + input.getType().cast().getElementType(); + Value c0float = rewriter.create( + loc, FloatAttr::get(elementType, 0.0)); + broadcasted = rewriter.create(loc, c0float, initTensor) + .getResult(0); + } // We need a matmul with dimension ordering (N, K) * (M, K), so transpose // the weights to fit into linalg::MatmulOp which is (N, K) * (K, M). // TODO: This whole aten.linear lowering should eventually be generated from diff --git a/python/torch_mlir_e2e_test/test_suite/mlp.py b/python/torch_mlir_e2e_test/test_suite/mlp.py index 8713c55a2..153f35759 100644 --- a/python/torch_mlir_e2e_test/test_suite/mlp.py +++ b/python/torch_mlir_e2e_test/test_suite/mlp.py @@ -57,6 +57,30 @@ class Mlp2LayerModule(torch.nn.Module): def Mlp2LayerModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 3)) +class Mlp2LayerModuleNoBias(torch.nn.Module): + def __init__(self): + super().__init__() + # Reset seed to make model deterministic. + torch.manual_seed(0) + N_HIDDEN = 5 + self.fc0 = nn.Linear(3, N_HIDDEN, bias=False) + self.tanh0 = nn.Tanh() + self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False) + self.tanh1 = nn.Tanh() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + x = self.tanh0(self.fc0(x)) + x = self.tanh1(self.fc1(x)) + return x + +@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias()) +def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3)) + class BatchMlpLayerModule(torch.nn.Module): def __init__(self): super().__init__()