mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Fix integer type handling for aten.mm (#2615)
Despite aten.mm requiring the input and output types match, we still opt to maintain signedness semantics in case later passes try to do any sort of integer type narrowing.pull/2616/head snapshot-20231207.1045
parent
c0115706a0
commit
141202bc01
|
@ -51,12 +51,24 @@ public:
|
|||
// The compiler cannot crash even if the user wrote an erroneous program!
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
|
||||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
|
||||
|
||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected both operands to aten.mm to be rank 2");
|
||||
}
|
||||
|
||||
ValueTensorType lhsTorchType =
|
||||
op.getSelf().getType().cast<ValueTensorType>();
|
||||
ValueTensorType rhsTorchType =
|
||||
op.getMat2().getType().cast<ValueTensorType>();
|
||||
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: aten.mm with different input element types");
|
||||
}
|
||||
|
||||
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
||||
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
||||
|
||||
|
@ -73,16 +85,22 @@ public:
|
|||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType);
|
||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
Value zeroFill =
|
||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
Value matmul = rewriter
|
||||
Value zeroFill = createZeroInitTensor(
|
||||
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
|
||||
Value matmul;
|
||||
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
|
||||
if (intType && intType.isUnsigned()) {
|
||||
matmul = rewriter
|
||||
.create<linalg::MatmulUnsignedOp>(
|
||||
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
||||
.getResult(0);
|
||||
} else {
|
||||
matmul = rewriter
|
||||
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
||||
ValueRange{lhs, rhs}, zeroFill)
|
||||
.getResult(0);
|
||||
}
|
||||
// When constructed with just dynamic sizes, EmptyOp will have a result
|
||||
// type which has all `?`'s for dimensions, which might not be the result
|
||||
// type of `op`. The constraints on later linalg ops means that the result
|
||||
|
|
|
@ -226,3 +226,39 @@ class Mv(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Mv())
|
||||
def Mv_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmFloatTypes(torch.nn.Module):
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.mm(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmFloatTypes())
|
||||
def AtenMmFloatTypes_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 8), tu.rand(8, 8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmIntTypes(torch.nn.Module):
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.ops.aten.mm(a, b)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmIntTypes())
|
||||
def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
|
||||
|
|
|
@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned(
|
||||
// CHECK: linalg.matmul_unsigned
|
||||
func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32>
|
||||
attributes {torch.assume_strict_symbolic_shapes}
|
||||
{
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32>
|
||||
return %0 : !torch.vtensor<[?,2],ui32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// If the operands are missing dtype, we cannot lower it.
|
||||
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||
// expected-error@+1 {{failed to legalize}}
|
||||
|
|
Loading…
Reference in New Issue