mirror of https://github.com/llvm/torch-mlir
Add support for mv decomposition.
parent
6777a9484d
commit
2ba71af651
|
@ -569,7 +569,6 @@ LTC_XFAIL_SET = {
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
"LiftFreshCopyModule_basic",
|
"LiftFreshCopyModule_basic",
|
||||||
"Matmul_dot",
|
"Matmul_dot",
|
||||||
"Matmul_matvec",
|
|
||||||
"MulIntModule_basic",
|
"MulIntModule_basic",
|
||||||
"NeFloatIntModule_basic",
|
"NeFloatIntModule_basic",
|
||||||
"NeIntModule_basic",
|
"NeIntModule_basic",
|
||||||
|
|
|
@ -3387,6 +3387,30 @@ def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMvOp : Torch_Op<"aten.mv", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::mv : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$vec
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenMvOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenMvOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -631,6 +631,21 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose aten.mv into: aten.matmul.
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenMvOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value lhs = op.self();
|
||||||
|
Value rhs = op.vec();
|
||||||
|
rewriter.replaceOpWithNewOp<AtenMatmulOp>(op, op.getType(), lhs, rhs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||||
Value input) {
|
Value input) {
|
||||||
|
@ -2859,6 +2874,8 @@ public:
|
||||||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||||
target.addIllegalOp<AtenSelectIntOp>();
|
target.addIllegalOp<AtenSelectIntOp>();
|
||||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||||
|
target.addIllegalOp<AtenMvOp>();
|
||||||
|
patterns.add<DecomposeAtenMvOp>(context);
|
||||||
target.addIllegalOp<AtenTOp>();
|
target.addIllegalOp<AtenTOp>();
|
||||||
patterns.add<DecomposeAtenTOp>(context);
|
patterns.add<DecomposeAtenTOp>(context);
|
||||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||||
|
|
|
@ -754,7 +754,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
|
|
||||||
// Promote the two dtypes assuming non-zero rank.
|
// Promote the two dtypes assuming non-zero rank.
|
||||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||||
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
|
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
|
||||||
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
|
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
|
|
|
@ -5864,6 +5864,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.mv\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.mv(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.mm\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
|
|
@ -600,6 +600,9 @@ def aten〇numpy_T(self: List[int]) -> List[int]:
|
||||||
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.matmul(self, other)
|
return upstream_shape_functions.matmul(self, other)
|
||||||
|
|
||||||
|
def aten〇mv(self: List[int], vec: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.mv(self, vec)
|
||||||
|
|
||||||
def aten〇mm(self: List[int], mat2: List[int]) -> List[int]:
|
def aten〇mm(self: List[int], mat2: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.mm(self, mat2)
|
return upstream_shape_functions.mm(self, mat2)
|
||||||
|
|
||||||
|
|
|
@ -335,6 +335,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
emit("aten::mv : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -209,3 +209,20 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Mv(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, m, v):
|
||||||
|
return torch.mv(m, v)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Mv())
|
||||||
|
def Mv_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 2), tu.rand(2))
|
Loading…
Reference in New Issue