mirror of https://github.com/llvm/torch-mlir
Handle `nn.Linear(..., bias=False)` case for TorchToLinalg (#919)
parent
298d095acf
commit
5c85ac3100
|
@ -366,39 +366,40 @@ public:
|
|||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
// TODO: Handle the case of bias being None (bias is optional).
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
|
||||
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected input to be rank 2 or rank 3");
|
||||
}
|
||||
|
||||
// Only handle the case of rank 2 `weight` for now.
|
||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected weight to be rank 2 and bias to be rank 1");
|
||||
}
|
||||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||||
if (inputType.getElementType() != weightType.getElementType() ||
|
||||
inputType.getElementType() != biasType.getElementType()) {
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
// Only handle the case of rank 2 `weight` for now.
|
||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected weight to be rank 2 and bias to be rank 1");
|
||||
}
|
||||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||||
if (inputType.getElementType() != weightType.getElementType() ||
|
||||
inputType.getElementType() != biasType.getElementType()) {
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||
}
|
||||
// TODO: We can handle a static size 1 here at some complexity cost, but the
|
||||
// dynamic case is not representable in linalg. We don't handle either for
|
||||
// now. Biases are generally statically shaped for most models (since for
|
||||
// inference they are constants, and for training they don't change shape
|
||||
// typically), so this is not too constraining.
|
||||
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
|
||||
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||
}
|
||||
|
||||
// TODO: We can handle a static size 1 here at some complexity cost, but the
|
||||
// dynamic case is not representable in linalg. We don't handle either for
|
||||
// now. Biases are generally statically shaped for most models (since for
|
||||
// inference they are constants, and for training they don't change shape
|
||||
// typically), so this is not too constraining.
|
||||
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
|
||||
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||
|
||||
Value batchDim = nullptr;
|
||||
int restDim = 0;
|
||||
|
@ -411,20 +412,23 @@ public:
|
|||
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
|
||||
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, contractingDimEqual,
|
||||
rewriter.getStringAttr(
|
||||
"mismatching contracting dimension for aten.linear"));
|
||||
// Here we take advantage of ruling out the size-1 case above.
|
||||
// In the static-size-1 case, we will not emit this check at all.
|
||||
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, biasSizeCorrect,
|
||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
// Here we take advantage of ruling out the size-1 case above.
|
||||
// In the static-size-1 case, we will not emit this check at all.
|
||||
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, biasSizeCorrect,
|
||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||
}
|
||||
|
||||
Value initTensor;
|
||||
SmallVector<AffineMap> broadcastIndexingMaps;
|
||||
|
@ -455,16 +459,26 @@ public:
|
|||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
|
||||
Value broadcasted =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
/*indexingMaps=*/broadcastIndexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
Value broadcasted;
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
broadcasted =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
/*indexingMaps=*/broadcastIndexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
} else {
|
||||
Type elementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
broadcasted = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
||||
.getResult(0);
|
||||
}
|
||||
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
|
||||
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
|
||||
// TODO: This whole aten.linear lowering should eventually be generated from
|
||||
|
|
|
@ -57,6 +57,30 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
class Mlp2LayerModuleNoBias(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
torch.manual_seed(0)
|
||||
N_HIDDEN = 5
|
||||
self.fc0 = nn.Linear(3, N_HIDDEN, bias=False)
|
||||
self.tanh0 = nn.Tanh()
|
||||
self.fc1 = nn.Linear(N_HIDDEN, 2, bias=False)
|
||||
self.tanh1 = nn.Tanh()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
x = self.tanh1(self.fc1(x))
|
||||
return x
|
||||
|
||||
@register_test_case(module_factory=lambda: Mlp2LayerModuleNoBias())
|
||||
def Mlp2LayerModuleNoBias_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
class BatchMlpLayerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue