[Torch] emit aten.dot and canonicalize it to aten.matmul (#3361)

* canonicalize `aten.dot` to `aten.matmul`
pull/3365/head
Yuanqiang Liu 2024-05-18 22:45:14 +08:00 committed by GitHub
parent e80f072ba4
commit 8814d0ae64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 99 additions and 0 deletions

View File

@ -5955,6 +5955,31 @@ def Torch_AtenMvOp : Torch_Op<"aten.mv", [
}]; }];
} }
def Torch_AtenDotOp : Torch_Op<"aten.dot", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::dot : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDotOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenDotOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -548,6 +548,24 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// AtenDotOp
//===----------------------------------------------------------------------===//
void AtenDotOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenDotOp op, PatternRewriter &rewriter) {
auto ty = dyn_cast<ValueTensorType>(op.getResult().getType());
if (!ty || !ty.hasSizes() || !ty.hasDtype()) {
return failure();
}
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getResult().getType(),
op.getSelf(), op.getTensor());
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RuntimeAssertOp // RuntimeAssertOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -7351,6 +7351,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n" " } : (!torch.int, !torch.bool) -> ()\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.matmul\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.matmul\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -11437,6 +11441,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %2 : !torch.int\n" " return %2 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.dot\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %1#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -822,6 +822,7 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
} }
STABLEHLO_PASS_SET = { STABLEHLO_PASS_SET = {
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
@ -1452,6 +1453,7 @@ STABLEHLO_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"AtenDotModule_basic",
"ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic",
"ElementwiseLogSigmoidModule_basic", "ElementwiseLogSigmoidModule_basic",
"ElementwiseTruncModule_basic", "ElementwiseTruncModule_basic",

View File

@ -724,6 +724,10 @@ def atennumpy_T〡shape(self: List[int]) -> List[int]:
result_shape.insert(0, i) result_shape.insert(0, i)
return result_shape return result_shape
@check_shape_function([Invocation(TensorOfShape(3), TensorOfShape(3))])
def atendot〡shape(self: List[int], tensor: List[int]) -> List[int]:
return []
def atenmatmul〡shape(self: List[int], other: List[int]) -> List[int]: def atenmatmul〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.matmul(self, other) return upstream_shape_functions.matmul(self, other)
@ -3303,6 +3307,13 @@ def atendivScalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Un
else: else:
return torch.float32 return torch.float32
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,), (4,)]))
def atendot〡dtype(self_rank_dtype: Tuple[int, int], tensor_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = tensor_rank_dtype
self_rank, self_dtype = self_rank_dtype
assert self_dtype == other_dtype
return self_dtype
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width # Different width

View File

@ -532,6 +532,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)")
emit( emit(
"aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"

View File

@ -12,6 +12,30 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class AtenDotModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, lhs, rhs):
return torch.dot(lhs, rhs)
@register_test_case(module_factory=lambda: AtenDotModule())
def AtenDotModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4))
# ==============================================================================
class MatmulDot(torch.nn.Module): class MatmulDot(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()