Handle `nn.Linear(..., bias=False)` case for TorchToLinalg (#919)

pull/924/head snapshot-20220609.498
Maksim Levental 2022-06-08 21:13:43 -05:00 committed by GitHub
parent 298d095acf
commit 5c85ac3100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 39 deletions

View File

@ -366,39 +366,40 @@ public:
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
// TODO: Handle the case of bias being None (bias is optional).
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or rank 3");
}
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
}
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
inputType.getElementType() != biasType.getElementType()) {
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
if (!bias.getType().isa<Torch::NoneType>()) {
auto biasType = bias.getType().cast<RankedTensorType>();
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
}
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
inputType.getElementType() != biasType.getElementType()) {
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
}
// TODO: We can handle a static size 1 here at some complexity cost, but the
// dynamic case is not representable in linalg. We don't handle either for
// now. Biases are generally statically shaped for most models (since for
// inference they are constants, and for training they don't change shape
// typically), so this is not too constraining.
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
}
// TODO: We can handle a static size 1 here at some complexity cost, but the
// dynamic case is not representable in linalg. We don't handle either for
// now. Biases are generally statically shaped for most models (since for
// inference they are constants, and for training they don't change shape
// typically), so this is not too constraining.
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
Value batchDim = nullptr;
int restDim = 0;
@ -411,20 +412,23 @@ public:
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for aten.linear"));
// Here we take advantage of ruling out the size-1 case above.
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
rewriter.create<cf::AssertOp>(
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
if (!bias.getType().isa<Torch::NoneType>()) {
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
// Here we take advantage of ruling out the size-1 case above.
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
rewriter.create<cf::AssertOp>(
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
}
Value initTensor;
SmallVector<AffineMap> broadcastIndexingMaps;
@ -455,16 +459,26 @@ public:
}
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
/*indexingMaps=*/broadcastIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Value broadcasted;
if (!bias.getType().isa<Torch::NoneType>()) {
broadcasted =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
/*indexingMaps=*/broadcastIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
} else {
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
broadcasted = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
}
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
// TODO: This whole aten.linear lowering should eventually be generated from

View File

@ -57,6 +57,30 @@ class Mlp2LayerModule(torch.nn.Module):
def Mlp2LayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class Mlp2LayerModuleNoBias(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
N_HIDDEN = 5
self.fc0 = nn.Linear(3, N_HIDDEN, bias=False)
self.tanh0 = nn.Tanh()
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
self.tanh1 = nn.Tanh()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
x = self.tanh0(self.fc0(x))
x = self.tanh1(self.fc1(x))
return x
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class BatchMlpLayerModule(torch.nn.Module):
def __init__(self):
super().__init__()