Add support for prim.NumToTensor

With this, we can import BERT!
```
pt_util ~/tmp/bert.pt  --import --exported-name=forward \
| npcomp-opt -torch-globalize-object-graph -inline -symbol-dce
```
https://gist.github.com/silvasean/fe7735ff5d065cc9216f7b0346d0e977

The test case here is a bit unconventional -- it isn't actually valid
Python. To figure out how to generate it I had to go search the PyTorch
codebase for "NumToTensor" and work backward. In this case I found
this
[code](649760e5f1/torch/csrc/jit/frontend/ir_emitter.cpp (L464))
which via a wild guess I was able to turn into a test case.

In this case it didn't take me too long, but when doing this kind of
"add a bunch of trivial stuff to bring up a real model", I'm starting to
think that we might skimp on test cases when it's fairly trivial and not
obvious how to test with a small test.
pull/172/head
Sean Silva 2021-02-25 16:35:29 -08:00
parent 7b6fa27838
commit 59a3f46795
4 changed files with 54 additions and 0 deletions

View File

@ -149,6 +149,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
return;
}
if (kind == c10::prim::NumToTensor) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "torch.prim.NumToTensor", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
}
// Unhandled.
{
std::stringstream msg;

View File

@ -0,0 +1,25 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: }
@mb.import_function
@torch.jit.script
def prim_NumToTensor(i: int):
return _to_tensor(i)
mb.module.operation.print()
print()

View File

@ -365,4 +365,15 @@ def Torch_PrintOp : Torch_Op<"prim.Print", []> {
}];
}
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
let summary = "TorchScript prim::NumToTensor op";
let arguments = (ins AnyTorchNumberType:$num);
let results = (outs AnyTorchTensorType:$result);
let assemblyFormat = [{
$num attr-dict `:` type($num) `->` type($result)
}];
}
#endif // TORCH_OPS

View File

@ -139,6 +139,15 @@ def AnyTorchScalarType : AnyTypeOf<[
AnySignlessInteger,
], "Any primitive type suitable to be passed as a Torch Scalar">;
def AnyTorchNumberType : AnyTypeOf<[
AnySignedInteger,
AnyFloat,
Basicpy_BoolType,
// Allow signless integers for ease of conversions. In general, this
// dialect uses signed integers.
AnySignlessInteger,
], "Any primitive numeric type">;
def AnyTorchBoolType : AnyTypeOf<[
I1,
Basicpy_BoolType,