lowered addcmul and addcdiv to linalg

pull/431/head
nodlabs 2021-11-24 14:01:48 -08:00 committed by Yi Zhang
parent 8d8d2c2fb8
commit 67ce816fca
6 changed files with 124 additions and 0 deletions

View File

@ -684,3 +684,40 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
class AddCMulModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, input, tensor1, tensor2):
return torch.addcmul(input, tensor1, tensor2, value=1.0)
@register_test_case(module_factory=lambda: AddCMulModule())
def AddCMulModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))
class AddCDivModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, input, tensor1, tensor2):
return torch.addcdiv(input, tensor1, tensor2, value=1.0)
@register_test_case(module_factory=lambda: AddCDivModule())
def AddCDivModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,3), tu.rand(1,3), tu.rand(1,3))

View File

@ -30,4 +30,6 @@ TOSA_PASS_SET = {
"ElementwiseLogModule_basic",
"TanhBackward_basic",
"ReturnThreeTensorFloat32_basic",
"AddCMulModule_basic",
"AddCDivModule_basic",
}

View File

@ -2895,3 +2895,37 @@ def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_d
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
}
def Torch_AtenAddCMulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
}
def Torch_AtenAddCDivOp : Torch_Op<"aten.addcdiv", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $tensor1 `,` $tensor2 `,` $value attr-dict `:` type($self) `,` type($tensor1) `,` type($tensor2) `,` type($value) `->` type($result)";
}

View File

@ -375,6 +375,26 @@ public:
};
} // namespace
namespace {
template<typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value tensor1 = op.tensor1();
Value tensor2 = op.tensor2();
Value value = op.value();
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
value);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -408,6 +428,10 @@ class DecomposeComplexOpsPass
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCMulOp, AtenMulTensorOp>>(context);
target.addIllegalOp<AtenAddCMulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddCDivOp, AtenDivTensorOp>>(context);
target.addIllegalOp<AtenAddCDivOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();

View File

@ -422,6 +422,8 @@ public:
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
return visitNumToTensorOp(numToTensorOp);
} else if (isa<AtenAddCMulOp, AtenAddCDivOp>(op)) {
return visitAtenAddCLikeOp(op, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -535,6 +537,10 @@ private:
ChangeResult
visitAtenSoftmaxLikeOp(OpTy op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenAddCLikeOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace
@ -1376,6 +1382,25 @@ ChangeResult TypeAnalyzer::visitAtenMatmulOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto self = operands[0]->getValue();
auto tensor1 = operands[1]->getValue();
auto tensor2 = operands[2]->getValue();
if (tensor1.hasSizes && tensor2.hasSizes && self.hasSizes) {
knowledge.hasSizes = true;
knowledge.sizes.resize(
std::max(self.sizes.size(),
std::max(tensor1.sizes.size(), tensor2.sizes.size())),
kUnknownSize);
}
knowledge.dtype =
getPromotedResultType(getContext(), {&self, &tensor1, &tensor2});
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------

View File

@ -471,6 +471,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")