mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add different dtype support for aten.bmm op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2458/head snapshot-20230912.959
parent
c1f379f3bf
commit
23b72244b1
|
@ -430,7 +430,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"BatchNorm3DModule_basic",
|
"BatchNorm3DModule_basic",
|
||||||
"BatchNorm1DStaticShapeModule_basic",
|
"BatchNorm1DStaticShapeModule_basic",
|
||||||
"ResNet18StaticModule_basic",
|
"ResNet18StaticModule_basic",
|
||||||
"BmmModule_basic",
|
"BmmFloatModule_basic",
|
||||||
|
"BmmIntModule_basic",
|
||||||
"BroadcastToModule_basic",
|
"BroadcastToModule_basic",
|
||||||
"BroadcastToSameRankStaticModule_basic",
|
"BroadcastToSameRankStaticModule_basic",
|
||||||
"BroadcastZeroRankInputStaticModule_basic",
|
"BroadcastZeroRankInputStaticModule_basic",
|
||||||
|
@ -976,7 +977,7 @@ TOSA_PASS_SET = {
|
||||||
"ReturnTwoTensorF32I64_basic",
|
"ReturnTwoTensorF32I64_basic",
|
||||||
"ElementwiseSignModule_basic",
|
"ElementwiseSignModule_basic",
|
||||||
"ElementwisePowModule_basic",
|
"ElementwisePowModule_basic",
|
||||||
"BmmModule_basic",
|
"BmmFloatModule_basic",
|
||||||
"MmDagModule_basic",
|
"MmDagModule_basic",
|
||||||
"Matmul4dStatic_basic",
|
"Matmul4dStatic_basic",
|
||||||
"Matmul_dot",
|
"Matmul_dot",
|
||||||
|
|
|
@ -1097,15 +1097,9 @@ public:
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
|
||||||
auto outElemType = newResultType.getElementType();
|
auto outElemType = newResultType.getElementType();
|
||||||
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
|
||||||
ValueRange payloadArgs) {
|
|
||||||
Value elem =
|
|
||||||
convertScalarToDtype(builder, loc, payloadArgs[0], outElemType);
|
|
||||||
builder.create<linalg::YieldOp>(loc, elem);
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
tensors[i] = torch_to_linalg::createElementwiseLinalgGeneric(
|
tensors[i] = torch_to_linalg::convertTensorToElementType(
|
||||||
rewriter, loc, {tensors[i]}, outElemType, dtypePromoteBody);
|
rewriter, loc, tensors[i], outElemType);
|
||||||
}
|
}
|
||||||
|
|
||||||
int rank = newResultType.getRank();
|
int rank = newResultType.getRank();
|
||||||
|
|
|
@ -441,16 +441,28 @@ public:
|
||||||
Value rhs = adaptor.getMat2();
|
Value rhs = adaptor.getMat2();
|
||||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
|
||||||
|
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
||||||
|
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
||||||
|
|
||||||
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected both operands to aten.bmm to be rank 3");
|
op, "expected both operands to aten.bmm to be rank 3");
|
||||||
}
|
}
|
||||||
if (!lhsType.getElementType().isa<mlir::FloatType>() ||
|
|
||||||
lhsType.getElementType() != rhsType.getElementType())
|
// Convert the inputs element type equivalent to the result' element type.
|
||||||
return op.emitError(
|
if (lhsElementType != rhsElementType) {
|
||||||
"unimplemented: non floating point operands or operands of "
|
if (lhsElementType != resultElementType) {
|
||||||
"different types");
|
// True if the lhs element type is not equal to the result' element type.
|
||||||
|
lhs = torch_to_linalg::convertTensorToElementType(
|
||||||
|
rewriter, loc, lhs, resultElementType);
|
||||||
|
} else {
|
||||||
|
// True if the rhs element type is not equal to the result' element type.
|
||||||
|
rhs = torch_to_linalg::convertTensorToElementType(
|
||||||
|
rewriter, loc, rhs, resultElementType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||||
|
@ -465,10 +477,8 @@ public:
|
||||||
// Check the matrixs shapes are valid for mulplication.
|
// Check the matrixs shapes are valid for mulplication.
|
||||||
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
||||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
|
||||||
Value initTensor0 = createZeroInitTensor(
|
Value initTensor0 = createZeroInitTensor(
|
||||||
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType);
|
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
|
||||||
|
|
||||||
Value bmm =
|
Value bmm =
|
||||||
rewriter
|
rewriter
|
||||||
|
|
|
@ -435,3 +435,16 @@ Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
||||||
return b.create<tensor::CastOp>(
|
return b.create<tensor::CastOp>(
|
||||||
loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor);
|
loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
|
||||||
|
Value tensor,
|
||||||
|
Type elementType) {
|
||||||
|
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
||||||
|
ValueRange payloadArgs) {
|
||||||
|
Value elem =
|
||||||
|
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
|
||||||
|
builder.create<linalg::YieldOp>(loc, elem);
|
||||||
|
};
|
||||||
|
return torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
|
b, loc, {tensor}, elementType, dtypePromoteBody);
|
||||||
|
}
|
||||||
|
|
|
@ -81,6 +81,11 @@ broadcastToGivenShape(Operation *op, PatternRewriter &rewriter, Value input,
|
||||||
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
|
// Cast a tensor to a rank-equivalent tensor of unknown size, i.e. <1x2xf32> ->
|
||||||
// <?x?xf32>
|
// <?x?xf32>
|
||||||
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
|
Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor);
|
||||||
|
|
||||||
|
// Converts a tensor' element type to the specified `elementType`.
|
||||||
|
Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor,
|
||||||
|
Type elementType);
|
||||||
|
|
||||||
} // namespace torch_to_linalg
|
} // namespace torch_to_linalg
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -60,7 +60,7 @@ def MmModule_chained(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BmmModule(torch.nn.Module):
|
class BmmFloatModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -75,11 +75,31 @@ class BmmModule(torch.nn.Module):
|
||||||
return torch.bmm(lhs, rhs)
|
return torch.bmm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: BmmModule())
|
@register_test_case(module_factory=lambda: BmmFloatModule())
|
||||||
def BmmModule_basic(module, tu: TestUtils):
|
def BmmFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class BmmIntModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.bmm(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BmmIntModule())
|
||||||
|
def BmmIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, high=100), tu.randint(3, 5, 4, high=100))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue