Fix as_strided/slice symint (#1401)

* Fix as_strided symint

* Re-enable LTC tests

* Re-enable LTC

* Add hardtanh shape inference function

* Fix slice symint
pull/1417/head
Jae Hoon (Antonio) Kim 2022-09-26 12:16:49 -04:00 committed by GitHub
parent 41d45400be
commit 3e27aa2be3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 26 additions and 11 deletions

View File

@ -78,8 +78,13 @@ symint:
- expand
- expand_copy
- narrow_copy
- slice_backward
- slice_copy.Tensor
- slice_scatter
- view
- view_copy
- as_strided_copy
- as_strided_scatter
additional_ops:

View File

@ -177,7 +177,7 @@ function build_in_tree() {
-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR="/main_checkout/torch-mlir/externals/llvm-external-projects/torch-mlir-dialects" \
-DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_ENABLE_LTC=ON \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_src" \
-DPython3_EXECUTABLE="$(which python3)" \
/main_checkout/torch-mlir/externals/llvm-project/llvm
@ -240,9 +240,8 @@ function test_in_tree() {
# - AvgPool2dFloatModule_basic,AvgPool2dCeilModeTrueModule_basic: https://github.com/llvm/torch-mlir/issues/1361
python -m e2e_testing.main --config=tosa -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed AvgPool2dFloatModule_basic AvgPool2dCeilModeTrueModule_basic
# Temporarily disabled in top of main (https://github.com/llvm/torch-mlir/pull/1292)
#echo ":::: Run Lazy Tensor Core e2e integration tests"
#python -m e2e_testing.torchscript.main --config=lazy_tensor_core -v
echo ":::: Run Lazy Tensor Core e2e integration tests"
python -m e2e_testing.main --config=lazy_tensor_core -v
}
function setup_venv() {
@ -291,7 +290,7 @@ function build_out_of_tree() {
-DLLVM_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/llvm/" \
-DMLIR_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" \
-DMLIR_ENABLE_BINDINGS_PYTHON=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_ENABLE_LTC=ON \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_src" \
-DPython3_EXECUTABLE="$(which python3)" \
/main_checkout/torch-mlir

View File

@ -450,6 +450,7 @@ LTC_XFAIL_SET = {
"_Convolution2DTF32Module_basic",
"_ConvolutionDeprecated2DAllFalseModule_basic",
"_ConvolutionDeprecated2DBenchmarkModule_basic",
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
@ -496,8 +497,6 @@ LTC_XFAIL_SET = {
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
@ -545,6 +544,7 @@ LTC_XFAIL_SET = {
"IndexTensorHackedTwinModule_basic",
"IndexTensorHackedTwinModule3dInput_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"LiftFreshCopyModule_basic",
"Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic",

View File

@ -419,9 +419,13 @@ at::Tensor LazyNativeFunctions::select_backward(
return at::functionalization::functionalize_aten_op<ATEN_OP(
select_backward)>::call(grad_output, input_sizes, dim, index);
}
at::Tensor LazyNativeFunctions::slice_backward(
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim,
int64_t start, int64_t end, int64_t step) {
at::Tensor LazyNativeFunctions::slice_backward_symint(
const at::Tensor& grad_output,
at::SymIntArrayRef input_sizes,
int64_t dim,
c10::SymInt start,
c10::SymInt end,
c10::SymInt step) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}

View File

@ -36,5 +36,12 @@ std::vector<torch::lazy::Shape> compute_shape_var(
return {Shape(self.scalar_type(), {})};
}
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val
) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
} // namespace lazy
} // namespace torch

View File

@ -46,7 +46,7 @@ import torch
PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1"
# If true, enable LTC build by default
TORCH_MLIR_ENABLE_LTC_DEFAULT = False
TORCH_MLIR_ENABLE_LTC_DEFAULT = True
# Build phase discovery is unreliable. Just tell it what phases to run.
class CustomBuild(_build):