mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in https://github.com/llvm/torch-mlir/pull/2184. The things I do in this PR: 1. create `setupBackendTypeConversionForStablehlo()`, `createFuncBackendTypeConversionForStablehloPass` and `createFinalizingBackendTypeConversionForStablehloPass`. 2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`, because it's different result type between linalg backend and stablehlo backend: ``` // linalg backend func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> { %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8> %0 = tensor.empty() : tensor<3xf32> %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) { ^bb0(%in: i8, %out: f32): %2 = arith.uitofp %in : i8 to f32 linalg.yield %2 : f32 } -> tensor<3xf32> return %1 : tensor<3xf32> } // stablehlo backend func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> { %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8> %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32> return %0 : tensor<3xf32> } ``` 3. fix stablehlo and linalg's conversionpull/3414/head
parent
56d21cba62
commit
50f7103098
|
@ -1 +1 @@
|
||||||
Subproject commit c44d9af8d4879adccf1054cb61a53377ae5898cb
|
Subproject commit 25d237f6273361bb29e8436349c7067ee559dca2
|
|
@ -25,9 +25,7 @@ class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
|
||||||
// Conversions to backend types.
|
// Conversions to backend types.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [
|
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> {
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
||||||
]> {
|
|
||||||
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||||
|
|
|
@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry ®istry);
|
||||||
/// boundary (which currently consist only of builtin types).
|
/// boundary (which currently consist only of builtin types).
|
||||||
void setupBackendTypeConversion(ConversionTarget &target,
|
void setupBackendTypeConversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter);
|
TypeConverter &typeConverter);
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
|
||||||
|
TypeConverter &typeConverter);
|
||||||
|
#endif
|
||||||
} // namespace TorchConversion
|
} // namespace TorchConversion
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -48,6 +48,13 @@ struct StablehloBackendPipelineOptions
|
||||||
|
|
||||||
void createTorchBackendToStablehloBackendPipeline(
|
void createTorchBackendToStablehloBackendPipeline(
|
||||||
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createFuncBackendTypeConversionForStablehloPass();
|
||||||
|
|
||||||
|
std::unique_ptr<InterfacePass<FunctionOpInterface>>
|
||||||
|
createFinalizingBackendTypeConversionForStablehloPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyStablehloBackendContractPass();
|
createVerifyStablehloBackendContractPass();
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> {
|
||||||
|
let summary = "Convert functions to operate on builtin tensors for stablehlo backend";
|
||||||
|
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()";
|
||||||
|
let description = [{
|
||||||
|
Partial type conversion pass analogous in scope to the upstream
|
||||||
|
`func-bufferize` pass. See details there.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
def FinalizingBackendTypeConversion
|
def FinalizingBackendTypeConversion
|
||||||
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
|
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
|
||||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||||
|
@ -32,6 +43,19 @@ def FinalizingBackendTypeConversion
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
def FinalizingBackendTypeConversionForStablehlo
|
||||||
|
: InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> {
|
||||||
|
let summary = "Finalizes a partial conversion to builtin tensors for stablehlo";
|
||||||
|
let constructor =
|
||||||
|
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()";
|
||||||
|
let description = [{
|
||||||
|
Analogous in scope to the upstream `finalizing-bufferize` pass.
|
||||||
|
See details there.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||||
|
|
|
@ -1197,6 +1197,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
Value input = payloadArgs[0];
|
Value input = payloadArgs[0];
|
||||||
|
Type inputElementType =
|
||||||
|
cast<BaseTensorType>(atenToDtype.getSelf().getType()).getDtype();
|
||||||
Type dtype =
|
Type dtype =
|
||||||
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
|
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
@ -1215,7 +1217,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
resultElementType = *maybeResultElementType;
|
resultElementType = *maybeResultElementType;
|
||||||
Value result = convertScalarToDtype(b, loc, input, dtype,
|
Value result = convertScalarToDtype(b, loc, input, dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/inputElementType,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -277,8 +277,8 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
||||||
if (!inputType)
|
if (!inputType)
|
||||||
|
|
||||||
op.emitError("only Tensor types supported in StableHLO");
|
op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getA();
|
Value input = adaptor.getA();
|
||||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
@ -290,14 +290,24 @@ public:
|
||||||
for (int64_t i = 0; i < inputRank; i++)
|
for (int64_t i = 0; i < inputRank; i++)
|
||||||
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||||
|
|
||||||
|
// handle unsigned interger
|
||||||
|
if (inputType.getElementType().isUnsignedInteger()) {
|
||||||
|
input = rewriter.create<stablehlo::ConvertOp>(
|
||||||
|
loc, input,
|
||||||
|
rewriter.getIntegerType(
|
||||||
|
inputType.getElementType().getIntOrFloatBitWidth()));
|
||||||
|
}
|
||||||
|
|
||||||
Value constantZero =
|
Value constantZero =
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||||
SmallVector<Value> indices(inputRank, constantZero);
|
SmallVector<Value> indices(inputRank, constantZero);
|
||||||
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
Type resultType =
|
Type resultType =
|
||||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
rewriter.replaceOp(
|
||||||
resultType, inputDtype));
|
op,
|
||||||
|
convertScalarToDtype(rewriter, loc, result, resultType, inputDtype,
|
||||||
|
/*srcOriginalDtype=*/inputType.getElementType()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -900,7 +900,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
||||||
updateWindowDims.push_back(i);
|
updateWindowDims.push_back(i);
|
||||||
}
|
}
|
||||||
llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n";
|
|
||||||
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
/*updateWindowDims=*/updateWindowDims,
|
/*updateWindowDims=*/updateWindowDims,
|
||||||
|
|
|
@ -51,7 +51,8 @@ public:
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||||
|
typeConverter);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
|
||||||
if (lhs.hasRank() != rhs.hasRank())
|
if (lhs.hasRank() != rhs.hasRank())
|
||||||
return false;
|
return false;
|
||||||
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
|
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
|
||||||
bool sameElementType = lhs.getElementType() == rhs.getElementType();
|
bool sameElementType = false;
|
||||||
|
// Namely, it is worth mentioning that the backends can have different
|
||||||
|
// expectations for signedness when converting from and to the builtin MLIR
|
||||||
|
// types. Therefore, the verifier cannot expect the input and output types to
|
||||||
|
// match in their signedness.
|
||||||
|
if (isa<IntegerType>(lhs.getElementType()) &&
|
||||||
|
isa<IntegerType>(rhs.getElementType())) {
|
||||||
|
sameElementType = lhs.getElementType().getIntOrFloatBitWidth() ==
|
||||||
|
rhs.getElementType().getIntOrFloatBitWidth();
|
||||||
|
} else {
|
||||||
|
sameElementType = lhs.getElementType() == rhs.getElementType();
|
||||||
|
}
|
||||||
return sameElementType && sameSize;
|
return sameElementType && sameSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
|
||||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
|
||||||
auto resultType =
|
|
||||||
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
|
|
||||||
if (!resultType)
|
|
||||||
return failure();
|
|
||||||
inferredReturnTypes.push_back(resultType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FromBuiltinTensorOp
|
// FromBuiltinTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
|
||||||
// Type conversion setup.
|
// Type conversion setup.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void
|
using ValueTensorTypeConversionFn =
|
||||||
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
std::function<std::optional<Type>(Torch::ValueTensorType)>;
|
||||||
TypeConverter &typeConverter) {
|
|
||||||
|
static void setupValueTensorToBuiltinTensorConversion(
|
||||||
|
ConversionTarget &target, TypeConverter &typeConverter,
|
||||||
|
const ValueTensorTypeConversionFn &conversionFn) {
|
||||||
target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
|
target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
|
||||||
TorchConversion::FromBuiltinTensorOp>();
|
TorchConversion::FromBuiltinTensorOp>();
|
||||||
typeConverter.addConversion(
|
typeConverter.addConversion(conversionFn);
|
||||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
|
||||||
return type.toBuiltinTensor();
|
|
||||||
});
|
|
||||||
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
|
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
|
||||||
ValueRange inputs,
|
ValueRange inputs,
|
||||||
Location loc) -> Value {
|
Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
|
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
|
||||||
return {};
|
return {};
|
||||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
return builder.create<ToBuiltinTensorOp>(loc, type, inputs[0]);
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder,
|
auto sourceMaterialization = [](OpBuilder &builder,
|
||||||
Torch::ValueTensorType type,
|
Torch::ValueTensorType type,
|
||||||
|
@ -162,9 +162,34 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
||||||
|
|
||||||
void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
||||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
auto valueTensorTypeConversion =
|
||||||
|
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||||
|
return type.toBuiltinTensor();
|
||||||
|
};
|
||||||
|
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||||
|
valueTensorTypeConversion);
|
||||||
setupTorchBoolToI1Conversion(target, typeConverter);
|
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||||
setupTorchIntToI64Conversion(target, typeConverter);
|
setupTorchIntToI64Conversion(target, typeConverter);
|
||||||
setupTorchFloatToF64Conversion(target, typeConverter);
|
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||||
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
|
||||||
|
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||||
|
auto valueTensorTypeConversion =
|
||||||
|
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||||
|
auto builtinType = type.toBuiltinTensor();
|
||||||
|
if (type.getDtype().isUnsignedInteger()) {
|
||||||
|
return builtinType.clone(type.getDtype());
|
||||||
|
}
|
||||||
|
return builtinType;
|
||||||
|
};
|
||||||
|
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||||
|
valueTensorTypeConversion);
|
||||||
|
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||||
|
setupTorchIntToI64Conversion(target, typeConverter);
|
||||||
|
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||||
|
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
|
@ -26,6 +26,32 @@ using namespace mlir::torch::TorchConversion;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
|
||||||
|
RewritePatternSet &patterns,
|
||||||
|
ConversionTarget &target) {
|
||||||
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||||
|
typeConverter);
|
||||||
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||||
|
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||||
|
typeConverter.isLegal(&op.getBody());
|
||||||
|
});
|
||||||
|
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
||||||
|
target.addDynamicallyLegalOp<func::CallOp>(
|
||||||
|
[&](func::CallOp op) { return typeConverter.isLegal(op); });
|
||||||
|
|
||||||
|
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
|
||||||
|
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
||||||
|
target.addLegalOp<ModuleOp>();
|
||||||
|
|
||||||
|
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||||
|
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
|
||||||
|
isLegalForBranchOpInterfaceTypeConversionPattern(op,
|
||||||
|
typeConverter) ||
|
||||||
|
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
struct FuncBackendTypeConversionPass
|
struct FuncBackendTypeConversionPass
|
||||||
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
||||||
using FuncBackendTypeConversionBase<
|
using FuncBackendTypeConversionBase<
|
||||||
|
@ -43,31 +69,41 @@ struct FuncBackendTypeConversionPass
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
|
||||||
patterns, typeConverter);
|
|
||||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
||||||
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
|
||||||
typeConverter.isLegal(&op.getBody());
|
|
||||||
});
|
|
||||||
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
|
||||||
target.addDynamicallyLegalOp<func::CallOp>(
|
|
||||||
[&](func::CallOp op) { return typeConverter.isLegal(op); });
|
|
||||||
|
|
||||||
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
|
|
||||||
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
|
||||||
target.addLegalOp<ModuleOp>();
|
|
||||||
|
|
||||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
|
||||||
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
|
|
||||||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
|
|
||||||
typeConverter) ||
|
|
||||||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
struct FuncBackendTypeConversionForStablehloPass
|
||||||
|
: public FuncBackendTypeConversionForStablehloBase<
|
||||||
|
FuncBackendTypeConversionForStablehloPass> {
|
||||||
|
using FuncBackendTypeConversionForStablehloBase<
|
||||||
|
FuncBackendTypeConversionForStablehloPass>::
|
||||||
|
FuncBackendTypeConversionForStablehloBase;
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<TorchConversion::TorchConversionDialect>();
|
||||||
|
}
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto module = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||||
|
typeConverter);
|
||||||
|
|
||||||
|
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
@ -75,6 +111,13 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
|
||||||
return std::make_unique<FuncBackendTypeConversionPass>();
|
return std::make_unique<FuncBackendTypeConversionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
|
||||||
|
createFuncBackendTypeConversionForStablehloPass() {
|
||||||
|
return std::make_unique<FuncBackendTypeConversionForStablehloPass>();
|
||||||
|
}
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FinalizingBackendTypeConversionPass
|
// FinalizingBackendTypeConversionPass
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -170,9 +213,61 @@ struct FinalizingBackendTypeConversionPass
|
||||||
stripTorchAttrs(func);
|
stripTorchAttrs(func);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
struct FinalizingBackendTypeConversionForStablehloPass
|
||||||
|
: public FinalizingBackendTypeConversionForStablehloBase<
|
||||||
|
FinalizingBackendTypeConversionForStablehloPass> {
|
||||||
|
using FinalizingBackendTypeConversionForStablehloBase<
|
||||||
|
FinalizingBackendTypeConversionForStablehloPass>::
|
||||||
|
FinalizingBackendTypeConversionForStablehloBase;
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
auto func = getOperation();
|
||||||
|
auto *context = &getContext();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||||
|
typeConverter);
|
||||||
|
|
||||||
|
// Mark materializations as illegal in this pass (since we are finalizing)
|
||||||
|
// and add patterns that eliminate them.
|
||||||
|
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
|
||||||
|
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
|
||||||
|
GeneratorToI64Op>(target, patterns, typeConverter);
|
||||||
|
|
||||||
|
// If all result types are legal, and all block arguments are legal, then
|
||||||
|
// all types in the program are legal.
|
||||||
|
//
|
||||||
|
// We also check that the operand types are legal to avoid creating invalid
|
||||||
|
// IR. For example, this prevents the patterns from updating
|
||||||
|
// the types of the operands to a return op without updating the enclosing
|
||||||
|
// function.
|
||||||
|
target.markUnknownOpDynamicallyLegal(
|
||||||
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
|
||||||
|
// Drop attributes that are no longer used after conversion out of Torch.
|
||||||
|
stripTorchAttrs(func);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<InterfacePass<FunctionOpInterface>>
|
std::unique_ptr<InterfacePass<FunctionOpInterface>>
|
||||||
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
||||||
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
|
||||||
|
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() {
|
||||||
|
return std::make_unique<FinalizingBackendTypeConversionForStablehloPass>();
|
||||||
|
}
|
||||||
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
|
@ -148,10 +148,11 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// StableHLO backend contract.
|
// StableHLO backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(
|
||||||
|
TorchConversion::createFuncBackendTypeConversionForStablehloPass());
|
||||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
pm.addNestedPass<func::FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass());
|
||||||
|
|
||||||
// Verify that we have lowered to Stablehlo ops.
|
// Verify that we have lowered to Stablehlo ops.
|
||||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||||
|
|
|
@ -66,6 +66,7 @@ void mlir::torch::registerAllPasses() {
|
||||||
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
|
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
|
||||||
mlir::stablehlo::registerStablehloAggressiveSimplificationPass();
|
mlir::stablehlo::registerStablehloAggressiveSimplificationPass();
|
||||||
mlir::stablehlo::registerStablehloRefineShapesPass();
|
mlir::stablehlo::registerStablehloRefineShapesPass();
|
||||||
|
mlir::stablehlo::registerStablehloConvertToSignlessPass();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_REFBACKEND
|
#ifdef TORCH_MLIR_ENABLE_REFBACKEND
|
||||||
|
|
|
@ -826,6 +826,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"SplitWithSizes_Module_basic",
|
"SplitWithSizes_Module_basic",
|
||||||
"TensorSplitSections_GetItemModule_basic",
|
"TensorSplitSections_GetItemModule_basic",
|
||||||
"TensorSplitSections_ListUnpackModule_basic",
|
"TensorSplitSections_ListUnpackModule_basic",
|
||||||
|
"EmptyModule_uint8",
|
||||||
|
"TypeConversionUint8ToF32Module_basic",
|
||||||
"AtenLinear1D_basic",
|
"AtenLinear1D_basic",
|
||||||
"AtenLinear2D_basic",
|
"AtenLinear2D_basic",
|
||||||
"AtenLinear3DBias_basic",
|
"AtenLinear3DBias_basic",
|
||||||
|
|
|
@ -23,6 +23,7 @@ STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
|
||||||
[
|
[
|
||||||
"func.func(stablehlo-aggressive-simplification)",
|
"func.func(stablehlo-aggressive-simplification)",
|
||||||
"stablehlo-legalize-to-linalg",
|
"stablehlo-legalize-to-linalg",
|
||||||
|
"stablehlo-convert-to-signless",
|
||||||
"canonicalize",
|
"canonicalize",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -136,6 +136,26 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
|
||||||
module.forward(tensor)
|
module.forward(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
class TypeConversionUint8ToF32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([3], torch.uint8, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return x.to(torch.float)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeConversionUint8ToF32Module())
|
||||||
|
def TypeConversionUint8ToF32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([0, 1, 255]).to(torch.uint8))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue