From 60a947b4a7cc0edd21f00e0cc10e6e3bacaf3299 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 22 Jun 2021 13:56:12 -0700 Subject: [PATCH] Add CastOpInterface to torch.prim.unchecked_cast. This allows it to fold away in trivial cases. --- .../torch_mlir_utils/codegen/torch_ods_gen.py | 3 ++- .../npcomp/Dialect/Torch/IR/GeneratedPrimOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 9 +++++++++ test/Dialect/Torch/invalid.mlir | 16 ++++++++++++---- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py index 76cc59734..a1476a6dd 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py @@ -393,7 +393,8 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry): emit("prim::max.int : (int, int) -> (int)") emit("prim::RaiseException : (str) -> ()") emit("prim::Uninitialized : () -> (Any)") - emit("prim::unchecked_cast : (t) -> (t)") + emit("prim::unchecked_cast : (t) -> (t)", + traits=["DeclareOpInterfaceMethods"]) emit("prim::Print : (...) -> ()") diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td index 49af69d5c..291ca98ae 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedPrimOps.td @@ -184,6 +184,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ } def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ + DeclareOpInterfaceMethods, AllowsTypeRefinement, HasValueSemantics ]> { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 39dd7bbff..00ddce894 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -690,6 +690,15 @@ void Torch::ConstantBoolOp::getAsmResultNames( setNameFn(getResult(), value() ? "true" : "false"); } +//===----------------------------------------------------------------------===// +// PrimUncheckedCastOp +//===----------------------------------------------------------------------===// + +bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs, + mlir::TypeRange outputs) { + return isValidSubtype(outputs[0], inputs[0]); +} + //===----------------------------------------------------------------------===// // Aten__Getitem__TOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 3cec79ac4..593ed32ee 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -115,10 +115,18 @@ func @f(%arg0: i32 {torch.type_bound = i32}) // ----- -func @derefine(%arg0: !torch.optional>) -> tensor { - // expected-error @+1 {{operand type '!torch.optional>' and result type 'tensor' are cast incompatible}} - %0 = torch.derefine %arg0 : !torch.optional> to tensor - return %0 : tensor +func @derefine(%arg0: !torch.optional) -> !torch.tensor { + // expected-error @+1 {{operand type '!torch.optional' and result type '!torch.tensor' are cast incompatible}} + %0 = torch.derefine %arg0 : !torch.optional to !torch.tensor + return %0 : !torch.tensor +} + +// ----- + +func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional { + // expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional' are cast incompatible}} + %0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional + return %0 : !torch.optional } // -----