mirror of https://github.com/llvm/torch-mlir
Fix type promotion code for scalar only operations
Fix the type promotion code for scalar only operation to return TorchType which is the type tracked in ValueKnowledge.scalarType. - Fix `getPromotedResultScalarType` to return Torch type. - Add `getBuiltInTypeForTorchScalar` helper to convert scalar type to builtin type before passing to the next level type promotion helper `updateResultTypeState`. - Add `setScalarType` helper to make setting ValueKnowledge.scalarType easier.pull/843/head snapshot-20220507.437
parent
b20679e1b8
commit
28be6511d2
|
@ -5650,6 +5650,30 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddOp : Torch_Op<"aten.add", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::add : (Scalar, Scalar) -> (Scalar)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$a,
|
||||
AnyTorchScalarType:$b
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchScalarType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAddOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenAddOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -30,11 +30,18 @@ torch_upstream::ScalarType getScalarTypeForType(Type type);
|
|||
Type getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||
|
||||
Type getTorchTypeForScalarType(MLIRContext *context,
|
||||
torch_upstream::ScalarType dtypeInt);
|
||||
|
||||
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype);
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type dtype);
|
||||
|
||||
bool isBuiltInType(Type type);
|
||||
|
||||
// Helper funtion to get rank of `Base tensor type`.
|
||||
// -1 is returned if the tensorRank can't be determined.
|
||||
int getTensorRank(Value tensor);
|
||||
|
|
|
@ -108,8 +108,9 @@ public:
|
|||
result.returnOp = returnOp;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
copyToNonValueTensor,
|
||||
"unsupported op encountered during abstract analysis");
|
||||
copyToNonValueTensor, "unsupported op `" +
|
||||
user->getName().getStringRef() +
|
||||
"` encountered during abstract analysis");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -114,6 +114,10 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
|
|||
/// on the `dtype` from tensors and can't be used on other types like scalar
|
||||
/// types.
|
||||
static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
|
||||
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
|
||||
assert(isNullOrBuiltIn(lhs) && "`lhs` must be a builtin type");
|
||||
assert(isNullOrBuiltIn(rhs) && "`rhs` must be a builtin type");
|
||||
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
if (!rhs)
|
||||
|
@ -167,6 +171,14 @@ struct ValueKnowledge {
|
|||
: dtype(dtype), scalarType(scalarType), kind(kind),
|
||||
optional(optionalKnowledge) {}
|
||||
|
||||
void setScalarType(Type type) {
|
||||
bool isValidScalarType = type.isa<NumberType, IntType, Torch::FloatType>();
|
||||
assert(isValidScalarType &&
|
||||
"scalarType can only be one of NumberType, IntType and FloatType");
|
||||
scalarType = type;
|
||||
kind = getTypeKind(type);
|
||||
}
|
||||
|
||||
// Get the static knowledge intrinsic to `type`.
|
||||
static ValueKnowledge getKnowledgeFromType(Type type) {
|
||||
ValueKnowledge result = getPessimisticValueState(type.getContext());
|
||||
|
@ -420,7 +432,7 @@ private:
|
|||
// This is the type rule used for deciding dtype for:
|
||||
// 1. A new tensor created from given data.
|
||||
// 2. The scalar type for type promotion when a scalar is an operand of a tensor
|
||||
// and scalar binary operation.
|
||||
// operation (such as AtenMulScalarOp, AtenAddScalarOp etc)
|
||||
// If the data is floating-point, the `dtype` is inferred to be the
|
||||
// default dtype, see `torch.get_default_dtype`.
|
||||
static Type getDefaultDtypeForTorchScalar(Type type) {
|
||||
|
@ -438,12 +450,29 @@ static Type getDefaultDtypeForTorchScalar(Type type) {
|
|||
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
||||
}
|
||||
|
||||
// This is the type rule used for deciding builtin type for:
|
||||
// 1. The dtype of the result tensor when converting a Scalar into a Tensor like
|
||||
// PrimNumToTensorScalarOp.
|
||||
// 2. The scalar type for type promotion when a scalar is an operand of scalar
|
||||
// only operation like AtenAddOp.
|
||||
static Type getBuiltInTypeForTorchScalar(Type type) {
|
||||
MLIRContext *context = type.getContext();
|
||||
if (type.isa<Torch::FloatType>())
|
||||
return Float64Type::get(context);
|
||||
if (type.isa<Torch::IntType>())
|
||||
return IntegerType::get(context, 64, IntegerType::Signed);
|
||||
if (type.isa<Torch::BoolType>())
|
||||
return IntegerType::get(context, 1);
|
||||
llvm_unreachable(
|
||||
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
||||
}
|
||||
|
||||
static torch_upstream::ResultTypeState
|
||||
updateResultTypeState(Type scalarType,
|
||||
const torch_upstream::ResultTypeState &inState) {
|
||||
assert(isBuiltInType(scalarType) && "scalarType must be builtin type");
|
||||
torch_upstream::ResultTypeState new_state = inState;
|
||||
torch_upstream::ScalarType current =
|
||||
getScalarTypeForType(getDefaultDtypeForTorchScalar(scalarType));
|
||||
torch_upstream::ScalarType current = getScalarTypeForType(scalarType);
|
||||
new_state.wrappedResult =
|
||||
promote_skip_undefined(inState.wrappedResult, current);
|
||||
return new_state;
|
||||
|
@ -481,15 +510,22 @@ updateResultTypeState(ValueKnowledge *tensor, Optional<bool> rankIsNonZero,
|
|||
return new_state;
|
||||
}
|
||||
|
||||
static Type getPromotedResultType(ArrayRef<Type> scalarTypes) {
|
||||
// Type promotion helper for operators where only scalar operands participating
|
||||
// in type promotion like AtenAddOp.
|
||||
//
|
||||
// \return The return type is a TorchType.
|
||||
static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
|
||||
torch_upstream::ResultTypeState state = {};
|
||||
for (const Type &scalarType : scalarTypes)
|
||||
state = updateResultTypeState(scalarType, state);
|
||||
return getTypeForScalarType(scalarTypes[0].getContext(), result_type(state));
|
||||
for (const Type &scalarType : scalarTypes) {
|
||||
state =
|
||||
updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state);
|
||||
}
|
||||
return getTorchTypeForScalarType(scalarTypes[0].getContext(),
|
||||
result_type(state));
|
||||
}
|
||||
|
||||
// Returns most generic type Type() if the tensor dtype is unknown.
|
||||
static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
|
||||
static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
|
||||
if (!tensor->dtype)
|
||||
return Type();
|
||||
torch_upstream::ResultTypeState state = {};
|
||||
|
@ -497,7 +533,8 @@ static Type getPromotedResultType(ValueKnowledge *tensor, Type scalarType) {
|
|||
// wrappedResult which is a lower priority than both dimResult and zeroResult.
|
||||
state = updateResultTypeState(tensor, /*rankIsNonZero=*/None, state,
|
||||
/*skipRankCheck=*/true);
|
||||
state = updateResultTypeState(scalarType, state);
|
||||
state =
|
||||
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
|
||||
return getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||
}
|
||||
|
||||
|
@ -550,8 +587,7 @@ getPromotedResultTypeAssumingNonZeroRank(MLIRContext *context,
|
|||
static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
|
||||
Value dtype,
|
||||
Type inputDType) {
|
||||
assert(isa<BuiltinDialect>(inputDType.getDialect()) &&
|
||||
"`inputDType` must be a builtin type");
|
||||
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
||||
int64_t dtypeInt;
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
knowledge.dtype = inputDType;
|
||||
|
@ -692,7 +728,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||
knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
|
@ -712,8 +748,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Value rhsScalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype =
|
||||
getPromotedResultType({lhsScalar.getType(), rhsScalar.getType()});
|
||||
knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType(
|
||||
{lhsScalar.getType(), rhsScalar.getType()}));
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
|
@ -723,7 +759,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Value scalar = op->getOperand(2);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&lhs, scalar.getType());
|
||||
knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
|
@ -733,7 +769,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(getContext());
|
||||
knowledge.dtype = getPromotedResultType(&rhs, scalar.getType());
|
||||
knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType());
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
|
@ -942,7 +978,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
return visitNumToTensorOp(numToTensorOp);
|
||||
}
|
||||
|
||||
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>(op)) {
|
||||
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp, AtenAddOp>(op)) {
|
||||
return visitBinaryScalarOp(op, operands);
|
||||
}
|
||||
|
||||
|
@ -1069,10 +1105,9 @@ ChangeResult TypeAnalyzer::visitBinaryScalarOp(
|
|||
Operation *op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getScalarPessimisticValueState(op->getContext());
|
||||
Type resultType = getPromotedResultType(
|
||||
Type resultType = getPromotedResultScalarType(
|
||||
{op->getOperand(0).getType(), op->getOperand(1).getType()});
|
||||
knowledge.scalarType = resultType;
|
||||
knowledge.kind = getTypeKind(resultType);
|
||||
knowledge.setScalarType(resultType);
|
||||
return incorporateKnowledge(op->getResult(0), knowledge);
|
||||
}
|
||||
|
||||
|
@ -1183,12 +1218,7 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
|||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
|
||||
// `NumToTensor` falls in the latter case.
|
||||
Type type = op.a().getType();
|
||||
if (type.isa<Torch::FloatType>())
|
||||
knowledge.dtype = Float64Type::get(op.getContext());
|
||||
else if (type.isa<Torch::IntType>())
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||
|
||||
knowledge.dtype = getBuiltInTypeForTorchScalar(type);
|
||||
return incorporateKnowledge(op.getResult(), knowledge);
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -78,6 +79,19 @@ Type Torch::getTypeForScalarType(
|
|||
}
|
||||
}
|
||||
|
||||
Type Torch::getTorchTypeForScalarType(MLIRContext *context,
|
||||
torch_upstream::ScalarType dtypeInt) {
|
||||
switch (dtypeInt) {
|
||||
case torch_upstream::ScalarType::Double:
|
||||
return Torch::FloatType::get(context);
|
||||
case torch_upstream::ScalarType::Long:
|
||||
return Torch::IntType::get(context);
|
||||
default:
|
||||
llvm::report_fatal_error(
|
||||
"Unsupported scalar type to Torch type conversion");
|
||||
}
|
||||
}
|
||||
|
||||
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype) {
|
||||
int intType = (int)getScalarTypeForType(dtype);
|
||||
|
@ -100,6 +114,10 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
|||
return converted;
|
||||
}
|
||||
|
||||
bool Torch::isBuiltInType(Type type) {
|
||||
return isa<BuiltinDialect>(type.getDialect());
|
||||
}
|
||||
|
||||
int Torch::getTensorRank(Value tensor) {
|
||||
int tensorRank = -1;
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
|
|
|
@ -518,6 +518,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||
emit("aten::add : (Scalar, Scalar) -> (Scalar)")
|
||||
|
||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
emit("aten::ceil.float : (float) -> (int)", has_folder=True)
|
||||
|
||||
|
|
|
@ -127,6 +127,18 @@ func @type_promotion$alpha_wider(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.v
|
|||
return %0 : !torch.vtensor<[?],unk>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @type_promotion_scalar_operation(
|
||||
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number {
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add %[[FLOAT]], %[[INT]] : !torch.float, !torch.int -> !torch.float
|
||||
// CHECK: %[[RET:.*]] = torch.derefine %[[ADD]] : !torch.float to !torch.number
|
||||
// CHECK: return %[[RET]] : !torch.number
|
||||
func @type_promotion_scalar_operation(%float: !torch.float, %int: !torch.int) -> !torch.number {
|
||||
%ret = torch.aten.add %float, %int : !torch.float, !torch.int -> !torch.number
|
||||
return %ret : !torch.number
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||
// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>,
|
||||
|
|
Loading…
Reference in New Issue