diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8774c77c0..a3904a191 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -363,6 +363,17 @@ public: ArrayRef *> operands) final; private: + // Get the MLIR type of the tensor dtype given the dtype integer value and the + // input dtype. When DType is None the type is inferred from the input dtype. + void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge, + Value dtype, Type inputDType); + + // Get the MLIR type of the tensor dtype given the dtype integer value and + // data type of torch type. When DType is None the type is inferred from the + // data type. + void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, Value dtype, + Type dataType); + /// Incorporates `knowledge` into the lattice state of `v`. /// /// This method should be used instead of @@ -587,24 +598,21 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context, /*skipRankCheck=*/true); } -// Get the MLIR type of the tensor dtype given the dtype integer value and the -// input dtype. When DType is None the type is inferred from the input dtype. -static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge, - Value dtype, - Type inputDType) { +void TypeAnalyzer::fillInDTypeGivenDTypeIntAndInputDType( + ValueKnowledge &knowledge, Value dtype, Type inputDType) { assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); int64_t dtypeInt; if (dtype.getType().isa()) knowledge.dtype = inputDType; else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) knowledge.dtype = getTypeForDTypeInteger(dtype.getContext(), dtypeInt); + else if (auto primDtypeOp = dyn_cast(dtype.getDefiningOp())) + knowledge.dtype = getLatticeElement(primDtypeOp.a()).getValue().dtype; } -// Get the MLIR type of the tensor dtype given the dtype integer value and data -// type of torch type. When DType is None the type is inferred from the data -// type. -static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, - Value dtype, Type dataType) { +void TypeAnalyzer::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, + Value dtype, + Type dataType) { assert(isa(dataType.getDialect()) && "`dataType` must be a torch type"); Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType); diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index eadefbfa6..38ce507fe 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -196,3 +196,47 @@ func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number return %1 : !torch.number } + +// ----- +// CHECK-LABEL: func.func @prim.dtype( +// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor { + +// CHECK: %[[zero:.*]] = torch.constant.int 0 +// CHECK: %[[false:.*]] = torch.constant.bool false + +// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16> +// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int +// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device +// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> + +// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int +// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device +// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> + +// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor +// CHECK: return %[[cast]] : !torch.vtensor +// CHECK: } + +func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { + %zero = torch.constant.int 0 + %false = torch.constant.bool false + + // Op that requires type refinement + %neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk> + + // Op whose processing requires type refinement on its source argument. + %dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int + %device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device + + // Another op that requires type refinement + %result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> + + // Repeat the above three ops a second time to ensure that the type refinement + // code works regardless of the number of alternating refinement+prim.dtype + // sequences. + %dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int + %device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device + %result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> + + return %result2 : !torch.vtensor<*,unk> +}