[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)

support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
pull/3087/head
Yuanqiang Liu 2024-03-27 09:51:58 +08:00 committed by GitHub
parent e2343cf4ce
commit 0a581a97a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 11 deletions

View File

@ -11444,7 +11444,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
printDefaultTorchOp(printer, *this, 1, 1); printDefaultTorchOp(printer, *this, 1, 1);
} }
}]; }];
let hasFolder = 1; let hasCanonicalizer = 1;
} }
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [ def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [

View File

@ -150,7 +150,7 @@ static Value getScalarIntValue(Value input, Location loc,
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) { if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue() auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>() .cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>(); .getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>( return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val)); loc, rewriter.getI64IntegerAttr(val));
@ -3646,14 +3646,15 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
// AtenIntTensorOp // AtenIntTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// If a scalar number is converted to a 0-d tensor and passed on to MLIRContext *context) {
// aten.Int.Tensor, fold to the scalar number. patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) {
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>()) Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter);
return numToTensorScalar.getA(); if (!scalarInt)
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>()) return failure();
return tensorIntOp.getT(); rewriter.replaceOp(op, scalarInt);
return nullptr; return success();
});
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4099,6 +4100,9 @@ OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
a = IntegerAttr::get(dty, iattr.getInt()); a = IntegerAttr::get(dty, iattr.getInt());
} else if (auto fattr = dyn_cast<FloatAttr>(a)) { } else if (auto fattr = dyn_cast<FloatAttr>(a)) {
a = FloatAttr::get(dty, fattr.getValueAsDouble()); a = FloatAttr::get(dty, fattr.getValueAsDouble());
} else {
// doesn't handle other types, like complex type
return {};
} }
auto mlirTensorType = auto mlirTensorType =

View File

@ -679,7 +679,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")

View File

@ -1425,6 +1425,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
return %scalar : !torch.int return %scalar : !torch.int
} }
// CHECK-LABEL: @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
// CHECK: %[[NUM:.*]] = torch.constant.int 1
// CHECK: return %[[NUM]] : !torch.int
func.func @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
%cst = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%scalar = torch.aten.Int.Tensor %cst : !torch.vtensor<[],si64> -> !torch.int
return %scalar : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int { // CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
// CHECK: %[[NUM:.*]] = torch.constant.int 1 // CHECK: %[[NUM:.*]] = torch.constant.int 1
// CHECK: return %[[NUM]] : !torch.int // CHECK: return %[[NUM]] : !torch.int