Add aten.tensor.int and aten.tensor.float op lowerings.

Add the required lowerings and correct test cases.
These op produce zero-d tensors and it was incorrectly mentioned in
refine types to produce 1d tensor of size 1.
pull/484/head
Prashant Kumar 2021-12-14 17:45:07 +05:30
parent cce490d71d
commit ab81f871e4
4 changed files with 100 additions and 10 deletions

View File

@ -835,6 +835,41 @@ def AddCDivModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class tensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = 1
return torch.tensor(a)
@register_test_case(module_factory=lambda: tensorIntModule())
def TensorIntModule_basic(module, tu: TestUtils):
module.forward()
class tensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = 1.0
return torch.tensor(a)
@register_test_case(module_factory=lambda: tensorFloatModule())
def TensorFloatModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class DropoutModule(torch.nn.Module): class DropoutModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -3451,6 +3451,60 @@ public:
}; };
} // namespace } // namespace
namespace {
struct ConvertAtenScalarToTensorLike : ConversionPattern {
ConvertAtenScalarToTensorLike(TypeConverter &typeConverter,
MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTensorIntOp, AtenTensorFloatOp>(op))
return rewriter.notifyMatchFailure(
op, "not a supported Scalar to Tensor like op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
Value elemVal, dtype, device, requires_grad;
if (AtenTensorIntOp tensorIntOp = dyn_cast<AtenTensorIntOp>(op)) {
AtenTensorIntOp::Adaptor adaptor(operands);
elemVal = adaptor.t();
dtype = tensorIntOp.dtype();
device = tensorIntOp.device();
requires_grad = tensorIntOp.requires_grad();
}
if (AtenTensorFloatOp tensorFloatOp = dyn_cast<AtenTensorFloatOp>(op)) {
AtenTensorFloatOp::Adaptor adaptor(operands);
elemVal = adaptor.t();
dtype = tensorFloatOp.dtype();
device = tensorFloatOp.device();
requires_grad = tensorFloatOp.requires_grad();
}
// TODO: Dtype conversion.
if (!dtype.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype");
// TODO: Device information.
if (!device.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Unimplemented non-None device information");
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
Value elemValProm =
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
Value zeroDTensor =
createInitTensor(rewriter, loc, {}, outElementType, elemValProm);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, zeroDTensor);
return success();
}
};
} // namespace
namespace { namespace {
// Converts constant tensor allocation like ops. // Converts constant tensor allocation like ops.
template <typename OpTy> template <typename OpTy>
@ -3751,6 +3805,8 @@ public:
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context); patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenIndexSelectOp>(); target.addIllegalOp<AtenIndexSelectOp>();
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context); patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<AtenTensorIntOp, AtenTensorFloatOp>();
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -238,11 +238,11 @@ public:
AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp,
AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp,
AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp,
AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp,
AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) { AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
@ -1272,7 +1272,6 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) {
Value t = op.t(); Value t = op.t();
Value dtype = op.dtype(); Value dtype = op.dtype();
knowledge.hasSizes = true; knowledge.hasSizes = true;
knowledge.sizes.resize(1, 1);
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType()); fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType());
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }

View File

@ -593,8 +593,8 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[1],f32> // CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],f32> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
@ -611,8 +611,8 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST11:.*]] = torch.constant.int 11 // CHECK: %[[CST11:.*]] = torch.constant.int 11
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[1],i1> // CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[],i1>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],i1> to !torch.tensor // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],i1> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor
builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {