mirror of https://github.com/llvm/torch-mlir
Support aten::linear with rank 3 inputs
Now, aten::linear supports rank 3 inputs. This is a fix for upcoming bert-inference task. The correct way should be to support broadcasting in `aten.matmul` op and decompose `aten.linear` into right ops.pull/413/head
parent
146f109152
commit
f8ff6d84f4
|
@ -56,3 +56,22 @@ class Mlp2LayerModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
@register_test_case(module_factory=lambda: Mlp2LayerModule())
|
||||||
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
def Mlp2LayerModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 3))
|
module.forward(tu.rand(5, 3))
|
||||||
|
|
||||||
|
class BatchMlpLayerModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Reset seed to make model deterministic.
|
||||||
|
torch.manual_seed(0)
|
||||||
|
self.fc0 = nn.Linear(3, 5)
|
||||||
|
self.tanh0 = nn.Tanh()
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.tanh0(self.fc0(x))
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BatchMlpLayerModule())
|
||||||
|
def BatchMlpLayerModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(7, 5, 3))
|
||||||
|
|
|
@ -1151,13 +1151,17 @@ public:
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
// Only handle the case of rank 2 `input` for now.
|
|
||||||
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
if (inputType.getRank() != 2 && inputType.getRank() != 3) {
|
||||||
if (inputType.getRank() != 2 || weightType.getRank() != 2 ||
|
|
||||||
biasType.getRank() != 1) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op, "expected input to be rank 2 or rank 3");
|
||||||
"expected both input and weight to be rank 2 and bias to be rank 1");
|
}
|
||||||
|
|
||||||
|
// Only handle the case of rank 2 `weight` for now.
|
||||||
|
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
|
||||||
|
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expected weight to be rank 2 and bias to be rank 1");
|
||||||
}
|
}
|
||||||
// TODO: Handle type promotion. What are ATen's promotion rules?
|
// TODO: Handle type promotion. What are ATen's promotion rules?
|
||||||
if (inputType.getElementType() != weightType.getElementType() ||
|
if (inputType.getElementType() != weightType.getElementType() ||
|
||||||
|
@ -1175,8 +1179,15 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
|
||||||
|
|
||||||
Value inputDim0 = getDimOp(rewriter, loc, input, 0);
|
Value batchDim = nullptr;
|
||||||
Value inputDim1 = getDimOp(rewriter, loc, input, 1);
|
int restDim = 0;
|
||||||
|
if (inputType.getRank() == 3) {
|
||||||
|
batchDim = getDimOp(rewriter, loc, input, 0);
|
||||||
|
restDim = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
|
||||||
|
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
|
||||||
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
|
||||||
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
|
||||||
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
|
||||||
|
@ -1194,13 +1205,35 @@ public:
|
||||||
loc, biasSizeCorrect,
|
loc, biasSizeCorrect,
|
||||||
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
rewriter.getStringAttr("mismatching bias size for aten.linear"));
|
||||||
|
|
||||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
Value initTensor;
|
||||||
loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType());
|
SmallVector<AffineMap> broadcastIndexingMaps;
|
||||||
SmallVector<AffineMap> broadcastIndexingMaps = {
|
Value transposedWeightInitTensor;
|
||||||
|
if (inputType.getRank() > 2) {
|
||||||
|
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, ValueRange{batchDim, inputDim0, weightDim0},
|
||||||
|
inputType.getElementType());
|
||||||
|
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, ValueRange{batchDim, weightDim1, weightDim0},
|
||||||
|
weightType.getElementType());
|
||||||
|
broadcastIndexingMaps = {
|
||||||
AffineMap::get(
|
AffineMap::get(
|
||||||
/*dimCount=*/2, /*symbolCount=*/0, rewriter.getAffineDimExpr(1)),
|
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||||
rewriter.getMultiDimIdentityMap(2)};
|
{rewriter.getAffineDimExpr(1 + restDim)}, context),
|
||||||
SmallVector<StringRef> iteratorTypes(2, "parallel");
|
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||||
|
} else {
|
||||||
|
initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, ValueRange{inputDim0, weightDim0},
|
||||||
|
inputType.getElementType());
|
||||||
|
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
||||||
|
broadcastIndexingMaps = {
|
||||||
|
AffineMap::get(
|
||||||
|
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||||
|
{rewriter.getAffineDimExpr(1)}, context),
|
||||||
|
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
|
||||||
Value broadcasted =
|
Value broadcasted =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
|
@ -1217,12 +1250,11 @@ public:
|
||||||
// a single linalg ODS generator statement. Both the bias and matmul part.
|
// a single linalg ODS generator statement. Both the bias and matmul part.
|
||||||
SmallVector<AffineMap> transposeIndexingMaps = {
|
SmallVector<AffineMap> transposeIndexingMaps = {
|
||||||
AffineMap::get(
|
AffineMap::get(
|
||||||
/*dimCount=*/2, /*symbolCount=*/0,
|
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
|
||||||
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(0)},
|
{rewriter.getAffineDimExpr(1 + restDim),
|
||||||
|
rewriter.getAffineDimExpr(0 + restDim)},
|
||||||
context),
|
context),
|
||||||
rewriter.getMultiDimIdentityMap(2)};
|
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||||
Value transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
|
|
||||||
Value transposedWeights =
|
Value transposedWeights =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
|
@ -1234,11 +1266,20 @@ public:
|
||||||
b.create<linalg::YieldOp>(loc, args[0]);
|
b.create<linalg::YieldOp>(loc, args[0]);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
Value matmul = rewriter
|
Value matmul;
|
||||||
|
if (batchDim)
|
||||||
|
matmul = rewriter
|
||||||
|
.create<linalg::BatchMatmulOp>(
|
||||||
|
loc, broadcasted.getType(),
|
||||||
|
ValueRange{input, transposedWeights}, broadcasted)
|
||||||
|
.getResult(0);
|
||||||
|
else
|
||||||
|
matmul = rewriter
|
||||||
.create<linalg::MatmulOp>(
|
.create<linalg::MatmulOp>(
|
||||||
loc, broadcasted.getType(),
|
loc, broadcasted.getType(),
|
||||||
ValueRange{input, transposedWeights}, broadcasted)
|
ValueRange{input, transposedWeights}, broadcasted)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||||
return success();
|
return success();
|
||||||
|
|
Loading…
Reference in New Issue