mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for aten::to.dtype.
This commit adds end to end support for AtenToDtypeOp from aten to linalg. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/403/head
parent
4bb9b44775
commit
18e8806b14
|
@ -311,7 +311,7 @@ class ElementwiseClampModule(torch.nn.Module):
|
||||||
def ElementwiseClampModule_basic(module, tu: TestUtils):
|
def ElementwiseClampModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5, low=-10, high=10))
|
module.forward(tu.rand(3, 5, low=-10, high=10))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
class RsubModule(torch.nn.Module):
|
class RsubModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -344,6 +344,7 @@ class RsubModule_noalpha(torch.nn.Module):
|
||||||
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
class ElementwiseLogModule(torch.nn.Module):
|
class ElementwiseLogModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -410,3 +411,21 @@ class ElementwisePowModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ElementwisePowModule())
|
@register_test_case(module_factory=lambda: ElementwisePowModule())
|
||||||
def ElementwisePowModule_basic(module, tu: TestUtils):
|
def ElementwisePowModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True)
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.int64)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module())
|
||||||
|
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
|
@ -1470,6 +1470,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||||
return b.create<arith::SubFOp>(loc, other, mult);
|
return b.create<arith::SubFOp>(loc, other, mult);
|
||||||
}
|
}
|
||||||
|
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
|
Value input = payloadArgs[0];
|
||||||
|
Type inType = input.getType();
|
||||||
|
Type outType = atenToDtype.getType().cast<ValueTensorType>().getDtype();
|
||||||
|
Value result;
|
||||||
|
if (!inType.isF32()) {
|
||||||
|
atenToDtype.emitError("unimplemented: non-floating point dtype");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (inType == outType)
|
||||||
|
result = input;
|
||||||
|
else if (outType.isInteger(64))
|
||||||
|
result = b.create<arith::FPToSIOp>(loc, b.getI64Type(), input);
|
||||||
|
else if (outType.isInteger(1))
|
||||||
|
result = b.create<arith::FPToSIOp>(loc, b.getI1Type(), input);
|
||||||
|
else
|
||||||
|
atenToDtype.emitError("unimplemented: unsupported target dtype");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForElementwiseOp");
|
"createLinalgPayloadCalculationForElementwiseOp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1679,8 +1699,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||||
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op))
|
AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -2816,8 +2836,9 @@ public:
|
||||||
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
target.addIllegalOp<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenAddTensorOp,
|
||||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||||
AtenMaximumOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>();
|
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||||
|
AtenPowTensorScalarOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||||
|
@ -2857,7 +2878,6 @@ public:
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenIntTensorOp>();
|
target.addIllegalOp<AtenIntTensorOp>();
|
||||||
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -91,8 +91,8 @@ public:
|
||||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||||
AtenBroadcastToOp, AtenContiguousOp, AtenPermuteOp,
|
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
|
||||||
AtenViewOp, AtenExpandOp>(op)) {
|
AtenPermuteOp, AtenViewOp, AtenExpandOp>(op)) {
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
// that it does not return a view and treat those as having value
|
// that it does not return a view and treat those as having value
|
||||||
|
|
|
@ -226,12 +226,11 @@ public:
|
||||||
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
if (isa<TensorStaticInfoCastOp, CopyToValueTensorOp, CopyToNonValueTensorOp,
|
||||||
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenEqScalarOp,
|
AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenEqScalarOp,
|
||||||
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
|
AtenGeScalarOp, AtenGtScalarOp, AtenNeScalarOp, AtenBitwiseNotOp,
|
||||||
AtenToDtypeOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp,
|
||||||
DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
|
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||||
AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp,
|
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||||
AtenCopy_Op, AtenIndexPut_Op, AtenCopy_Op, AtenCumsumOp,
|
AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp,
|
||||||
AtenLayerNormOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp>(op)) {
|
||||||
AtenSqrtOp, AtenFloorOp>(op)) {
|
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -346,6 +345,8 @@ public:
|
||||||
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
|
} else if (auto emptyMemoryFormat = dyn_cast<AtenEmptyMemoryFormatOp>(op)) {
|
||||||
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
|
return visitConstantTensorAllocOp<AtenEmptyMemoryFormatOp>(
|
||||||
emptyMemoryFormat);
|
emptyMemoryFormat);
|
||||||
|
} else if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
|
return visitAtenToDtypeOp(toDtype, operands);
|
||||||
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
|
} else if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
|
||||||
return visitTypeConversionOp<AtenToOtherOp>(toOther, operands);
|
return visitTypeConversionOp<AtenToOtherOp>(toOther, operands);
|
||||||
} else if (auto typeAs = dyn_cast<AtenTypeAsOp>(op)) {
|
} else if (auto typeAs = dyn_cast<AtenTypeAsOp>(op)) {
|
||||||
|
@ -480,6 +481,9 @@ private:
|
||||||
ChangeResult visitScalarToTensorConversionOp(OpTy op);
|
ChangeResult visitScalarToTensorConversionOp(OpTy op);
|
||||||
ChangeResult visitAtenTensorOp(AtenTensorOp op);
|
ChangeResult visitAtenTensorOp(AtenTensorOp op);
|
||||||
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
|
template <typename OpTy> ChangeResult visitConstantTensorAllocOp(OpTy op);
|
||||||
|
ChangeResult
|
||||||
|
visitAtenToDtypeOp(AtenToDtypeOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitTypeConversionOp(OpTy op,
|
visitTypeConversionOp(OpTy op,
|
||||||
|
@ -1078,6 +1082,20 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) {
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert input tensor type to the given `dtype`.
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
|
||||||
|
AtenToDtypeOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
knowledge.hasSizes = input.hasSizes;
|
||||||
|
if (input.hasSizes)
|
||||||
|
knowledge.sizes = input.sizes;
|
||||||
|
Value dtype = op.dtype();
|
||||||
|
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype);
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
// Convert input tensor type to the same as the other tensor.
|
// Convert input tensor type to the same as the other tensor.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
||||||
|
|
|
@ -1004,3 +1004,21 @@ func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1
|
||||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.to.dtype
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
|
||||||
|
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
|
||||||
|
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
|
||||||
|
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
|
||||||
|
// CHECK-SAME: -> !torch.tensor<[?,?],si64>
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor
|
||||||
|
// CHECK-NEXT: return %[[RES]] : !torch.tensor
|
||||||
|
|
||||||
|
func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||||
|
%none = torch.constant.none
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue