[TorchToLinalg] Fix integer type handling for aten.mm (#2615)

Despite aten.mm requiring the input and output types match, we still opt
to maintain signedness semantics in case later passes try to do any sort
of integer type narrowing.
pull/2616/head snapshot-20231207.1045
Quinn Dawkins 2023-12-07 00:13:53 -05:00 committed by GitHub
parent c0115706a0
commit 141202bc01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 13 deletions

View File

@ -51,12 +51,24 @@ public:
// The compiler cannot crash even if the user wrote an erroneous program! // The compiler cannot crash even if the user wrote an erroneous program!
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
rhs.getType().cast<RankedTensorType>().getRank() != 2) { RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected both operands to aten.mm to be rank 2"); op, "expected both operands to aten.mm to be rank 2");
} }
ValueTensorType lhsTorchType =
op.getSelf().getType().cast<ValueTensorType>();
ValueTensorType rhsTorchType =
op.getMat2().getType().cast<ValueTensorType>();
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
}
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0); Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1); Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
@ -73,16 +85,22 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType(); Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor = rewriter.create<tensor::EmptyOp>( Value zeroFill = createZeroInitTensor(
loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType); rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0)); Value matmul;
Value zeroFill = auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0); if (intType && intType.isUnsigned()) {
Value matmul = rewriter matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(), .create<linalg::MatmulUnsignedOp>(
ValueRange{lhs, rhs}, zeroFill) loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
.getResult(0); .getResult(0);
} else {
matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
}
// When constructed with just dynamic sizes, EmptyOp will have a result // When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result // type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result // type of `op`. The constraints on later linalg ops means that the result

View File

@ -226,3 +226,39 @@ class Mv(torch.nn.Module):
@register_test_case(module_factory=lambda: Mv()) @register_test_case(module_factory=lambda: Mv())
def Mv_basic(module, tu: TestUtils): def Mv_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2), tu.rand(2)) module.forward(tu.rand(2, 2), tu.rand(2))
# ==============================================================================
class AtenMmFloatTypes(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.mm(a, b)
@register_test_case(module_factory=lambda: AtenMmFloatTypes())
def AtenMmFloatTypes_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 8), tu.rand(8, 8))
# ==============================================================================
class AtenMmIntTypes(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, a, b):
return torch.ops.aten.mm(a, b)
@register_test_case(module_factory=lambda: AtenMmIntTypes())
def AtenMmIntTypes_basic(module, tu: TestUtils):
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))

View File

@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// ----- // -----
// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned(
// CHECK: linalg.matmul_unsigned
func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32>
attributes {torch.assume_strict_symbolic_shapes}
{
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32>
return %0 : !torch.vtensor<[?,2],ui32>
}
// -----
// If the operands are missing dtype, we cannot lower it. // If the operands are missing dtype, we cannot lower it.
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}} // expected-error@+1 {{failed to legalize}}