mirror of https://github.com/llvm/torch-mlir
Add support for prim.NumToTensor
With this, we can import BERT!
```
pt_util ~/tmp/bert.pt --import --exported-name=forward \
| npcomp-opt -torch-globalize-object-graph -inline -symbol-dce
```
https://gist.github.com/silvasean/fe7735ff5d065cc9216f7b0346d0e977
The test case here is a bit unconventional -- it isn't actually valid
Python. To figure out how to generate it I had to go search the PyTorch
codebase for "NumToTensor" and work backward. In this case I found
this
[code](649760e5f1/torch/csrc/jit/frontend/ir_emitter.cpp (L464)
)
which via a wild guess I was able to turn into a test case.
In this case it didn't take me too long, but when doing this kind of
"add a bunch of trivial stuff to bring up a real model", I'm starting to
think that we might skimp on test cases when it's fairly trivial and not
obvious how to test with a small test.
pull/172/head
parent
7b6fa27838
commit
59a3f46795
|
@ -149,6 +149,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::NumToTensor) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.NumToTensor", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
||||
// Unhandled.
|
||||
{
|
||||
std::stringstream msg;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: }
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_NumToTensor(i: int):
|
||||
return _to_tensor(i)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -365,4 +365,15 @@ def Torch_PrintOp : Torch_Op<"prim.Print", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
||||
let summary = "TorchScript prim::NumToTensor op";
|
||||
|
||||
let arguments = (ins AnyTorchNumberType:$num);
|
||||
let results = (outs AnyTorchTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$num attr-dict `:` type($num) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -139,6 +139,15 @@ def AnyTorchScalarType : AnyTypeOf<[
|
|||
AnySignlessInteger,
|
||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||
|
||||
def AnyTorchNumberType : AnyTypeOf<[
|
||||
AnySignedInteger,
|
||||
AnyFloat,
|
||||
Basicpy_BoolType,
|
||||
// Allow signless integers for ease of conversions. In general, this
|
||||
// dialect uses signed integers.
|
||||
AnySignlessInteger,
|
||||
], "Any primitive numeric type">;
|
||||
|
||||
def AnyTorchBoolType : AnyTypeOf<[
|
||||
I1,
|
||||
Basicpy_BoolType,
|
||||
|
|
Loading…
Reference in New Issue