Add CastOpInterface to torch.prim.unchecked_cast.

This allows it to fold away in trivial cases.
pull/239/head
Sean Silva 2021-06-22 13:56:12 -07:00
parent 45f2edfc7a
commit 60a947b4a7
4 changed files with 24 additions and 5 deletions

View File

@ -393,7 +393,8 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit("prim::max.int : (int, int) -> (int)")
emit("prim::RaiseException : (str) -> ()")
emit("prim::Uninitialized : () -> (Any)")
emit("prim::unchecked_cast : (t) -> (t)")
emit("prim::unchecked_cast : (t) -> (t)",
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()")

View File

@ -184,6 +184,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
}
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,
HasValueSemantics
]> {

View File

@ -690,6 +690,15 @@ void Torch::ConstantBoolOp::getAsmResultNames(
setNameFn(getResult(), value() ? "true" : "false");
}
//===----------------------------------------------------------------------===//
// PrimUncheckedCastOp
//===----------------------------------------------------------------------===//
bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
mlir::TypeRange outputs) {
return isValidSubtype(outputs[0], inputs[0]);
}
//===----------------------------------------------------------------------===//
// Aten__Getitem__TOp
//===----------------------------------------------------------------------===//

View File

@ -115,10 +115,18 @@ func @f(%arg0: i32 {torch.type_bound = i32})
// -----
func @derefine(%arg0: !torch.optional<tensor<f32>>) -> tensor<f32> {
// expected-error @+1 {{operand type '!torch.optional<tensor<f32>>' and result type 'tensor<f32>' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<tensor<f32>> to tensor<f32>
return %0 : tensor<f32>
func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<!torch.tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<!torch.tensor> to !torch.tensor
return %0 : !torch.tensor
}
// -----
func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<!torch.tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<!torch.tensor>
return %0 : !torch.optional<!torch.tensor>
}
// -----