diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index a6bb3feb5..e46ecd126 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -569,7 +569,6 @@ LTC_XFAIL_SET = { "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "LiftFreshCopyModule_basic", "Matmul_dot", - "Matmul_matvec", "MulIntModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9bc03e4a4..a9f89661d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3387,6 +3387,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ }]; } +def Torch_AtenMvOp : Torch_Op<"aten.mv", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$vec + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMvOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 73c47ec26..49ac08b2c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -631,6 +631,21 @@ public: }; } // namespace +// Decompose aten.mv into: aten.matmul. +namespace { +class DecomposeAtenMvOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMvOp op, + PatternRewriter &rewriter) const override { + Value lhs = op.self(); + Value rhs = op.vec(); + rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); + return success(); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { @@ -2859,6 +2874,8 @@ public: patterns.add(context); target.addIllegalOp(); patterns.add(context); + target.addIllegalOp(); + patterns.add(context); target.addIllegalOp(); patterns.add(context); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index b7fcb3d95..bc4ae1d25 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 9cc1d6de2..72ffd9fa5 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5864,6 +5864,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mv\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mv(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index a72ac0dde..b7e5cfc14 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -600,6 +600,9 @@ def aten〇numpy_T(self: List[int]) -> List[int]: def aten〇matmul(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.matmul(self, other) +def aten〇mv(self: List[int], vec: List[int]) -> List[int]: + return upstream_shape_functions.mv(self, vec) + def aten〇mm(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index e130f4cef..43c15b769 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -335,6 +335,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") + emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) diff --git a/python/torch_mlir_e2e_test/test_suite/matmul.py b/python/torch_mlir_e2e_test/test_suite/matmul.py index e1ecfa6a3..e40086bb7 100644 --- a/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module): def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) +# ============================================================================== + +class Mv(torch.nn.Module): + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, m, v): + return torch.mv(m, v) + + +@register_test_case(module_factory=lambda: Mv()) +def Mv_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2), tu.rand(2)) \ No newline at end of file