mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] enhance aten.int.tensor's canonicalize (#3058)
support fold with literal vtensor. change it to canonicalize because this pattern will create new op.pull/3087/head
parent
e2343cf4ce
commit
0a581a97a7
|
@ -11444,7 +11444,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [
|
||||
|
|
|
@ -150,7 +150,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseElementsAttr>()
|
||||
.cast<DenseIntElementsAttr>()
|
||||
.getSplatValue<int64_t>();
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(val));
|
||||
|
@ -3646,14 +3646,15 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) {
|
|||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a scalar number is converted to a 0-d tensor and passed on to
|
||||
// aten.Int.Tensor, fold to the scalar number.
|
||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||
return numToTensorScalar.getA();
|
||||
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
||||
return tensorIntOp.getT();
|
||||
return nullptr;
|
||||
void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) {
|
||||
Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter);
|
||||
if (!scalarInt)
|
||||
return failure();
|
||||
rewriter.replaceOp(op, scalarInt);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -4099,6 +4100,9 @@ OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) {
|
|||
a = IntegerAttr::get(dty, iattr.getInt());
|
||||
} else if (auto fattr = dyn_cast<FloatAttr>(a)) {
|
||||
a = FloatAttr::get(dty, fattr.getValueAsDouble());
|
||||
} else {
|
||||
// doesn't handle other types, like complex type
|
||||
return {};
|
||||
}
|
||||
|
||||
auto mlirTensorType =
|
||||
|
|
|
@ -679,7 +679,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True)
|
||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||
|
|
|
@ -1425,6 +1425,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
|||
return %scalar : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
|
||||
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[NUM]] : !torch.int
|
||||
func.func @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int {
|
||||
%cst = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||
%scalar = torch.aten.Int.Tensor %cst : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %scalar : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
|
||||
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[NUM]] : !torch.int
|
||||
|
|
Loading…
Reference in New Issue