Add support for prim::RaiseException.

Used by resnet18.

It seems to originate from a helper `_verify_batch_size`:
[code link](b3bf08e67f/torch/nn/functional.py (L2099)).

I couldn't find a way to test `prim::RaiseException` without also having
`prim::Uninitialized`.
pull/176/head
Sean Silva 2021-03-01 13:47:50 -08:00
parent 7bb3b2eb6e
commit 7dfd6f697e
4 changed files with 58 additions and 2 deletions

View File

@ -176,6 +176,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::RaiseException) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.RaiseException", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
if (kind == c10::prim::Uninitialized) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.Uninitialized", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;

View File

@ -136,9 +136,12 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
// TODO: Don't lose the element type information.
return npcompTupleTypeGet(context);
}
case TypeKind::StringType: {
return npcompBytesTypeGet(context);
}
default: {
std::stringstream message;
message << "unable to map Torch type " << *torchType << " to MLIR type";
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
mlirEmitError(loc, message.str().c_str());
return {nullptr};
}

View File

@ -30,5 +30,15 @@ def prim_NumToTensor(i: int):
def prim_Print(x):
print("x", x)
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
# CHECK: return %[[NONE]] : !basicpy.NoneType
@mb.import_function
@torch.jit.script
def prim_RaiseException():
raise Exception("Error")
mb.module.operation.print()
print()

View File

@ -366,7 +366,7 @@ def Torch_PrimCallMethodOp : Torch_Op<"prim.CallMethod", []> {
}];
}
def Torch_PrintOp : Torch_Op<"prim.Print", []> {
def Torch_PrimPrintOp : Torch_Op<"prim.Print", []> {
let summary = "TorchScript prim::Print op";
let arguments = (ins Variadic<AnyTorchType>:$operands);
@ -388,4 +388,29 @@ def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
}];
}
def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", []> {
let summary = "TorchScript prim::RaiseException op";
// TODO: Error messages suggest that any exception derived from BaseException
// is allowed at the Python level, but they seem to just be strings at the
// IR level.
let arguments = (ins Basicpy_BytesType:$errorMsg);
let results = (outs);
let assemblyFormat = [{
$errorMsg attr-dict
}];
}
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
let summary = "TorchScript prim::Uninitialized op";
let arguments = (ins);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
attr-dict `:` type($result)
}];
}
#endif // TORCH_OPS