mirror of https://github.com/llvm/torch-mlir
Add support for mv decomposition.
parent
6777a9484d
commit
2ba71af651
|
@ -569,7 +569,6 @@ LTC_XFAIL_SET = {
|
|||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_matvec",
|
||||
"MulIntModule_basic",
|
||||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
|
|
|
@ -3387,6 +3387,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMvOp : Torch_Op<"aten.mv", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$vec
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenMvOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -631,6 +631,21 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.mv into: aten.matmul.
|
||||
namespace {
|
||||
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMvOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value lhs = op.self();
|
||||
Value rhs = op.vec();
|
||||
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
|
@ -2859,6 +2874,8 @@ public:
|
|||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addIllegalOp<AtenMvOp>();
|
||||
patterns.add<DecomposeAtenMvOp>(context);
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
patterns.add<DecomposeAtenTOp>(context);
|
||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||
|
|
|
@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
|
||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
||||
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
|
|
@ -5864,6 +5864,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.mv\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.mv(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -600,6 +600,9 @@ def aten〇numpy_T(self: List[int]) -> List[int]:
|
|||
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.matmul(self, other)
|
||||
|
||||
def aten〇mv(self: List[int], vec: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.mv(self, vec)
|
||||
|
||||
def aten〇mm(self: List[int], mat2: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.mm(self, mat2)
|
||||
|
||||
|
|
|
@ -335,6 +335,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
|||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Mv(torch.nn.Module):
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, m, v):
|
||||
return torch.mv(m, v)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Mv())
|
||||
def Mv_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2), tu.rand(2))
|
Loading…
Reference in New Issue