mirror of https://github.com/llvm/torch-mlir
Add NumToTensor
parent
a75ae82530
commit
f41958037a
|
@ -542,3 +542,19 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: LogSoftmaxIntModule())
|
||||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
class NumToTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.ops.prim.NumToTensor(1)
|
||||
|
||||
@register_test_case(module_factory=lambda: NumToTensorModule())
|
||||
def NumToTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
|
@ -2794,6 +2794,29 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertPrimNumToTensorScalarOp
|
||||
: public OpConversionPattern<PrimNumToTensorScalarOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimNumToTensorScalarOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
PrimNumToTensorScalarOp::Adaptor adaptor(operands);
|
||||
Location loc = op.getLoc();
|
||||
Value a = adaptor.a();
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, ValueRange{}, a.getType())
|
||||
->getResult(0);
|
||||
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, a, outTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2878,6 +2901,9 @@ public:
|
|||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
||||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -416,6 +416,8 @@ public:
|
|||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||
return visitNumToTensorOp(numToTensorOp);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -477,6 +479,7 @@ private:
|
|||
ChangeResult
|
||||
visitAtenPermuteOp(AtenPermuteOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
|
||||
template <typename OpTy>
|
||||
ChangeResult visitScalarToTensorConversionOp(OpTy op);
|
||||
ChangeResult visitAtenTensorOp(AtenTensorOp op);
|
||||
|
@ -1262,6 +1265,14 @@ ChangeResult TypeAnalyzer::visitAtenShapeAsTensorOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
||||
AtenEmbeddingOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
|
|
|
@ -67,3 +67,17 @@ func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
|||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
|
||||
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
|
||||
// CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor<i64> -> tensor<i64>
|
||||
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
|
||||
|
||||
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
|
||||
%0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64>
|
||||
return %0 : !torch.vtensor<[],si64>
|
||||
}
|
||||
|
|
|
@ -979,28 +979,28 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
|||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
@ -1022,3 +1022,16 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
|||
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<[],si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
||||
func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
||||
return %0: !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue