Add support for mv decomposition.

pull/1214/head
Daniel Ellis 2022-10-03 18:32:17 +00:00
parent 6777a9484d
commit 2ba71af651
8 changed files with 67 additions and 2 deletions

View File

@ -569,7 +569,6 @@ LTC_XFAIL_SET = {
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"Matmul_dot", "Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic", "MulIntModule_basic",
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",

View File

@ -3387,6 +3387,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
}]; }];
} }
def Torch_AtenMvOp : Torch_Op<"aten.mv", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$vec
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMvOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -631,6 +631,21 @@ public:
}; };
} // namespace } // namespace
// Decompose aten.mv into: aten.matmul.
namespace {
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMvOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
Value rhs = op.vec();
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
return success();
}
};
} // namespace
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc, static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) { Value input) {
@ -2859,6 +2874,8 @@ public:
patterns.add<DecomposeAtenSelectIntOp>(context); patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>(); target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context); patterns.add<DecomposeAtenMatmulOp>(context);
target.addIllegalOp<AtenMvOp>();
patterns.add<DecomposeAtenMvOp>(context);
target.addIllegalOp<AtenTOp>(); target.addIllegalOp<AtenTOp>();
patterns.add<DecomposeAtenTOp>(context); patterns.add<DecomposeAtenTOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context); patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);

View File

@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming non-zero rank. // Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp, if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) { AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge = auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext()); ValueKnowledge::getTensorPessimisticValueState(op->getContext());

View File

@ -5864,6 +5864,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mv\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.mv(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"

View File

@ -600,6 +600,9 @@ def atennumpy_T(self: List[int]) -> List[int]:
def atenmatmul(self: List[int], other: List[int]) -> List[int]: def atenmatmul(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.matmul(self, other) return upstream_shape_functions.matmul(self, other)
def atenmv(self: List[int], vec: List[int]) -> List[int]:
return upstream_shape_functions.mv(self, vec)
def atenmm(self: List[int], mat2: List[int]) -> List[int]: def atenmm(self: List[int], mat2: List[int]) -> List[int]:
return upstream_shape_functions.mm(self, mat2) return upstream_shape_functions.mm(self, mat2)

View File

@ -335,6 +335,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
emit( emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
) )

View File

@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
# ==============================================================================
class Mv(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, m, v):
return torch.mv(m, v)
@register_test_case(module_factory=lambda: Mv())
def Mv_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2), tu.rand(2))