[Torch] fix toBuiltinTensor() (#3415)

* Let `toBuiltinTensor()` reflects the original dtype of
`!torch.vtensor`.
* Backend handles dtype conversion themselves.
pull/3437/head
Yuanqiang Liu 2024-06-08 09:36:32 +08:00 committed by GitHub
parent 75af64fc12
commit 689efc8917
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 60 additions and 56 deletions

View File

@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::lowest())) std::numeric_limits<float>::lowest()))
return failure(); return failure();
auto minSplatAttr = SplatElementsAttr::get( auto minSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype), resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, minValue)); rewriter.getFloatAttr(resultDtype, minValue));
min = rewriter.create<Torch::ValueTensorLiteralOp>( min = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, minSplatAttr); binder.getLoc(), resultType, minSplatAttr);
@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
std::numeric_limits<float>::max())) std::numeric_limits<float>::max()))
return failure(); return failure();
auto maxSplatAttr = SplatElementsAttr::get( auto maxSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype), resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDtype, maxValue)); rewriter.getFloatAttr(resultDtype, maxValue));
max = rewriter.create<Torch::ValueTensorLiteralOp>( max = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, maxSplatAttr); binder.getLoc(), resultType, maxSplatAttr);
@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_float") && if (binder.op->hasAttr("torch.onnx.value_float") &&
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) { !binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
auto splatAttr = auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(dtype, floatValue)); rewriter.getFloatAttr(dtype, floatValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>( rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr); binder.op, resultType, splatAttr);
@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (binder.op->hasAttr("torch.onnx.value_int") && if (binder.op->hasAttr("torch.onnx.value_int") &&
!binder.s64IntegerAttr(intValue, "value_int", 0)) { !binder.s64IntegerAttr(intValue, "value_int", 0)) {
auto splatAttr = auto splatAttr =
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), SplatElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getIntegerAttr(dtype, intValue)); rewriter.getIntegerAttr(dtype, intValue));
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>( rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, splatAttr); binder.op, resultType, splatAttr);
@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
for (auto intVal : intValues) { for (auto intVal : intValues) {
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
} }
auto attr = DenseElementsAttr::get( auto attr =
resultType.toBuiltinTensor().clone(dtype), apValues); DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>( rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr); binder.op, resultType, attr);
return success(); return success();
@ -2272,8 +2272,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
// Extract the fill value and dtype // Extract the fill value and dtype
// ONNX requires value attr to be a tensor // ONNX requires value attr to be a tensor
if (!attr) { if (!attr) {
attr = DenseElementsAttr::get( attr =
resultType.toBuiltinTensor().clone(resultDType), DenseElementsAttr::get(resultType.toBuiltinTensor(),
rewriter.getFloatAttr(resultDType, 0.0)); rewriter.getFloatAttr(resultDType, 0.0));
} }

View File

@ -146,12 +146,11 @@ public:
"mismatching contracting dimension for torch.aten.mm")); "mismatching contracting dimension for torch.aten.mm"));
} }
auto resultTy = cast<ValueTensorType>(op.getType()); TensorType resultType =
auto resultDTy = resultTy.toBuiltinTensor().getElementType(); cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = resultType.getElementType();
Type elementType = cast<TensorType>(newResultType).getElementType(); auto accumulatorDType = getDefaultAccType(rewriter, elementType);
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); if (accumulatorDType != resultType.getElementType()) {
if (accumulatorDType != resultDTy) {
elementType = accumulatorDType; elementType = accumulatorDType;
} }
Value zeroFill = createZeroInitTensor( Value zeroFill = createZeroInitTensor(
@ -197,18 +196,16 @@ public:
.getResult(0); .getResult(0);
} }
if (accumulatorDType != resultDTy) { if (accumulatorDType != resultType.getElementType()) {
Type resultElementType =
cast<RankedTensorType>(newResultType).getElementType();
matmul = torch_to_linalg::convertTensorToElementType( matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType); rewriter, loc, matmul, resultType.getElementType());
} }
// When constructed with just dynamic sizes, EmptyOp will have a result // When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result // type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result // type of `op`. The constraints on later linalg ops means that the result
// of the MatmulOp will have this type too. So cast it to the desired type // of the MatmulOp will have this type too. So cast it to the desired type
// so that in the end we have the original result type. // so that in the end we have the original result type.
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmul);
return success(); return success();
} }

View File

@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
return nullptr; return nullptr;
auto dty = resultTy.getDtype(); auto dty = resultTy.getDtype();
auto resultBTy = resultTy.toBuiltinTensor().clone(dty); auto resultBTy = resultTy.toBuiltinTensor();
auto fpTy = dyn_cast<mlir::FloatType>(dty); auto fpTy = dyn_cast<mlir::FloatType>(dty);
auto intTy = dyn_cast<mlir::IntegerType>(dty); auto intTy = dyn_cast<mlir::IntegerType>(dty);
@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) {
if (!ty || !ty.hasDtype() || !ty.hasSizes()) if (!ty || !ty.hasDtype() || !ty.hasSizes())
return nullptr; return nullptr;
auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); auto bty = ty.toBuiltinTensor();
if (!bty.hasStaticShape()) if (!bty.hasStaticShape())
return nullptr; return nullptr;
@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
return nullptr; return nullptr;
auto ctx = lhs.getContext(); auto ctx = lhs.getContext();
auto resultETy = resultTy.getDtype();
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType(); auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
if (lhs.isSplat()) { if (lhs.isSplat()) {
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) { if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign);
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
resultAP);
} }
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) { if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
auto resultBool = auto resultBool =
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
resultAP);
} }
return nullptr; return nullptr;
} }
@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign);
values.push_back(resultBool); values.push_back(resultBool);
} }
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
values);
} }
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) { if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
values.push_back(resultBool); values.push_back(resultBool);
} }
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
values);
} }
return nullptr; return nullptr;
@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
if (!fpTy && !intTy) if (!fpTy && !intTy)
return nullptr; return nullptr;
auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); auto resultBTy = resultTy.toBuiltinTensor();
bool splat = operand.isSplat(); bool splat = operand.isSplat();
bool withinMaxFold = bool withinMaxFold =
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
auto selfTy = cast<ShapedType>(self.getType()); auto selfTy = cast<ShapedType>(self.getType());
auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); auto bty = ty.toBuiltinTensor();
if (!bty.hasStaticShape()) if (!bty.hasStaticShape())
return nullptr; return nullptr;
@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
if (!indicesTensorType.hasDtype()) if (!indicesTensorType.hasDtype())
return failure(); return failure();
auto indicesType = auto indicesType = indicesTensorType.toBuiltinTensor();
indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype());
if (!indicesType || !indicesType.hasStaticShape()) if (!indicesType || !indicesType.hasStaticShape())
return failure(); return failure();
@ -3612,8 +3606,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
if (input && input.isSplat()) if (input && input.isSplat())
return DenseElementsAttr::get( return DenseElementsAttr::get(outType.toBuiltinTensor(),
outType.toBuiltinTensor().clone(inType.getDtype()),
input.getSplatValue<Attribute>()); input.getSplatValue<Attribute>());
int count = 1; int count = 1;
@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
for (int i = begin; i < limit; i += stride) for (int i = begin; i < limit; i += stride)
values.push_back(input.getValues<Attribute>()[i]); values.push_back(input.getValues<Attribute>()[i]);
return DenseElementsAttr::get( return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
outType.toBuiltinTensor().clone(inType.getDtype()), values);
} }
// If the input and output shapes are the same we can just fold: // If the input and output shapes are the same we can just fold:
@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr; return nullptr;
Type eTy = resultTy.getDtype(); Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); ShapedType shapedTy = resultTy.toBuiltinTensor();
SmallVector<int64_t> data; SmallVector<int64_t> data;
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr; return nullptr;
Type eTy = resultTy.getDtype(); Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); ShapedType shapedTy = resultTy.toBuiltinTensor();
int64_t data; int64_t data;
if (matchPattern(getT(), m_TorchConstantInt(&data))) { if (matchPattern(getT(), m_TorchConstantInt(&data))) {
@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr; return nullptr;
Type eTy = resultTy.getDtype(); Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); ShapedType shapedTy = resultTy.toBuiltinTensor();
double data; double data;
if (matchPattern(getT(), m_TorchConstantFloat(&data))) { if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
: selfAttr.getValues<Attribute>()[indexInt]; : selfAttr.getValues<Attribute>()[indexInt];
auto dty = resultTy.getDtype(); auto dty = resultTy.getDtype();
auto attrTy = resultTy.toBuiltinTensor().clone(dty); auto attrTy = resultTy.toBuiltinTensor();
if (auto floatAttr = dyn_cast<FloatAttr>(splattr)) if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
return DenseElementsAttr::get( return DenseElementsAttr::get(
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!valueDense.isSplat()) if (!valueDense.isSplat())
return nullptr; return nullptr;
auto splattr = valueDense.getSplatValue<Attribute>(); auto splattr = valueDense.getSplatValue<Attribute>();
auto attrty = ty.toBuiltinTensor().clone(dty); auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, splattr); return DenseElementsAttr::get(attrty, splattr);
} }
@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!isa<mlir::IntegerType>(dty)) if (!isa<mlir::IntegerType>(dty))
return nullptr; return nullptr;
int64_t intval = intAttr.getInt(); int64_t intval = intAttr.getInt();
auto attrty = ty.toBuiltinTensor().clone(dty); auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
} }
@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!isa<mlir::FloatType>(dty)) if (!isa<mlir::FloatType>(dty))
return nullptr; return nullptr;
double dblval = fpAttr.getValueAsDouble(); double dblval = fpAttr.getValueAsDouble();
auto attrty = ty.toBuiltinTensor().clone(dty); auto attrty = ty.toBuiltinTensor();
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
} }

