mirror of https://github.com/llvm/torch-mlir
torch: handle `torch.prim.dtype` ops during type refinement (#1013)
The canonicalizer converts `torch.prim.dtype` ops into integer constants for valid types, but the type may not be known until type refinement is complete. However, type refinement cannot make progress until `torch.prim.dtype` ops have been resolved to their corresponding integer constants, thus creating a circular dependency. This patch creates a tight coupling between type refinement and the lowering of `torch.prim.dtype` ops by handling such ops as they are encountered during type refinement. The unit test in this patch aims to check whether the type refinement pass can now handle chains of operations that alternate between type construction and type refinement.pull/1032/head
parent
5bd9362c61
commit
340d8af28a
|
@ -363,6 +363,17 @@ public:
|
|||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final;
|
||||
|
||||
private:
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and the
|
||||
// input dtype. When DType is None the type is inferred from the input dtype.
|
||||
void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
|
||||
Value dtype, Type inputDType);
|
||||
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and
|
||||
// data type of torch type. When DType is None the type is inferred from the
|
||||
// data type.
|
||||
void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, Value dtype,
|
||||
Type dataType);
|
||||
|
||||
/// Incorporates `knowledge` into the lattice state of `v`.
|
||||
///
|
||||
/// This method should be used instead of
|
||||
|
@ -587,24 +598,21 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context,
|
|||
/*skipRankCheck=*/true);
|
||||
}
|
||||
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and the
|
||||
// input dtype. When DType is None the type is inferred from the input dtype.
|
||||
static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
|
||||
Value dtype,
|
||||
Type inputDType) {
|
||||
void TypeAnalyzer::fillInDTypeGivenDTypeIntAndInputDType(
|
||||
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
||||
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
||||
int64_t dtypeInt;
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
knowledge.dtype = inputDType;
|
||||
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
knowledge.dtype = getTypeForDTypeInteger(dtype.getContext(), dtypeInt);
|
||||
else if (auto primDtypeOp = dyn_cast<PrimDtypeOp>(dtype.getDefiningOp()))
|
||||
knowledge.dtype = getLatticeElement(primDtypeOp.a()).getValue().dtype;
|
||||
}
|
||||
|
||||
// Get the MLIR type of the tensor dtype given the dtype integer value and data
|
||||
// type of torch type. When DType is None the type is inferred from the data
|
||||
// type.
|
||||
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
|
||||
Value dtype, Type dataType) {
|
||||
void TypeAnalyzer::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
|
||||
Value dtype,
|
||||
Type dataType) {
|
||||
assert(isa<TorchDialect>(dataType.getDialect()) &&
|
||||
"`dataType` must be a torch type");
|
||||
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);
|
||||
|
|
|
@ -196,3 +196,47 @@ func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
|
|||
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
|
||||
return %1 : !torch.number
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @prim.dtype(
|
||||
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor {
|
||||
|
||||
// CHECK: %[[zero:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[false:.*]] = torch.constant.bool false
|
||||
|
||||
// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16>
|
||||
// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int
|
||||
// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device
|
||||
// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
|
||||
|
||||
// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int
|
||||
// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device
|
||||
// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
|
||||
|
||||
// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor
|
||||
// CHECK: return %[[cast]] : !torch.vtensor
|
||||
// CHECK: }
|
||||
|
||||
func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
|
||||
%zero = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
|
||||
// Op that requires type refinement
|
||||
%neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk>
|
||||
|
||||
// Op whose processing requires type refinement on its source argument.
|
||||
%dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int
|
||||
%device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device
|
||||
|
||||
// Another op that requires type refinement
|
||||
%result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
|
||||
|
||||
// Repeat the above three ops a second time to ensure that the type refinement
|
||||
// code works regardless of the number of alternating refinement+prim.dtype
|
||||
// sequences.
|
||||
%dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int
|
||||
%device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device
|
||||
%result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
|
||||
|
||||
return %result2 : !torch.vtensor<*,unk>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue