[TORCH][MLIR] Add E2E support for aten::to.dtype.

This commit adds end to end support for AtenToDtypeOp from aten
to linalg.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/403/head
Prateek Gupta 2021-11-01 11:46:46 +00:00 committed by Yi Zhang
parent 4bb9b44775
commit 18e8806b14
5 changed files with 89 additions and 14 deletions

View File

@ -311,7 +311,7 @@ class ElementwiseClampModule(torch.nn.Module):
def ElementwiseClampModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, low=-10, high=10))
# ==============================================================================
class RsubModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -344,6 +344,7 @@ class RsubModule_noalpha(torch.nn.Module):
def RsubModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseLogModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -410,3 +411,21 @@ class ElementwisePowModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwisePowModule())
def ElementwisePowModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True)
])
def forward(self, x):
return x.to(torch.int64)
@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module())
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))

View File

@ -1470,6 +1470,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
}
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
Type inType = input.getType();
Type outType = atenToDtype.getType().cast<ValueTensorType>().getDtype();
Value result;
if (!inType.isF32()) {
atenToDtype.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
if (inType == outType)
result = input;
else if (outType.isInteger(64))
result = b.create<arith::FPToSIOp>(loc, b.getI64Type(), input);
else if (outType.isInteger(1))
result = b.create<arith::FPToSIOp>(loc, b.getI1Type(), input);
else
atenToDtype.emitError("unimplemented: unsupported target dtype");
return result;
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp");
return nullptr;
@ -1679,8 +1699,8 @@ struct ConvertElementwiseOp : ConversionPattern {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op))
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -2816,8 +2836,9 @@ public:
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>();
AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenUnsqueezeOp>();
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
@ -2857,7 +2878,6 @@ public:
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenIntTensorOp>();
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();

View File

@ -91,8 +91,8 @@ public:
copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp, TensorStaticInfoCastOp,
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
AtenViewOp, AtenExpandOp>(op)) {
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
AtenPermuteOp, AtenViewOp, AtenExpandOp>(op)) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value

View File

@ -226,12 +226,11 @@ public:
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenEqScalarOp,
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
AtenToDtypeOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenSqrtOp, AtenFloorOp>(op)) {
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}
@ -346,6 +345,8 @@ public:
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
emptyMemoryFormat);
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
return visitAtenToDtypeOp(toDtype, operands);
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
return visitTypeConversionOp<AtenToOtherOp>(toOther, operands);
} else if (auto typeAs = dyn_cast<AtenTypeAsOp>(op)) {
@ -480,6 +481,9 @@ private:
ChangeResult visitScalarToTensorConversionOp(OpTy op);
ChangeResult visitAtenTensorOp(AtenTensorOp op);
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
ChangeResult
visitAtenToDtypeOp(AtenToDtypeOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy>
ChangeResult
visitTypeConversionOp(OpTy op,
@ -1078,6 +1082,20 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
return getLatticeElement(op.getResult()).join(knowledge);
}
// Convert input tensor type to the given `dtype`.
ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
if (input.hasSizes)
knowledge.sizes = input.sizes;
Value dtype = op.dtype();
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype);
return getLatticeElement(op.getResult()).join(knowledge);
}
// Convert input tensor type to the same as the other tensor.
template <typename OpTy>
ChangeResult TypeAnalyzer::visitTypeConversionOp(

View File

@ -1004,3 +1004,21 @@ func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @torch.aten.to.dtype
// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
return %0 : !torch.tensor
}