mirror of https://github.com/llvm/torch-mlir
Add CastOpInterface to torch.prim.unchecked_cast.
This allows it to fold away in trivial cases.pull/239/head
parent
45f2edfc7a
commit
60a947b4a7
|
@ -393,7 +393,8 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("prim::max.int : (int, int) -> (int)")
|
||||
emit("prim::RaiseException : (str) -> ()")
|
||||
emit("prim::Uninitialized : () -> (Any)")
|
||||
emit("prim::unchecked_cast : (t) -> (t)")
|
||||
emit("prim::unchecked_cast : (t) -> (t)",
|
||||
traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
|
||||
emit("prim::Print : (...) -> ()")
|
||||
|
||||
|
||||
|
|
|
@ -184,6 +184,7 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
|
|||
}
|
||||
|
||||
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
|
|
|
@ -690,6 +690,15 @@ void Torch::ConstantBoolOp::getAsmResultNames(
|
|||
setNameFn(getResult(), value() ? "true" : "false");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimUncheckedCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs,
|
||||
mlir::TypeRange outputs) {
|
||||
return isValidSubtype(outputs[0], inputs[0]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten__Getitem__TOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -115,10 +115,18 @@ func @f(%arg0: i32 {torch.type_bound = i32})
|
|||
|
||||
// -----
|
||||
|
||||
func @derefine(%arg0: !torch.optional<tensor<f32>>) -> tensor<f32> {
|
||||
// expected-error @+1 {{operand type '!torch.optional<tensor<f32>>' and result type 'tensor<f32>' are cast incompatible}}
|
||||
%0 = torch.derefine %arg0 : !torch.optional<tensor<f32>> to tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
func @derefine(%arg0: !torch.optional<!torch.tensor>) -> !torch.tensor {
|
||||
// expected-error @+1 {{operand type '!torch.optional<!torch.tensor>' and result type '!torch.tensor' are cast incompatible}}
|
||||
%0 = torch.derefine %arg0 : !torch.optional<!torch.tensor> to !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<!torch.tensor> {
|
||||
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<!torch.tensor>' are cast incompatible}}
|
||||
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<!torch.tensor>
|
||||
return %0 : !torch.optional<!torch.tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue