mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor. change it to canonicalize because this pattern will create new op.pull/3087/head
parent
e2343cf4ce
commit
0a581a97a7
|
@ -11444,7 +11444,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
|
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
|
||||||
|
|
|
@ -150,7 +150,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = valueTensorLiteralOp.getValue()
|
||||||
.cast<DenseElementsAttr>()
|
.cast<DenseIntElementsAttr>()
|
||||||
.getSplatValue<int64_t>();
|
.getSplatValue<int64_t>();
|
||||||
return rewriter.create<Torch::ConstantIntOp>(
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
@ -3646,14 +3646,15 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
// AtenIntTensorOp
|
// AtenIntTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
MLIRContext *context) {
|
||||||
// aten.Int.Tensor, fold to the scalar number.
|
patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) {
|
||||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter);
|
||||||
return numToTensorScalar.getA();
|
if (!scalarInt)
|
||||||
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
return failure();
|
||||||
return tensorIntOp.getT();
|
rewriter.replaceOp(op, scalarInt);
|
||||||
return nullptr;
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -4099,6 +4100,9 @@ OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
|
||||||
a = IntegerAttr::get(dty, iattr.getInt());
|
a = IntegerAttr::get(dty, iattr.getInt());
|
||||||
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
|
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
|
||||||
a = FloatAttr::get(dty, fattr.getValueAsDouble());
|
a = FloatAttr::get(dty, fattr.getValueAsDouble());
|
||||||
|
} else {
|
||||||
|
// doesn't handle other types, like complex type
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mlirTensorType =
|
auto mlirTensorType =
|
||||||
|
|
|
@ -679,7 +679,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
||||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||||
|
|
|
@ -1425,6 +1425,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
||||||
return %scalar : !torch.int
|
return %scalar : !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
|
||||||
|
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: return %[[NUM]] : !torch.int
|
||||||
|
func.func @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
|
||||||
|
%cst = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
%scalar = torch.aten.Int.Tensor %cst : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
return %scalar : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
|
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
|
||||||
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
||||||
// CHECK: return %[[NUM]] : !torch.int
|
// CHECK: return %[[NUM]] : !torch.int
|
||||||
|
|
Loading…
Reference in New Issue