mirror of https://github.com/llvm/torch-mlir
Support aten::linear with rank 3 inputs
Now, aten::linear supports rank 3 inputs. This is a fix for upcoming bert-inference task. The correct way should be to support broadcasting in `aten.matmul` op and decompose `aten.linear` into right ops.pull/413/head
parent
146f109152
commit
f8ff6d84f4
|
@ -56,3 +56,22 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
||||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3))
|
||||
|
||||
class BatchMlpLayerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Reset seed to make model deterministic.
|
||||
torch.manual_seed(0)
|
||||
self.fc0 = nn.Linear(3, 5)
|
||||
self.tanh0 = nn.Tanh()
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
||||
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
||||
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(7, 5, 3))
|
||||
|
|
|
@ -1151,13 +1151,17 @@ public:
|
|||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
// Only handle the case of rank 2 `input` for now.
|
||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||
if (inputType.getRank() != 2 || weightType.getRank() != 2 ||
|
||||
biasType.getRank() != 1) {
|
||||
|
||||
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"expected both input and weight to be rank 2 and bias to be rank 1");
|
||||
op, "expected input to be rank 2 or rank 3");
|
||||
}
|
||||
|
||||
// Only handle the case of rank 2 `weight` for now.
|
||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected weight to be rank 2 and bias to be rank 1");
|
||||
}
|
||||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||||
if (inputType.getElementType() != weightType.getElementType() ||
|
||||
|
@ -1175,8 +1179,15 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||
|
||||
Value inputDim0 = getDimOp(rewriter, loc, input, 0);
|
||||
Value inputDim1 = getDimOp(rewriter, loc, input, 1);
|
||||
Value batchDim = nullptr;
|
||||
int restDim = 0;
|
||||
if (inputType.getRank() == 3) {
|
||||
batchDim = getDimOp(rewriter, loc, input, 0);
|
||||
restDim = 1;
|
||||
}
|
||||
|
||||
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
|
||||
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
|
||||
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
||||
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||
|
@ -1194,13 +1205,35 @@ public:
|
|||
loc, biasSizeCorrect,
|
||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType());
|
||||
SmallVector<AffineMap> broadcastIndexingMaps = {
|
||||
Value initTensor;
|
||||
SmallVector<AffineMap> broadcastIndexingMaps;
|
||||
Value transposedWeightInitTensor;
|
||||
if (inputType.getRank() > 2) {
|
||||
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{batchDim, inputDim0, weightDim0},
|
||||
inputType.getElementType());
|
||||
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{batchDim, weightDim1, weightDim0},
|
||||
weightType.getElementType());
|
||||
broadcastIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)),
|
||||
rewriter.getMultiDimIdentityMap(2)};
|
||||
SmallVector<StringRef> iteratorTypes(2, "parallel");
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1 + restDim)}, context),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
} else {
|
||||
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{inputDim0, weightDim0},
|
||||
inputType.getElementType());
|
||||
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
||||
broadcastIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1)}, context),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
}
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
|
||||
Value broadcasted =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
|
@ -1217,12 +1250,11 @@ public:
|
|||
// a single linalg ODS generator statement. Both the bias and matmul part.
|
||||
SmallVector<AffineMap> transposeIndexingMaps = {
|
||||
AffineMap::get(
|
||||
/*dimCount=*/2, /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(0)},
|
||||
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(1 + restDim),
|
||||
rewriter.getAffineDimExpr(0 + restDim)},
|
||||
context),
|
||||
rewriter.getMultiDimIdentityMap(2)};
|
||||
Value transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
Value transposedWeights =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
|
@ -1234,11 +1266,20 @@ public:
|
|||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
Value matmul = rewriter
|
||||
Value matmul;
|
||||
if (batchDim)
|
||||
matmul = rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, broadcasted.getType(),
|
||||
ValueRange{input, transposedWeights}, broadcasted)
|
||||
.getResult(0);
|
||||
else
|
||||
matmul = rewriter
|
||||
.create<linalg::MatmulOp>(
|
||||
loc, broadcasted.getType(),
|
||||
ValueRange{input, transposedWeights}, broadcasted)
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue