mirror of https://github.com/llvm/torch-mlir
[Torch] fix toBuiltinTensor() (#3415)
* Let `toBuiltinTensor()` reflects the original dtype of `!torch.vtensor`. * Backend handles dtype conversion themselves.pull/3437/head
parent
75af64fc12
commit
689efc8917
|
@ -737,7 +737,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
std::numeric_limits<float>::lowest()))
|
||||
return failure();
|
||||
auto minSplatAttr = SplatElementsAttr::get(
|
||||
resultType.toBuiltinTensor().clone(resultDtype),
|
||||
resultType.toBuiltinTensor(),
|
||||
rewriter.getFloatAttr(resultDtype, minValue));
|
||||
min = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||
binder.getLoc(), resultType, minSplatAttr);
|
||||
|
@ -748,7 +748,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
std::numeric_limits<float>::max()))
|
||||
return failure();
|
||||
auto maxSplatAttr = SplatElementsAttr::get(
|
||||
resultType.toBuiltinTensor().clone(resultDtype),
|
||||
resultType.toBuiltinTensor(),
|
||||
rewriter.getFloatAttr(resultDtype, maxValue));
|
||||
max = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||
binder.getLoc(), resultType, maxSplatAttr);
|
||||
|
@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
if (binder.op->hasAttr("torch.onnx.value_float") &&
|
||||
!binder.f32FloatAttr(floatValue, "value_float", 0.0)) {
|
||||
auto splatAttr =
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor(),
|
||||
rewriter.getFloatAttr(dtype, floatValue));
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, splatAttr);
|
||||
|
@ -872,7 +872,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
if (binder.op->hasAttr("torch.onnx.value_int") &&
|
||||
!binder.s64IntegerAttr(intValue, "value_int", 0)) {
|
||||
auto splatAttr =
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype),
|
||||
SplatElementsAttr::get(resultType.toBuiltinTensor(),
|
||||
rewriter.getIntegerAttr(dtype, intValue));
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, splatAttr);
|
||||
|
@ -932,8 +932,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
for (auto intVal : intValues) {
|
||||
apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal));
|
||||
}
|
||||
auto attr = DenseElementsAttr::get(
|
||||
resultType.toBuiltinTensor().clone(dtype), apValues);
|
||||
auto attr =
|
||||
DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues);
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, attr);
|
||||
return success();
|
||||
|
@ -2272,9 +2272,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
// Extract the fill value and dtype
|
||||
// ONNX requires value attr to be a tensor
|
||||
if (!attr) {
|
||||
attr = DenseElementsAttr::get(
|
||||
resultType.toBuiltinTensor().clone(resultDType),
|
||||
rewriter.getFloatAttr(resultDType, 0.0));
|
||||
attr =
|
||||
DenseElementsAttr::get(resultType.toBuiltinTensor(),
|
||||
rewriter.getFloatAttr(resultDType, 0.0));
|
||||
}
|
||||
|
||||
// If its a dense resource attr we need to convert to a dense type:
|
||||
|
|
|
@ -146,12 +146,11 @@ public:
|
|||
"mismatching contracting dimension for torch.aten.mm"));
|
||||
}
|
||||
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = cast<TensorType>(newResultType).getElementType();
|
||||
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
|
||||
if (accumulatorDType != resultDTy) {
|
||||
TensorType resultType =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
Type elementType = resultType.getElementType();
|
||||
auto accumulatorDType = getDefaultAccType(rewriter, elementType);
|
||||
if (accumulatorDType != resultType.getElementType()) {
|
||||
elementType = accumulatorDType;
|
||||
}
|
||||
Value zeroFill = createZeroInitTensor(
|
||||
|
@ -197,18 +196,16 @@ public:
|
|||
.getResult(0);
|
||||
}
|
||||
|
||||
if (accumulatorDType != resultDTy) {
|
||||
Type resultElementType =
|
||||
cast<RankedTensorType>(newResultType).getElementType();
|
||||
if (accumulatorDType != resultType.getElementType()) {
|
||||
matmul = torch_to_linalg::convertTensorToElementType(
|
||||
rewriter, loc, matmul, resultElementType);
|
||||
rewriter, loc, matmul, resultType.getElementType());
|
||||
}
|
||||
// When constructed with just dynamic sizes, EmptyOp will have a result
|
||||
// type which has all `?`'s for dimensions, which might not be the result
|
||||
// type of `op`. The constraints on later linalg ops means that the result
|
||||
// of the MatmulOp will have this type too. So cast it to the desired type
|
||||
// so that in the end we have the original result type.
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmul);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1311,7 +1311,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
|
|||
return nullptr;
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
auto resultBTy = resultTy.toBuiltinTensor().clone(dty);
|
||||
auto resultBTy = resultTy.toBuiltinTensor();
|
||||
|
||||
auto fpTy = dyn_cast<mlir::FloatType>(dty);
|
||||
auto intTy = dyn_cast<mlir::IntegerType>(dty);
|
||||
|
@ -1521,7 +1521,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) {
|
|||
if (!ty || !ty.hasDtype() || !ty.hasSizes())
|
||||
return nullptr;
|
||||
|
||||
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
|
||||
auto bty = ty.toBuiltinTensor();
|
||||
if (!bty.hasStaticShape())
|
||||
return nullptr;
|
||||
|
||||
|
@ -1635,7 +1635,6 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
return nullptr;
|
||||
|
||||
auto ctx = lhs.getContext();
|
||||
auto resultETy = resultTy.getDtype();
|
||||
auto tensorETy = cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||
if (lhs.isSplat()) {
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(rhs)) {
|
||||
|
@ -1647,8 +1646,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign);
|
||||
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
||||
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
||||
resultAP);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
||||
}
|
||||
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
||||
|
@ -1657,8 +1655,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
auto resultBool =
|
||||
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
||||
auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
||||
resultAP);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1681,8 +1678,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
auto resultBool = intFolder(tensorAP, scalarAP, unsign);
|
||||
values.push_back(resultBool);
|
||||
}
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
||||
values);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
||||
}
|
||||
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(rhs)) {
|
||||
|
@ -1693,8 +1689,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs,
|
|||
fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble());
|
||||
values.push_back(resultBool);
|
||||
}
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy),
|
||||
values);
|
||||
return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -1844,7 +1839,7 @@ static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand,
|
|||
if (!fpTy && !intTy)
|
||||
return nullptr;
|
||||
|
||||
auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype());
|
||||
auto resultBTy = resultTy.toBuiltinTensor();
|
||||
bool splat = operand.isSplat();
|
||||
bool withinMaxFold =
|
||||
resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold;
|
||||
|
@ -2192,7 +2187,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
|
||||
auto selfTy = cast<ShapedType>(self.getType());
|
||||
auto bty = ty.toBuiltinTensor().clone(ty.getDtype());
|
||||
auto bty = ty.toBuiltinTensor();
|
||||
if (!bty.hasStaticShape())
|
||||
return nullptr;
|
||||
|
||||
|
@ -2656,8 +2651,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor,
|
|||
|
||||
if (!indicesTensorType.hasDtype())
|
||||
return failure();
|
||||
auto indicesType =
|
||||
indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype());
|
||||
auto indicesType = indicesTensorType.toBuiltinTensor();
|
||||
if (!indicesType || !indicesType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
|
@ -3612,9 +3606,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
|
||||
if (input && input.isSplat())
|
||||
return DenseElementsAttr::get(
|
||||
outType.toBuiltinTensor().clone(inType.getDtype()),
|
||||
input.getSplatValue<Attribute>());
|
||||
return DenseElementsAttr::get(outType.toBuiltinTensor(),
|
||||
input.getSplatValue<Attribute>());
|
||||
|
||||
int count = 1;
|
||||
for (auto dim : outType.getSizes())
|
||||
|
@ -3652,8 +3645,7 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
for (int i = begin; i < limit; i += stride)
|
||||
values.push_back(input.getValues<Attribute>()[i]);
|
||||
|
||||
return DenseElementsAttr::get(
|
||||
outType.toBuiltinTensor().clone(inType.getDtype()), values);
|
||||
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
||||
}
|
||||
|
||||
// If the input and output shapes are the same we can just fold:
|
||||
|
@ -3923,7 +3915,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
|
|||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||
return nullptr;
|
||||
Type eTy = resultTy.getDtype();
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||
|
||||
SmallVector<int64_t> data;
|
||||
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
|
||||
|
@ -3944,7 +3936,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) {
|
|||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||
return nullptr;
|
||||
Type eTy = resultTy.getDtype();
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||
|
||||
int64_t data;
|
||||
if (matchPattern(getT(), m_TorchConstantInt(&data))) {
|
||||
|
@ -3964,7 +3956,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) {
|
|||
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||
return nullptr;
|
||||
Type eTy = resultTy.getDtype();
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor();
|
||||
|
||||
double data;
|
||||
if (matchPattern(getT(), m_TorchConstantFloat(&data))) {
|
||||
|
@ -4137,7 +4129,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
|||
: selfAttr.getValues<Attribute>()[indexInt];
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
|
||||
auto attrTy = resultTy.toBuiltinTensor();
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
||||
return DenseElementsAttr::get(
|
||||
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
||||
|
@ -4330,7 +4322,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
|||
if (!valueDense.isSplat())
|
||||
return nullptr;
|
||||
auto splattr = valueDense.getSplatValue<Attribute>();
|
||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
||||
auto attrty = ty.toBuiltinTensor();
|
||||
return DenseElementsAttr::get(attrty, splattr);
|
||||
}
|
||||
|
||||
|
@ -4338,7 +4330,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
|||
if (!isa<mlir::IntegerType>(dty))
|
||||
return nullptr;
|
||||
int64_t intval = intAttr.getInt();
|
||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
||||
auto attrty = ty.toBuiltinTensor();
|
||||
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
|
||||
}
|
||||
|
||||
|
@ -4346,7 +4338,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
|
|||
if (!isa<mlir::FloatType>(dty))
|
||||
return nullptr;
|
||||
double dblval = fpAttr.getValueAsDouble();
|
||||
auto attrty = ty.toBuiltinTensor().clone(dty);
|
||||
auto attrty = ty.toBuiltinTensor();
|
||||
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
|
||||
}
|
||||
|
||||
|
|
|
@ -453,12 +453,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
|||
}
|
||||
|
||||
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
||||
if (auto floatType = dyn_cast<mlir::FloatType>(dtype)) {
|
||||
return dtype;
|
||||
} else if (auto integerType = dyn_cast<IntegerType>(dtype)) {
|
||||
return IntegerType::get(context, integerType.getWidth(),
|
||||
IntegerType::Signless);
|
||||
} else if (isa<mlir::ComplexType>(dtype)) {
|
||||
if (isa<mlir::FloatType, IntegerType, mlir::ComplexType>(dtype)) {
|
||||
return dtype;
|
||||
}
|
||||
|
||||
|
@ -480,11 +475,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
|||
TensorType ValueTensorType::toBuiltinTensor() const {
|
||||
if (!hasDtype())
|
||||
return nullptr;
|
||||
if (!hasSizes())
|
||||
return UnrankedTensorType::get(getDtype());
|
||||
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
|
||||
if (!elementType)
|
||||
return nullptr;
|
||||
if (!hasSizes())
|
||||
return UnrankedTensorType::get(elementType);
|
||||
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
|
||||
getOptionalSparsity());
|
||||
}
|
||||
|
|
|
@ -164,7 +164,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
|||
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||
auto valueTensorTypeConversion =
|
||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||
return type.toBuiltinTensor();
|
||||
auto builtinType = type.toBuiltinTensor();
|
||||
if (!builtinType)
|
||||
return std::nullopt;
|
||||
|
||||
// convert any integer type to signless
|
||||
if (type.getDtype().isInteger()) {
|
||||
return builtinType.clone(IntegerType::get(
|
||||
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
|
||||
IntegerType::Signless));
|
||||
}
|
||||
|
||||
return builtinType;
|
||||
};
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||
valueTensorTypeConversion);
|
||||
|
@ -180,9 +191,18 @@ void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
|
|||
auto valueTensorTypeConversion =
|
||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||
auto builtinType = type.toBuiltinTensor();
|
||||
if (!builtinType)
|
||||
return std::nullopt;
|
||||
|
||||
// convert signed integer type to signless, keep unsigned as unsigned
|
||||
if (type.getDtype().isUnsignedInteger()) {
|
||||
return builtinType.clone(type.getDtype());
|
||||
} else if (type.getDtype().isSignedInteger()) {
|
||||
return builtinType.clone(IntegerType::get(
|
||||
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
|
||||
IntegerType::Signless));
|
||||
}
|
||||
|
||||
return builtinType;
|
||||
};
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||
|
|
Loading…
Reference in New Issue