mirror of https://github.com/llvm/torch-mlir
lowered addcmul and addcdiv to linalg
parent
8d8d2c2fb8
commit
67ce816fca
|
@ -684,3 +684,40 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
|
|||
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
||||
|
||||
class AddCMulModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input, tensor1, tensor2):
|
||||
return torch.addcmul(input, tensor1, tensor2, value=1.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: AddCMulModule())
|
||||
def AddCMulModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
|
||||
|
||||
class AddCDivModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input, tensor1, tensor2):
|
||||
return torch.addcdiv(input, tensor1, tensor2, value=1.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: AddCDivModule())
|
||||
def AddCDivModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
|
||||
|
|
|
@ -30,4 +30,6 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseLogModule_basic",
|
||||
"TanhBackward_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
"AddCMulModule_basic",
|
||||
"AddCDivModule_basic",
|
||||
}
|
||||
|
|
|
@ -2895,3 +2895,37 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
|
|||
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenAddCMulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$tensor1,
|
||||
AnyTorchTensorType:$tensor2,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenAddCDivOp : Torch_Op<"aten.addcdiv", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$tensor1,
|
||||
AnyTorchTensorType:$tensor2,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
|
|
|
@ -375,6 +375,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template<typename OpTy, typename T1T2Op>
|
||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Value tensor1 = op.tensor1();
|
||||
Value tensor2 = op.tensor2();
|
||||
Value value = op.value();
|
||||
|
||||
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
||||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
|
||||
value);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -408,6 +428,10 @@ class DecomposeComplexOpsPass
|
|||
// Make aten.matmul legal if the following condition is satisfied.
|
||||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||
});
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCMulOp, AtenMulTensorOp>>(context);
|
||||
target.addIllegalOp<AtenAddCMulOp>();
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCDivOp, AtenDivTensorOp>>(context);
|
||||
target.addIllegalOp<AtenAddCDivOp>();
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -422,6 +422,8 @@ public:
|
|||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||
return visitNumToTensorOp(numToTensorOp);
|
||||
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
|
||||
return visitAtenAddCLikeOp(op, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -535,6 +537,10 @@ private:
|
|||
ChangeResult
|
||||
visitAtenSoftmaxLikeOp(OpTy op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
ChangeResult
|
||||
visitAtenAddCLikeOp(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1376,6 +1382,25 @@ ChangeResult TypeAnalyzer::visitAtenMatmulOp(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
|
||||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto self = operands[0]->getValue();
|
||||
auto tensor1 = operands[1]->getValue();
|
||||
auto tensor2 = operands[2]->getValue();
|
||||
if (tensor1.hasSizes && tensor2.hasSizes && self.hasSizes) {
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(
|
||||
std::max(self.sizes.size(),
|
||||
std::max(tensor1.sizes.size(), tensor2.sizes.size())),
|
||||
kUnknownSize);
|
||||
}
|
||||
knowledge.dtype =
|
||||
getPromotedResultType(getContext(), {&self, &tensor1, &tensor2});
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -471,6 +471,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit_with_mutating_variants(key)
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue