[Torch] fix toBuiltinTensor() (#3415)

* Let `toBuiltinTensor()` reflects the original dtype of
`!torch.vtensor`.
* Backend handles dtype conversion themselves.
pull/3437/head
Yuanqiang Liu 2024-06-08 09:36:32 +08:00 committed by GitHub
parent 75af64fc12
commit 689efc8917
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 60 additions and 56 deletions

View File

@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::lowest()))
return failure();
auto minSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, minValue));
min = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, minSplatAttr);
@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::max()))
return failure();
auto maxSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, maxValue));
max = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, maxSplatAttr);
@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_float") &&
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(dtype, floatValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_int") &&
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getIntegerAttr(dtype, intValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr);
@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
for (auto intVal : intValues) {
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
}
auto attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(dtype), apValues);
auto attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
@ -2272,9 +2272,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// Extract the fill value and dtype
// ONNX requires value attr to be a tensor
if (!attr) {
attr = DenseElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDType),
rewriter.getFloatAttr(resultDType, 0.0));
attr =
DenseElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDType, 0.0));
}
// If its a dense resource attr we need to convert to a dense type:

View File

@ -146,12 +146,11 @@ public:
"mismatching contracting dimension for torch.aten.mm"));
}
auto resultTy = cast<ValueTensorType>(op.getType());
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = cast<TensorType>(newResultType).getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) {
TensorType resultType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Type elementType = resultType.getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, elementType);
if (accumulatorDType != resultType.getElementType()) {
elementType = accumulatorDType;
}
Value zeroFill = createZeroInitTensor(
@ -197,18 +196,16 @@ public:
.getResult(0);
}
if (accumulatorDType != resultDTy) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
if (accumulatorDType != resultType.getElementType()) {
matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType);
rewriter, loc, matmul, resultType.getElementType());
}
// When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result
// of the MatmulOp will have this type too. So cast it to the desired type
// so that in the end we have the original result type.
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmul);
return success();
}

View File

@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
return nullptr;
auto dty = resultTy.getDtype();
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);
auto resultBTy = resultTy.toBuiltinTensor();
auto fpTy = dyn_cast<mlir::FloatType>(dty);
auto intTy = dyn_cast<mlir::IntegerType>(dty);
@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) {
if (!ty || !ty.hasDtype() || !ty.hasSizes())
return nullptr;
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
auto bty = ty.toBuiltinTensor();
if (!bty.hasStaticShape())
return nullptr;
@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
return nullptr;
auto ctx = lhs.getContext();
auto resultETy = resultTy.getDtype();
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
if (lhs.isSplat()) {
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
resultAP);
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
}
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
auto resultBool =
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
resultAP);
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
}
return nullptr;
}
@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
values.push_back(resultBool);
}
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
values);
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
}
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
values.push_back(resultBool);
}
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
values);
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
}
return nullptr;
@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
if (!fpTy && !intTy)
return nullptr;
auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
auto resultBTy = resultTy.toBuiltinTensor();
bool splat = operand.isSplat();
bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
auto selfTy = cast<ShapedType>(self.getType());
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
auto bty = ty.toBuiltinTensor();
if (!bty.hasStaticShape())
return nullptr;
@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
if (!indicesTensorType.hasDtype())
return failure();
auto indicesType =
indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype());
auto indicesType = indicesTensorType.toBuiltinTensor();
if (!indicesType || !indicesType.hasStaticShape())
return failure();
@ -3612,9 +3606,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
return nullptr;
if (input && input.isSplat())
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()),
input.getSplatValue<Attribute>());
return DenseElementsAttr::get(outType.toBuiltinTensor(),
input.getSplatValue<Attribute>());
int count = 1;
for (auto dim : outType.getSizes())
@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
for (int i = begin; i < limit; i += stride)
values.push_back(input.getValues<Attribute>()[i]);
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), values);
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
}
// If the input and output shapes are the same we can just fold:
@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
ShapedType shapedTy = resultTy.toBuiltinTensor();
SmallVector<int64_t> data;
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
ShapedType shapedTy = resultTy.toBuiltinTensor();
int64_t data;
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
ShapedType shapedTy = resultTy.toBuiltinTensor();
double data;
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
: selfAttr.getValues<Attribute>()[indexInt];
auto dty = resultTy.getDtype();
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
auto attrTy = resultTy.toBuiltinTensor();
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
return DenseElementsAttr::get(
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!valueDense.isSplat())
return nullptr;
auto splattr = valueDense.getSplatValue<Attribute>();
auto attrty = ty.toBuiltinTensor().clone(dty);
auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, splattr);
}
@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!isa<mlir::IntegerType>(dty))
return nullptr;
int64_t intval = intAttr.getInt();
auto attrty = ty.toBuiltinTensor().clone(dty);
auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
}
@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!isa<mlir::FloatType>(dty))
return nullptr;
double dblval = fpAttr.getValueAsDouble();
auto attrty = ty.toBuiltinTensor().clone(dty);
auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
}

View File

@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
}
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
return dtype;
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
} else if (isa<mlir::ComplexType>(dtype)) {
if (isa<mlir::FloatType, IntegerType, mlir::ComplexType>(dtype)) {
return dtype;
}
@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
if (!elementType)
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(elementType);
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
getOptionalSparsity());
}

View File

@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) {
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor();
auto builtinType = type.toBuiltinTensor();
if (!builtinType)
return std::nullopt;
// convert any integer type to signless
if (type.getDtype().isInteger()) {
return builtinType.clone(IntegerType::get(
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
IntegerType::Signless));
}
return builtinType;
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
auto builtinType = type.toBuiltinTensor();
if (!builtinType)
return std::nullopt;
// convert signed integer type to signless, keep unsigned as unsigned
if (type.getDtype().isUnsignedInteger()) {
return builtinType.clone(type.getDtype());
} else if (type.getDtype().isSignedInteger()) {
return builtinType.clone(IntegerType::get(
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
IntegerType::Signless));
}
return builtinType;
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,