Minor fixes

pull/535/head snapshot-20220125.227
Yi Zhang 2022-01-24 18:58:27 -05:00
parent f8080bd1c5
commit ad4b9e0369
4 changed files with 3 additions and 16 deletions

View File

@ -175,6 +175,7 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
} }
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
NoSideEffect,
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
]> { ]> {
@ -185,7 +186,6 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [
AnyTorchType:$result AnyTorchType:$result
); );
let assemblyFormat = " attr-dict `:` qualified(type($result))"; let assemblyFormat = " attr-dict `:` qualified(type($result))";
let hasCanonicalizer = 1;
} }
def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [

View File

@ -1037,20 +1037,6 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// PrimUninitializedOp
//===----------------------------------------------------------------------===//
void PrimUninitializedOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) {
if (!op.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PrimTupleUnpackOp // PrimTupleUnpackOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -414,7 +414,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)") emit("prim::max.int : (int, int) -> (int)")
emit("prim::RaiseException : (str) -> ()") emit("prim::RaiseException : (str) -> ()")
emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True) emit("prim::Uninitialized : () -> (Any)", traits=["NoSideEffect"])
emit("prim::unchecked_cast : (t) -> (t)", emit("prim::unchecked_cast : (t) -> (t)",
traits=["DeclareOpInterfaceMethods<CastOpInterface>"]) traits=["DeclareOpInterfaceMethods<CastOpInterface>"])
emit("prim::Print : (...) -> ()") emit("prim::Print : (...) -> ()")

View File

@ -11,6 +11,7 @@
#include "function_importer.h" #include "function_importer.h"
#include "ivalue_importer.h" #include "ivalue_importer.h"
#include <ATen/TensorUtils.h>
#include <unordered_map> #include <unordered_map>
#include "mlir_utils.h" #include "mlir_utils.h"