View File

@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
} }
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) { if (isa<mlir::FloatType, IntegerType, mlir::ComplexType>(dtype)) {
return dtype;
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
} else if (isa<mlir::ComplexType>(dtype)) {
return dtype; return dtype;
} }
@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
TensorType ValueTensorType::toBuiltinTensor() const { TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype()) if (!hasDtype())
return nullptr; return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
if (!elementType) if (!elementType)
return nullptr; return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(elementType);
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
getOptionalSparsity()); getOptionalSparsity());
} }

View File

@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) { ConversionTarget &target, TypeConverter &typeConverter) {
auto valueTensorTypeConversion = auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> { [](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor(); auto builtinType = type.toBuiltinTensor();
if (!builtinType)
return std::nullopt;
// convert any integer type to signless
if (type.getDtype().isInteger()) {
return builtinType.clone(IntegerType::get(
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
IntegerType::Signless));
}
return builtinType;
}; };
setupValueTensorToBuiltinTensorConversion(target, typeConverter, setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion); valueTensorTypeConversion);
@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
auto valueTensorTypeConversion = auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> { [](Torch::ValueTensorType type) -> std::optional<Type> {
auto builtinType = type.toBuiltinTensor(); auto builtinType = type.toBuiltinTensor();
if (!builtinType)
return std::nullopt;
// convert signed integer type to signless, keep unsigned as unsigned
if (type.getDtype().isUnsignedInteger()) { if (type.getDtype().isUnsignedInteger()) {
return builtinType.clone(type.getDtype()); return builtinType.clone(type.getDtype());
} else if (type.getDtype().isSignedInteger()) {
return builtinType.clone(IntegerType::get(
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
IntegerType::Signless));
} }
return builtinType; return builtinType;
}; };
setupValueTensorToBuiltinTensorConversion(target, typeConverter, setupValueTensorToBuiltinTensorConversion(target, typeConverter,