Support aten::linear with rank 3 inputs

Now, aten::linear supports rank 3 inputs. This is a fix
for upcoming bert-inference task. The correct way should be
to support broadcasting in `aten.matmul` op and decompose
`aten.linear` into right ops.
pull/413/head
Prashant Kumar 2021-11-12 04:40:16 +00:00
parent 146f109152
commit f8ff6d84f4
2 changed files with 85 additions and 25 deletions

View File

@ -56,3 +56,22 @@ class Mlp2LayerModule(torch.nn.Module):
@register_test_case(module_factory=lambda: Mlp2LayerModule()) @register_test_case(module_factory=lambda: Mlp2LayerModule())
def Mlp2LayerModule_basic(module, tu: TestUtils): def Mlp2LayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3)) module.forward(tu.rand(5, 3))
class BatchMlpLayerModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Reset seed to make model deterministic.
torch.manual_seed(0)
self.fc0 = nn.Linear(3, 5)
self.tanh0 = nn.Tanh()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return self.tanh0(self.fc0(x))
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
def BatchMlpLayerModule_basic(module, tu: TestUtils):
module.forward(tu.rand(7, 5, 3))

View File

@ -1151,13 +1151,17 @@ public:
auto inputType = input.getType().cast<RankedTensorType>(); auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>(); auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>(); auto biasType = bias.getType().cast<RankedTensorType>();
// Only handle the case of rank 2 `input` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions. if (inputType.getRank() != 2 && inputType.getRank() != 3) {
if (inputType.getRank() != 2 || weightType.getRank() != 2 ||
biasType.getRank() != 1) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, op, "expected input to be rank 2 or rank 3");
"expected both input and weight to be rank 2 and bias to be rank 1"); }
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
} }
// TODO: Handle type promotion. What are ATen's promotion rules? // TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() || if (inputType.getElementType() != weightType.getElementType() ||
@ -1175,8 +1179,15 @@ public:
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp"); op, "unimplemented: size-1 broadcasting for aten::LinearOp");
Value inputDim0 = getDimOp(rewriter, loc, input, 0); Value batchDim = nullptr;
Value inputDim1 = getDimOp(rewriter, loc, input, 1); int restDim = 0;
if (inputType.getRank() == 3) {
batchDim = getDimOp(rewriter, loc, input, 0);
restDim = 1;
}
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0); Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1); Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value biasDim0 = getDimOp(rewriter, loc, bias, 0); Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
@ -1194,13 +1205,35 @@ public:
loc, biasSizeCorrect, loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear")); rewriter.getStringAttr("mismatching bias size for aten.linear"));
Value initTensor = rewriter.create<linalg::InitTensorOp>( Value initTensor;
loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType()); SmallVector<AffineMap> broadcastIndexingMaps;
SmallVector<AffineMap> broadcastIndexingMaps = { Value transposedWeightInitTensor;
if (inputType.getRank() > 2) {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, weightDim1, weightDim0},
weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get( AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)), /*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
rewriter.getMultiDimIdentityMap(2)}; {rewriter.getAffineDimExpr(1 + restDim)}, context),
SmallVector<StringRef> iteratorTypes(2, "parallel"); rewriter.getMultiDimIdentityMap(inputType.getRank())};
} else {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
}
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted = Value broadcasted =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
@ -1217,12 +1250,11 @@ public:
// a single linalg ODS generator statement. Both the bias and matmul part. // a single linalg ODS generator statement. Both the bias and matmul part.
SmallVector<AffineMap> transposeIndexingMaps = { SmallVector<AffineMap> transposeIndexingMaps = {
AffineMap::get( AffineMap::get(
/*dimCount=*/2, /*symbolCount=*/0, /*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(0)}, {rewriter.getAffineDimExpr(1 + restDim),
rewriter.getAffineDimExpr(0 + restDim)},
context), context),
rewriter.getMultiDimIdentityMap(2)}; rewriter.getMultiDimIdentityMap(inputType.getRank())};
Value transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
Value transposedWeights = Value transposedWeights =
rewriter rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
@ -1234,11 +1266,20 @@ public:
b.create<linalg::YieldOp>(loc, args[0]); b.create<linalg::YieldOp>(loc, args[0]);
}) })
.getResult(0); .getResult(0);
Value matmul = rewriter Value matmul;
if (batchDim)
matmul = rewriter
.create<linalg::BatchMatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
else
matmul = rewriter
.create<linalg::MatmulOp>( .create<linalg::MatmulOp>(
loc, broadcasted.getType(), loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted) ValueRange{input, transposedWeights}, broadcasted)
.getResult(0); .getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success(); return success();