torch: handle `torch.prim.dtype` ops during type refinement (#1013)

The canonicalizer converts `torch.prim.dtype` ops into integer constants
for valid types, but the type may not be known until type refinement is
complete.  However, type refinement cannot make progress until
`torch.prim.dtype` ops have been resolved to their corresponding integer
constants, thus creating a circular dependency.

This patch creates a tight coupling between type refinement and the
lowering of `torch.prim.dtype` ops by handling such ops as they are
encountered during type refinement.  The unit test in this patch aims to
check whether the type refinement pass can now handle chains of
operations that alternate between type construction and type refinement.
pull/1032/head
Ashay Rane 2022-07-08 16:38:51 -07:00 committed by GitHub
parent 5bd9362c61
commit 340d8af28a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 10 deletions

View File

@ -363,6 +363,17 @@ public:
ArrayRef<LatticeElement<ValueKnowledge> *> operands) final;
private:
// Get the MLIR type of the tensor dtype given the dtype integer value and the
// input dtype. When DType is None the type is inferred from the input dtype.
void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
Value dtype, Type inputDType);
// Get the MLIR type of the tensor dtype given the dtype integer value and
// data type of torch type. When DType is None the type is inferred from the
// data type.
void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, Value dtype,
Type dataType);
/// Incorporates `knowledge` into the lattice state of `v`.
///
/// This method should be used instead of
@ -587,24 +598,21 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context,
/*skipRankCheck=*/true);
}
// Get the MLIR type of the tensor dtype given the dtype integer value and the
// input dtype. When DType is None the type is inferred from the input dtype.
static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
Value dtype,
Type inputDType) {
void TypeAnalyzer::fillInDTypeGivenDTypeIntAndInputDType(
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
int64_t dtypeInt;
if (dtype.getType().isa<Torch::NoneType>())
knowledge.dtype = inputDType;
else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
knowledge.dtype = getTypeForDTypeInteger(dtype.getContext(), dtypeInt);
else if (auto primDtypeOp = dyn_cast<PrimDtypeOp>(dtype.getDefiningOp()))
knowledge.dtype = getLatticeElement(primDtypeOp.a()).getValue().dtype;
}
// Get the MLIR type of the tensor dtype given the dtype integer value and data
// type of torch type. When DType is None the type is inferred from the data
// type.
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
Value dtype, Type dataType) {
void TypeAnalyzer::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
Value dtype,
Type dataType) {
assert(isa<TorchDialect>(dataType.getDialect()) &&
"`dataType` must be a torch type");
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);

View File

@ -196,3 +196,47 @@ func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number {
%1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number
return %1 : !torch.number
}
// -----
// CHECK-LABEL: func.func @prim.dtype(
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor {
// CHECK: %[[zero:.*]] = torch.constant.int 0
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16>
// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int
// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device
// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int
// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device
// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16>
// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor
// CHECK: return %[[cast]] : !torch.vtensor
// CHECK: }
func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> {
%zero = torch.constant.int 0
%false = torch.constant.bool false
// Op that requires type refinement
%neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk>
// Op whose processing requires type refinement on its source argument.
%dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int
%device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device
// Another op that requires type refinement
%result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
// Repeat the above three ops a second time to ensure that the type refinement
// code works regardless of the number of alternating refinement+prim.dtype
// sequences.
%dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int
%device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device
%result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk>
return %result2 : !torch.vtensor<*,unk>
}