Add canonicalization for aten.add.tensor op

pull/973/head
erman-gurses 2022-06-17 18:49:36 +00:00 committed by erman-gurses
parent e143a34948
commit 5cff40c88a
4 changed files with 140 additions and 50 deletions

View File

@ -790,55 +790,6 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [
}];
}
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2439,6 +2390,56 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
}];
}
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -98,6 +98,29 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}
static Value getScalarValue(Value input, Location loc,
PatternRewriter &rewriter) {
Value scalar = nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
if (valueTensorLiteralOp &&
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
auto tensorType =
valueTensorLiteralOp.value().getType().cast<RankedTensorType>();
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
auto val = valueTensorLiteralOp.value()
.cast<DenseElementsAttr>()
.getSplatValue<int64_t>();
scalar = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
}
}
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
scalar = primNumToTensorScalarOp.a();
}
return scalar;
}
//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
@ -763,6 +786,38 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenAddTensorOp
//===----------------------------------------------------------------------===//
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
// The lhs and rhs of the add.tensor op should be 0d tensors for the
// canonicalization to be carried out.
// `aten.add.tensor(self, other, alpha)` is canonicalized to
// `aten.add.int(self, aten.mul.int(other, alpha))`.
Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter);
if (!lhs)
return rewriter.notifyMatchFailure(op, "lhs scalar is empyty");
if (!lhs.getType().isa<Torch::IntType>())
return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType");
Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter);
if (!rhs)
return rewriter.notifyMatchFailure(op, "rhs scalar is empyty");
if (!rhs.getType().isa<Torch::IntType>())
return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType");
Value mul = rewriter.create<AtenMulIntOp>(op->getLoc(), rhs, op.alpha());
Value add = rewriter.create<AtenAddIntOp>(op->getLoc(), lhs, mul);
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
op, op.self().getType(), add);
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenSizeOp
//===----------------------------------------------------------------------===//

View File

@ -255,7 +255,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
@ -294,6 +293,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
# variants.
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")

View File

@ -1267,3 +1267,37 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool {
%1 = torch.aten.Bool.int %int : !torch.int -> !torch.bool
return %1 : !torch.bool
}
// CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}
// CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
%int0 = torch.constant.int 0
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
return %2 : !torch.vtensor<[],si64>
}