[Torch] fix aten.linear's decomposition (#3391)

* support aten.linear with more rank.
pull/3392/head
Yuanqiang Liu 2024-05-27 15:49:50 +08:00 committed by GitHub
parent 05929f9171
commit e0a5adb1db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 180 additions and 25 deletions

View File

@ -5513,38 +5513,57 @@ public:
Value bias = op.getBias(); Value bias = op.getBias();
BaseTensorType inputType = cast<BaseTensorType>(input.getType()); BaseTensorType inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasSizes() || inputType.getSizes().size() < 2) if (!inputType.hasSizes())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "expected input to have sizes");
op, "expected input to be rank 2 or greater");
BaseTensorType weightType = cast<BaseTensorType>(weight.getType()); BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
// `weight` must be a rank 2 matrix. if (!weightType.hasSizes())
if (!weightType.hasSizes() || weightType.getSizes().size() != 2) return rewriter.notifyMatchFailure(op, "expected weight to have sizes");
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
SmallVector<int64_t> transposeShape = auto transposeWeight = [&]() -> Value {
llvm::to_vector(llvm::reverse(weightType.getSizes())); SmallVector<int64_t> transposeShape =
Type transposeType = weightType.getWithSizesAndDtype( llvm::to_vector(llvm::reverse(weightType.getSizes()));
llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); Type transposeType = weightType.getWithSizesAndDtype(
Value transposeWeight = llvm::ArrayRef(transposeShape), weightType.getOptionalDtype());
rewriter.create<AtenTOp>(loc, transposeType, weight); Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);
return transposeWeight;
};
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight);
if (bias.getType().isa<Torch::NoneType>()) { if (bias.getType().isa<Torch::NoneType>()) {
rewriter.replaceOp(op, matmul); auto weightRank = weightType.getSizes().size();
if (weightRank > 2 || weightRank <= 0)
return rewriter.notifyMatchFailure(
op, "expected weight's rank <= 2 && >= 1");
if (weightRank == 1) {
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), input,
weight);
return success();
} else if (weightRank == 2) {
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), input,
transposeWeight());
return success();
}
llvm_unreachable("unsupported weightRank");
} else {
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
// `weight` must be a rank 2 matrix.
auto weightRank = weightType.getSizes().size();
if (weightRank != 2)
return rewriter.notifyMatchFailure(op,
"expected weight to be a rank 2");
Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight());
Value alpha =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.getBias(), alpha);
return success(); return success();
} }
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.getBias(), alpha);
return success();
} }
}; };
} // namespace } // namespace

View File

@ -814,6 +814,12 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
} }
STABLEHLO_PASS_SET = { STABLEHLO_PASS_SET = {
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"AtenLinearVecMat_basic",
"ReduceAminSingleDim_basic", "ReduceAminSingleDim_basic",
"AtenDotModule_basic", "AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
@ -1447,6 +1453,8 @@ STABLEHLO_CRASHING_SET = set()
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseDivTensorFloatModule_basic", "ElementwiseDivTensorFloatModule_basic",
"ElementwiseMulTensorFloatModule_basic", "ElementwiseMulTensorFloatModule_basic",
@ -1911,6 +1919,9 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET TOSA_PASS_SET
| { | {
### Tests additionally passing in make_fx_tosa ### Tests additionally passing in make_fx_tosa
"AtenLinear1D_basic",
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic", "MaxPool1dStaticModule_basic",

View File

@ -622,6 +622,131 @@ def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class AtenLinear1D(torch.nn.Module):
@export
@annotate_args(
[
None,
([3], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)
@register_test_case(module_factory=lambda: AtenLinear1D())
def AtenLinear1D_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3))
# ==============================================================================
class AtenLinearMatVec(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 4], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)
@register_test_case(module_factory=lambda: AtenLinearMatVec())
def AtenLinearMatVec_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(4))
# ==============================================================================
class AtenLinearVecMat(torch.nn.Module):
@export
@annotate_args(
[
None,
([4], torch.float32, True),
([3, 4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)
@register_test_case(module_factory=lambda: AtenLinearVecMat())
def AtenLinearVecMat_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(3, 4))
class AtenLinearVecMatBias(torch.nn.Module):
@export
@annotate_args(
[
None,
([4], torch.float32, True),
([3, 4], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, a, b, c):
return torch.ops.aten.linear(a, b, c)
@register_test_case(module_factory=lambda: AtenLinearVecMatBias())
def AtenLinearVecMatBias_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(3, 4), tu.rand(3))
# ==============================================================================
class AtenLinear2D(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 4], torch.float32, True),
([5, 4], torch.float32, True),
]
)
def forward(self, a, b):
return torch.ops.aten.linear(a, b)
@register_test_case(module_factory=lambda: AtenLinear2D())
def AtenLinear2D_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(5, 4))
# ==============================================================================
class AtenLinear3DBias(torch.nn.Module):
@export
@annotate_args(
[
None,
([3, 6, 4], torch.float32, True),
([5, 4], torch.float32, True),
([5], torch.float32, True),
]
)
def forward(self, a, b, c):
return torch.ops.aten.linear(a, b, c)
@register_test_case(module_factory=lambda: AtenLinear3DBias())
def AtenLinear3DBias_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 6, 4), tu.rand(5, 4), tu.rand(5))
# ==============================================================================
class AtenLinalgCrossInt(torch.nn.Module): class AtenLinalgCrossInt(torch.nn.Module):
@export @export
@annotate_args( @annotate_args(