mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Fix integer type handling for aten.mm (#2615)
Despite aten.mm requiring the input and output types match, we still opt to maintain signedness semantics in case later passes try to do any sort of integer type narrowing.pull/2616/head snapshot-20231207.1045
parent
c0115706a0
commit
141202bc01
|
@ -51,12 +51,24 @@ public:
|
||||||
// The compiler cannot crash even if the user wrote an erroneous program!
|
// The compiler cannot crash even if the user wrote an erroneous program!
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
|
|
||||||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
|
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||||
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected both operands to aten.mm to be rank 2");
|
op, "expected both operands to aten.mm to be rank 2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ValueTensorType lhsTorchType =
|
||||||
|
op.getSelf().getType().cast<ValueTensorType>();
|
||||||
|
ValueTensorType rhsTorchType =
|
||||||
|
op.getMat2().getType().cast<ValueTensorType>();
|
||||||
|
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: aten.mm with different input element types");
|
||||||
|
}
|
||||||
|
|
||||||
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
||||||
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
||||||
|
|
||||||
|
@ -73,16 +85,22 @@ public:
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
Value zeroFill = createZeroInitTensor(
|
||||||
loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType);
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, FloatAttr::get(elementType, 0.0));
|
Value matmul;
|
||||||
Value zeroFill =
|
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
|
||||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
if (intType && intType.isUnsigned()) {
|
||||||
Value matmul = rewriter
|
matmul = rewriter
|
||||||
|
.create<linalg::MatmulUnsignedOp>(
|
||||||
|
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
||||||
|
.getResult(0);
|
||||||
|
} else {
|
||||||
|
matmul = rewriter
|
||||||
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
|
||||||
ValueRange{lhs, rhs}, zeroFill)
|
ValueRange{lhs, rhs}, zeroFill)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
}
|
||||||
// When constructed with just dynamic sizes, EmptyOp will have a result
|
// When constructed with just dynamic sizes, EmptyOp will have a result
|
||||||
// type which has all `?`'s for dimensions, which might not be the result
|
// type which has all `?`'s for dimensions, which might not be the result
|
||||||
// type of `op`. The constraints on later linalg ops means that the result
|
// type of `op`. The constraints on later linalg ops means that the result
|
||||||
|
|
|
@ -226,3 +226,39 @@ class Mv(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: Mv())
|
@register_test_case(module_factory=lambda: Mv())
|
||||||
def Mv_basic(module, tu: TestUtils):
|
def Mv_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
module.forward(tu.rand(2, 2), tu.rand(2))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenMmFloatTypes(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.mm(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenMmFloatTypes())
|
||||||
|
def AtenMmFloatTypes_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(8, 8), tu.rand(8, 8))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenMmIntTypes(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.mm(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenMmIntTypes())
|
||||||
|
def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
|
||||||
|
|
|
@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned(
|
||||||
|
// CHECK: linalg.matmul_unsigned
|
||||||
|
func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32>
|
||||||
|
attributes {torch.assume_strict_symbolic_shapes}
|
||||||
|
{
|
||||||
|
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32>
|
||||||
|
return %0 : !torch.vtensor<[?,2],ui32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// If the operands are missing dtype, we cannot lower it.
|
// If the operands are missing dtype, we cannot lower it.
|
||||||
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
|
||||||
// expected-error@+1 {{failed to legalize}}
|
// expected-error@+1 {{failed to legalize}}
|
||||||
|
|
Loading…
Reference in New Issue