mirror of https://github.com/llvm/torch-mlir
Add canonicalization for aten.add.tensor op
parent
e143a34948
commit
5cff40c88a
|
@ -790,55 +790,6 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -2439,6 +2390,56 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenAddTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchScalarType:$alpha
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenAdd_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -98,6 +98,29 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|||
return FloatAttr::get(Float64Type::get(context), value);
|
||||
}
|
||||
|
||||
static Value getScalarValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
Value scalar = nullptr;
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
if (valueTensorLiteralOp &&
|
||||
getTensorRank(valueTensorLiteralOp.getResult()) == 0) {
|
||||
auto tensorType =
|
||||
valueTensorLiteralOp.value().getType().cast<RankedTensorType>();
|
||||
if (tensorType.getElementType().isa<mlir::IntegerType>()) {
|
||||
auto val = valueTensorLiteralOp.value()
|
||||
.cast<DenseElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
scalar = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
}
|
||||
}
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
scalar = primNumToTensorScalarOp.a();
|
||||
}
|
||||
return scalar;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MethodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -763,6 +786,38 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenAddTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) {
|
||||
// The lhs and rhs of the add.tensor op should be 0d tensors for the
|
||||
// canonicalization to be carried out.
|
||||
// `aten.add.tensor(self, other, alpha)` is canonicalized to
|
||||
// `aten.add.int(self, aten.mul.int(other, alpha))`.
|
||||
|
||||
Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter);
|
||||
if (!lhs)
|
||||
return rewriter.notifyMatchFailure(op, "lhs scalar is empyty");
|
||||
if (!lhs.getType().isa<Torch::IntType>())
|
||||
return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType");
|
||||
|
||||
Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter);
|
||||
if (!rhs)
|
||||
return rewriter.notifyMatchFailure(op, "rhs scalar is empyty");
|
||||
if (!rhs.getType().isa<Torch::IntType>())
|
||||
return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType");
|
||||
|
||||
Value mul = rewriter.create<AtenMulIntOp>(op->getLoc(), rhs, op.alpha());
|
||||
Value add = rewriter.create<AtenAddIntOp>(op->getLoc(), lhs, mul);
|
||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
|
||||
op, op.self().getType(), add);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -255,7 +255,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::floor : (Tensor) -> (Tensor)",
|
||||
"aten::ceil : (Tensor) -> (Tensor)",
|
||||
"aten::bitwise_not : (Tensor) -> (Tensor)",
|
||||
"aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
|
@ -294,6 +293,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants(key)
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
# variants.
|
||||
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -1267,3 +1267,37 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool {
|
|||
%1 = torch.aten.Bool.int %int : !torch.int -> !torch.bool
|
||||
return %1 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[PR3]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[PR2]] : !torch.vtensor<[],si64>
|
||||
func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64>
|
||||
%2 = torch.aten.add.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64>
|
||||
return %2 : !torch.vtensor<[],si64>
|
||||
}
|
Loading…
Reference in New Issue