mirror of https://github.com/llvm/torch-mlir
Add support for prim.NumToTensor
With this, we can import BERT!
```
pt_util ~/tmp/bert.pt --import --exported-name=forward \
| npcomp-opt -torch-globalize-object-graph -inline -symbol-dce
```
https://gist.github.com/silvasean/fe7735ff5d065cc9216f7b0346d0e977
The test case here is a bit unconventional -- it isn't actually valid
Python. To figure out how to generate it I had to go search the PyTorch
codebase for "NumToTensor" and work backward. In this case I found
this
[code](649760e5f1/torch/csrc/jit/frontend/ir_emitter.cpp (L464)
)
which via a wild guess I was able to turn into a test case.
In this case it didn't take me too long, but when doing this kind of
"add a bunch of trivial stuff to bring up a real model", I'm starting to
think that we might skimp on test cases when it's fairly trivial and not
obvious how to test with a small test.
pull/172/head
parent
7b6fa27838
commit
59a3f46795
|
@ -149,6 +149,15 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (kind == c10::prim::NumToTensor) {
|
||||||
|
MlirOperation operation =
|
||||||
|
createMlirOperationAtEnd(appendToBlock, "torch.prim.NumToTensor", loc,
|
||||||
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
lookupMappedValues(node->inputs()));
|
||||||
|
mapResults(node, operation);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Unhandled.
|
// Unhandled.
|
||||||
{
|
{
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @prim_NumToTensor(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
|
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
# CHECK: }
|
||||||
|
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_NumToTensor(i: int):
|
||||||
|
return _to_tensor(i)
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
|
@ -365,4 +365,15 @@ def Torch_PrintOp : Torch_Op<"prim.Print", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimNumToTensorOp : Torch_Op<"prim.NumToTensor", []> {
|
||||||
|
let summary = "TorchScript prim::NumToTensor op";
|
||||||
|
|
||||||
|
let arguments = (ins AnyTorchNumberType:$num);
|
||||||
|
let results = (outs AnyTorchTensorType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$num attr-dict `:` type($num) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TORCH_OPS
|
#endif // TORCH_OPS
|
||||||
|
|
|
@ -139,6 +139,15 @@ def AnyTorchScalarType : AnyTypeOf<[
|
||||||
AnySignlessInteger,
|
AnySignlessInteger,
|
||||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||||
|
|
||||||
|
def AnyTorchNumberType : AnyTypeOf<[
|
||||||
|
AnySignedInteger,
|
||||||
|
AnyFloat,
|
||||||
|
Basicpy_BoolType,
|
||||||
|
// Allow signless integers for ease of conversions. In general, this
|
||||||
|
// dialect uses signed integers.
|
||||||
|
AnySignlessInteger,
|
||||||
|
], "Any primitive numeric type">;
|
||||||
|
|
||||||
def AnyTorchBoolType : AnyTypeOf<[
|
def AnyTorchBoolType : AnyTypeOf<[
|
||||||
I1,
|
I1,
|
||||||
Basicpy_BoolType,
|
Basicpy_BoolType,
|
||||||
|
|
Loading…
Reference in New Issue