mirror of https://github.com/llvm/torch-mlir
Add aten.tensor.int and aten.tensor.float op lowerings.
Add the required lowerings and correct test cases. These op produce zero-d tensors and it was incorrectly mentioned in refine types to produce 1d tensor of size 1.pull/484/head
parent
cce490d71d
commit
ab81f871e4
|
@ -835,6 +835,41 @@ def AddCDivModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class tensorIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
a = 1
|
||||
return torch.tensor(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: tensorIntModule())
|
||||
def TensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class tensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
a = 1.0
|
||||
return torch.tensor(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: tensorFloatModule())
|
||||
def TensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class DropoutModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -3451,6 +3451,60 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct ConvertAtenScalarToTensorLike : ConversionPattern {
|
||||
ConvertAtenScalarToTensorLike(TypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTensorIntOp, AtenTensorFloatOp>(op))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "not a supported Scalar to Tensor like op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
Value elemVal, dtype, device, requires_grad;
|
||||
if (AtenTensorIntOp tensorIntOp = dyn_cast<AtenTensorIntOp>(op)) {
|
||||
AtenTensorIntOp::Adaptor adaptor(operands);
|
||||
elemVal = adaptor.t();
|
||||
dtype = tensorIntOp.dtype();
|
||||
device = tensorIntOp.device();
|
||||
requires_grad = tensorIntOp.requires_grad();
|
||||
}
|
||||
if (AtenTensorFloatOp tensorFloatOp = dyn_cast<AtenTensorFloatOp>(op)) {
|
||||
AtenTensorFloatOp::Adaptor adaptor(operands);
|
||||
elemVal = adaptor.t();
|
||||
dtype = tensorFloatOp.dtype();
|
||||
device = tensorFloatOp.device();
|
||||
requires_grad = tensorFloatOp.requires_grad();
|
||||
}
|
||||
// TODO: Dtype conversion.
|
||||
if (!dtype.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
|
||||
|
||||
// TODO: Device information.
|
||||
if (!device.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None device information");
|
||||
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type outElementType = resultType.getElementType();
|
||||
Value elemValProm =
|
||||
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
|
||||
Value zeroDTensor =
|
||||
createInitTensor(rewriter, loc, {}, outElementType, elemValProm);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, zeroDTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Converts constant tensor allocation like ops.
|
||||
template <typename OpTy>
|
||||
|
@ -3751,6 +3805,8 @@ public:
|
|||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIndexSelectOp>();
|
||||
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
||||
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
||||
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -238,11 +238,11 @@ public:
|
|||
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
|
||||
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
|
||||
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
|
||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp,
|
||||
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp,
|
||||
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp,
|
||||
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
|
||||
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
|
||||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -1272,7 +1272,6 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
|
|||
Value t = op.t();
|
||||
Value dtype = op.dtype();
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(1, 1);
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
|
|
@ -593,8 +593,8 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten
|
|||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[1],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],f32> to !torch.tensor
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
||||
builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
|
||||
|
@ -611,8 +611,8 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
|
|||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[CST11:.*]] = torch.constant.int 11
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[1],i1>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],i1> to !torch.tensor
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[],i1>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],i1> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
||||
builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {
|
||||
|
|
Loading…
Reference in New Issue