mirror of https://github.com/llvm/torch-mlir
Add support for prim::RaiseException.
Used by resnet18.
It seems to originate from a helper `_verify_batch_size`:
[code link](b3bf08e67f/torch/nn/functional.py (L2099)
).
I couldn't find a way to test `prim::RaiseException` without also having
`prim::Uninitialized`.
pull/176/head
parent
7bb3b2eb6e
commit
7dfd6f697e
|
@ -176,6 +176,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::RaiseException) {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.RaiseException", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::Uninitialized) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.Uninitialized", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -136,9 +136,12 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
// TODO: Don't lose the element type information.
|
||||
return npcompTupleTypeGet(context);
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompBytesTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unable to map Torch type " << *torchType << " to MLIR type";
|
||||
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
|
||||
mlirEmitError(loc, message.str().c_str());
|
||||
return {nullptr};
|
||||
}
|
||||
|
|
|
@ -30,5 +30,15 @@ def prim_NumToTensor(i: int):
|
|||
def prim_Print(x):
|
||||
print("x", x)
|
||||
|
||||
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
|
||||
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
|
||||
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
|
||||
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
|
||||
# CHECK: return %[[NONE]] : !basicpy.NoneType
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_RaiseException():
|
||||
raise Exception("Error")
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -366,7 +366,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrintOp : Torch_Op<"prim.Print", []> {
|
||||
def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
|
||||
let summary = "TorchScript prim::Print op";
|
||||
|
||||
let arguments = (ins Variadic<AnyTorchType>:$operands);
|
||||
|
@ -388,4 +388,29 @@ def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", []> {
|
||||
let summary = "TorchScript prim::RaiseException op";
|
||||
|
||||
// TODO: Error messages suggest that any exception derived from BaseException
|
||||
// is allowed at the Python level, but they seem to just be strings at the
|
||||
// IR level.
|
||||
let arguments = (ins Basicpy_BytesType:$errorMsg);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$errorMsg attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
|
||||
let summary = "TorchScript prim::Uninitialized op";
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
Loading…
Reference in New Issue