Support aten::linear with rank 3 inputs

Now, aten::linear supports rank 3 inputs. This is a fix
for upcoming bert-inference task. The correct way should be
to support broadcasting in `aten.matmul` op and decompose
`aten.linear` into right ops.
pull/413/head
Prashant Kumar 2021-11-12 04:40:16 +00:00
parent 146f109152
commit f8ff6d84f4
2 changed files with 85 additions and 25 deletions

View File

@ -56,3 +56,22 @@ class Mlp2LayerModule(torch.nn.Module):
@register_test_case(module_factory=lambda: Mlp2LayerModule())
def Mlp2LayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3))
class BatchMlpLayerModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.fc0 = nn.Linear(3, 5)
self.tanh0 = nn.Tanh()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.tanh0(self.fc0(x))
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
def BatchMlpLayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(7, 5, 3))

View File

@ -1151,13 +1151,17 @@ public:
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
// Only handle the case of rank 2 `input` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (inputType.getRank() != 2 || weightType.getRank() != 2 ||
biasType.getRank() != 1) {
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op,
"expected both input and weight to be rank 2 and bias to be rank 1");
op, "expected input to be rank 2 or rank 3");
}
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
}
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
@ -1175,8 +1179,15 @@ public:
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
Value inputDim0 = getDimOp(rewriter, loc, input, 0);
Value inputDim1 = getDimOp(rewriter, loc, input, 1);
Value batchDim = nullptr;
int restDim = 0;
if (inputType.getRank() == 3) {
batchDim = getDimOp(rewriter, loc, input, 0);
restDim = 1;
}
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
@ -1194,13 +1205,35 @@ public:
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType());
SmallVector<AffineMap> broadcastIndexingMaps = {
Value initTensor;
SmallVector<AffineMap> broadcastIndexingMaps;
Value transposedWeightInitTensor;
if (inputType.getRank() > 2) {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, weightDim1, weightDim0},
weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)),
rewriter.getMultiDimIdentityMap(2)};
SmallVector<StringRef> iteratorTypes(2, "parallel");
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
} else {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
}
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted =
rewriter
.create<linalg::GenericOp>(
@ -1217,12 +1250,11 @@ public:
// a single linalg ODS generator statement. Both the bias and matmul part.
SmallVector<AffineMap> transposeIndexingMaps = {
AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(0)},
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim),
rewriter.getAffineDimExpr(0 + restDim)},
context),
rewriter.getMultiDimIdentityMap(2)};
Value transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
rewriter.getMultiDimIdentityMap(inputType.getRank())};
Value transposedWeights =
rewriter
.create<linalg::GenericOp>(
@ -1234,11 +1266,20 @@ public:
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Value matmul = rewriter
Value matmul;
if (batchDim)
matmul = rewriter
.create<linalg::BatchMatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
else
matmul = rewriter
.create<linalg::MatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();