[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)

support fold with literal vtensor.  
change it to canonicalize because this pattern will create new op.
pull/3087/head
Yuanqiang Liu 2024-03-27 09:51:58 +08:00 committed by GitHub
parent e2343cf4ce
commit 0a581a97a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 11 deletions

View File

@ -11444,7 +11444,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [

View File

@ -150,7 +150,7 @@ static Value getScalarIntValue(Value input, Location loc,
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseElementsAttr>()
.cast<DenseIntElementsAttr>()
.getSplatValue<int64_t>();
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(val));
@ -3646,14 +3646,15 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
// If a scalar number is converted to a 0-d tensor and passed on to
// aten.Int.Tensor, fold to the scalar number.
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
return numToTensorScalar.getA();
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
return tensorIntOp.getT();
return nullptr;
void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) {
Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter);
if (!scalarInt)
return failure();
rewriter.replaceOp(op, scalarInt);
return success();
});
}
//===----------------------------------------------------------------------===//
@ -4099,6 +4100,9 @@ OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
a = IntegerAttr::get(dty, iattr.getInt());
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
a = FloatAttr::get(dty, fattr.getValueAsDouble());
} else {
// doesn't handle other types, like complex type
return {};
}
auto mlirTensorType =

View File

@ -679,7 +679,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")

View File

@ -1425,6 +1425,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
return %scalar : !torch.int
}
// CHECK-LABEL: @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
// CHECK: %[[NUM:.*]] = torch.constant.int 1
// CHECK: return %[[NUM]] : !torch.int
func.func @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
%cst = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
%scalar = torch.aten.Int.Tensor %cst : !torch.vtensor<[],si64> -> !torch.int
return %scalar : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
// CHECK: %[[NUM:.*]] = torch.constant.int 1
// CHECK: return %[[NUM]] : !torch.int