[MLIR][TORCH] Add different dtype support for aten.bmm op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2458/head snapshot-20230912.959
Vivek Khandelwal 2023-09-11 12:58:59 +00:00
parent c1f379f3bf
commit 23b72244b1
6 changed files with 64 additions and 21 deletions

View File

@ -430,7 +430,8 @@ STABLEHLO_PASS_SET = {
"BatchNorm3DModule_basic", "BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic", "BatchNorm1DStaticShapeModule_basic",
"ResNet18StaticModule_basic", "ResNet18StaticModule_basic",
"BmmModule_basic", "BmmFloatModule_basic",
"BmmIntModule_basic",
"BroadcastToModule_basic", "BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic", "BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic",
@ -976,7 +977,7 @@ TOSA_PASS_SET = {
"ReturnTwoTensorF32I64_basic", "ReturnTwoTensorF32I64_basic",
"ElementwiseSignModule_basic", "ElementwiseSignModule_basic",
"ElementwisePowModule_basic", "ElementwisePowModule_basic",
"BmmModule_basic", "BmmFloatModule_basic",
"MmDagModule_basic", "MmDagModule_basic",
"Matmul4dStatic_basic", "Matmul4dStatic_basic",
"Matmul_dot", "Matmul_dot",

View File

@ -1097,15 +1097,9 @@ public:
typeConverter->convertType(op.getType()).cast<RankedTensorType>(); typeConverter->convertType(op.getType()).cast<RankedTensorType>();
auto outElemType = newResultType.getElementType(); auto outElemType = newResultType.getElementType();
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem =
convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
builder.create<linalg::YieldOp>(loc, elem);
};
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric( tensors[i] = torch_to_linalg::convertTensorToElementType(
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody); rewriter, loc, tensors[i], outElemType);
} }
int rank = newResultType.getRank(); int rank = newResultType.getRank();

View File

@ -441,16 +441,28 @@ public:
Value rhs = adaptor.getMat2(); Value rhs = adaptor.getMat2();
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>(); RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>(); RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "expected both operands to aten.bmm to be rank 3"); op, "expected both operands to aten.bmm to be rank 3");
} }
if (!lhsType.getElementType().isa<mlir::FloatType>() ||
lhsType.getElementType() != rhsType.getElementType()) // Convert the inputs element type equivalent to the result' element type.
return op.emitError( if (lhsElementType != rhsElementType) {
"unimplemented: non floating point operands or operands of " if (lhsElementType != resultElementType) {
"different types"); // True if the lhs element type is not equal to the result' element type.
lhs = torch_to_linalg::convertTensorToElementType(
rewriter, loc, lhs, resultElementType);
} else {
// True if the rhs element type is not equal to the result' element type.
rhs = torch_to_linalg::convertTensorToElementType(
rewriter, loc, rhs, resultElementType);
}
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
@ -465,10 +477,8 @@ public:
// Check the matrixs shapes are valid for mulplication. // Check the matrixs shapes are valid for mulplication.
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor0 = createZeroInitTensor( Value initTensor0 = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType); rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
Value bmm = Value bmm =
rewriter rewriter

View File

@ -435,3 +435,16 @@ Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
return b.create<tensor::CastOp>( return b.create<tensor::CastOp>(
loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor); loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor);
} }
Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
Value tensor,
Type elementType) {
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
ValueRange payloadArgs) {
Value elem =
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
builder.create<linalg::YieldOp>(loc, elem);
};
return torch_to_linalg::createElementwiseLinalgGeneric(
b, loc, {tensor}, elementType, dtypePromoteBody);
}

View File

@ -81,6 +81,11 @@ broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> -> // Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
// <?x?xf32> // <?x?xf32>
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
// Converts a tensor' element type to the specified `elementType`.
Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor,
Type elementType);
} // namespace torch_to_linalg } // namespace torch_to_linalg
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -60,7 +60,7 @@ def MmModule_chained(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class BmmModule(torch.nn.Module): class BmmFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -75,11 +75,31 @@ class BmmModule(torch.nn.Module):
return torch.bmm(lhs, rhs) return torch.bmm(lhs, rhs)
@register_test_case(module_factory=lambda: BmmModule()) @register_test_case(module_factory=lambda: BmmFloatModule())
def BmmModule_basic(module, tu: TestUtils): def BmmFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
class BmmIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, lhs, rhs):
return torch.bmm(lhs, rhs)
@register_test_case(module_factory=lambda: BmmIntModule())
def BmmIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100), tu.randint(3, 5, 4, high=100))
# ============================================================================== # ==============================================================================