diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4de41e13b..f68916c76 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5955,6 +5955,31 @@ def Torch_AtenMvOp : Torch_Op<"aten.mv", [ }]; } +def Torch_AtenDotOp : Torch_Op<"aten.dot", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dot : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$tensor + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenDotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bdd794924..b1153fa40 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -548,6 +548,24 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenDotOp +//===----------------------------------------------------------------------===// + +void AtenDotOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenDotOp op, PatternRewriter &rewriter) { + auto ty = dyn_cast(op.getResult().getType()); + if (!ty || !ty.hasSizes() || !ty.hasDtype()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.getSelf(), op.getTensor()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // RuntimeAssertOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 43bcc3acc..ceccb38be 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7351,6 +7351,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.matmul\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11437,6 +11441,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dot\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %1#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 72c495b1b..d5b682a22 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -822,6 +822,7 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { } STABLEHLO_PASS_SET = { + "AtenDotModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", @@ -1452,6 +1453,7 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenDotModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseLogSigmoidModule_basic", "ElementwiseTruncModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1cf0c2c76..81a860892 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -724,6 +724,10 @@ def aten〇numpy_T〡shape(self: List[int]) -> List[int]: result_shape.insert(0, i) return result_shape +@check_shape_function([Invocation(TensorOfShape(3), TensorOfShape(3))]) +def aten〇dot〡shape(self: List[int], tensor: List[int]) -> List[int]: + return [] + def aten〇matmul〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.matmul(self, other) @@ -3303,6 +3307,13 @@ def aten〇div〇Scalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Un else: return torch.float32 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,), (4,)])) +def aten〇dot〡dtype(self_rank_dtype: Tuple[int, int], tensor_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = tensor_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert self_dtype == other_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c847e42d8..97b952175 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -532,6 +532,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") + emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") emit( "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 0093f13ce..3b9f022fa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -12,6 +12,30 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== +class AtenDotModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.dot(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenDotModule()) +def AtenDotModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + + +# ============================================================================== + + class MatmulDot(torch.nn.Module): def __init__(self): super().__init__()