Add NumToTensor

pull/406/head
George Petterson 2021-11-08 15:28:51 -05:00 committed by Yi Zhang
parent a75ae82530
commit f41958037a
5 changed files with 90 additions and 10 deletions

View File

@ -542,3 +542,19 @@ class LogSoftmaxIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: LogSoftmaxIntModule())
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())
class NumToTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.prim.NumToTensor(1)
@register_test_case(module_factory=lambda: NumToTensorModule())
def NumToTensorModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -2794,6 +2794,29 @@ public:
};
} // namespace
namespace {
class ConvertPrimNumToTensorScalarOp
: public OpConversionPattern<PrimNumToTensorScalarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimNumToTensorScalarOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
PrimNumToTensorScalarOp::Adaptor adaptor(operands);
Location loc = op.getLoc();
Value a = adaptor.a();
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, ValueRange{}, a.getType())
->getResult(0);
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, a, outTensor);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -2878,6 +2901,9 @@ public:
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();

View File

@ -416,6 +416,8 @@ public:
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
return visitNumToTensorOp(numToTensorOp);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -477,6 +479,7 @@ private:
ChangeResult
visitAtenPermuteOp(AtenPermuteOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
template <typename OpTy>
ChangeResult visitScalarToTensorConversionOp(OpTy op);
ChangeResult visitAtenTensorOp(AtenTensorOp op);
@ -1262,6 +1265,14 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
AtenEmbeddingOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =

View File

@ -67,3 +67,17 @@ func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
return %0 : !torch.int
}
// -----
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
// CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor<i64> -> tensor<i64>
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
%0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
return %0 : !torch.vtensor<[],si64>
}

View File

@ -979,28 +979,28 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
// ----
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// ----
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
@ -1022,3 +1022,16 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
return %0 : !torch.tensor
}
// ----
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor
}