mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add different dtype support for aten.bmm op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2458/head snapshot-20230912.959
parent
c1f379f3bf
commit
23b72244b1
|
@ -430,7 +430,8 @@ STABLEHLO_PASS_SET = {
|
|||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BmmFloatModule_basic",
|
||||
"BmmIntModule_basic",
|
||||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
|
@ -976,7 +977,7 @@ TOSA_PASS_SET = {
|
|||
"ReturnTwoTensorF32I64_basic",
|
||||
"ElementwiseSignModule_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
"BmmModule_basic",
|
||||
"BmmFloatModule_basic",
|
||||
"MmDagModule_basic",
|
||||
"Matmul4dStatic_basic",
|
||||
"Matmul_dot",
|
||||
|
|
|
@ -1097,15 +1097,9 @@ public:
|
|||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto outElemType = newResultType.getElementType();
|
||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem =
|
||||
convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
|
||||
builder.create<linalg::YieldOp>(loc, elem);
|
||||
};
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
|
||||
tensors[i] = torch_to_linalg::convertTensorToElementType(
|
||||
rewriter, loc, tensors[i], outElemType);
|
||||
}
|
||||
|
||||
int rank = newResultType.getRank();
|
||||
|
|
|
@ -441,16 +441,28 @@ public:
|
|||
Value rhs = adaptor.getMat2();
|
||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
|
||||
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
||||
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
||||
|
||||
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected both operands to aten.bmm to be rank 3");
|
||||
}
|
||||
if (!lhsType.getElementType().isa<mlir::FloatType>() ||
|
||||
lhsType.getElementType() != rhsType.getElementType())
|
||||
return op.emitError(
|
||||
"unimplemented: non floating point operands or operands of "
|
||||
"different types");
|
||||
|
||||
// Convert the inputs element type equivalent to the result' element type.
|
||||
if (lhsElementType != rhsElementType) {
|
||||
if (lhsElementType != resultElementType) {
|
||||
// True if the lhs element type is not equal to the result' element type.
|
||||
lhs = torch_to_linalg::convertTensorToElementType(
|
||||
rewriter, loc, lhs, resultElementType);
|
||||
} else {
|
||||
// True if the rhs element type is not equal to the result' element type.
|
||||
rhs = torch_to_linalg::convertTensorToElementType(
|
||||
rewriter, loc, rhs, resultElementType);
|
||||
}
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
|
@ -465,10 +477,8 @@ public:
|
|||
// Check the matrixs shapes are valid for mulplication.
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||
Value initTensor0 = createZeroInitTensor(
|
||||
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType);
|
||||
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
|
||||
|
||||
Value bmm =
|
||||
rewriter
|
||||
|
|
|
@ -435,3 +435,16 @@ Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
|||
return b.create<tensor::CastOp>(
|
||||
loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor);
|
||||
}
|
||||
|
||||
Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
|
||||
Value tensor,
|
||||
Type elementType) {
|
||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||
ValueRange payloadArgs) {
|
||||
Value elem =
|
||||
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
|
||||
builder.create<linalg::YieldOp>(loc, elem);
|
||||
};
|
||||
return torch_to_linalg::createElementwiseLinalgGeneric(
|
||||
b, loc, {tensor}, elementType, dtypePromoteBody);
|
||||
}
|
||||
|
|
|
@ -81,6 +81,11 @@ broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
|
|||
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
|
||||
// <?x?xf32>
|
||||
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
|
||||
|
||||
// Converts a tensor' element type to the specified `elementType`.
|
||||
Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor,
|
||||
Type elementType);
|
||||
|
||||
} // namespace torch_to_linalg
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -60,7 +60,7 @@ def MmModule_chained(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class BmmModule(torch.nn.Module):
|
||||
class BmmFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -75,11 +75,31 @@ class BmmModule(torch.nn.Module):
|
|||
return torch.bmm(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BmmModule())
|
||||
def BmmModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: BmmFloatModule())
|
||||
def BmmFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
|
||||
class BmmIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.bmm(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BmmIntModule())
|
||||
def BmmIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=100), tu.randint(3, 5, 4, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue