- Add fixes for
af78e5daf0
- Add fixes for
bb6d5c2200
pull/2758/head
Han-Chung Wang 2024-01-15 07:12:12 -08:00 committed by GitHub
parent dc37616d67
commit 10acea71be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 23 additions and 21 deletions

@ -1 +1 @@
Subproject commit 6b65d79fbb4682468333cea42b62f15c2dffd8f3 Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b

View File

@ -33,8 +33,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)), rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)), rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
rewriter.getDenseI32ArrayAttr({multiplier}), rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
rewriter.getBoolAttr(false));
return rescale_op.getResult(); return rescale_op.getResult();
} }
@ -86,8 +87,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
rewriter, op->getLoc(), output_type, conv_val, rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getDenseI32ArrayAttr({multiplier}), rewriter.getDenseI32ArrayAttr({multiplier}),
rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true),
rewriter.getBoolAttr(false));
return rescale_op.getResult(); return rescale_op.getResult();
@ -96,7 +98,7 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) { .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
// Per-channel quantization // Per-channel quantization
SmallVector<int32_t> multiplier_arr; SmallVector<int32_t> multiplier_arr;
SmallVector<int32_t> shift_arr; SmallVector<int8_t> shift_arr;
SmallVector<double> weight_scale_arr( SmallVector<double> weight_scale_arr(
weight_per_channel_qtype.getScales().begin(), weight_per_channel_qtype.getScales().begin(),
@ -115,14 +117,14 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
scale_width); scale_width);
multiplier_arr.push_back(multiplier); multiplier_arr.push_back(multiplier);
shift_arr.push_back(shift); shift_arr.push_back(static_cast<int8_t>(shift));
} }
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>( auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, conv_val, rewriter, op->getLoc(), output_type, conv_val,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
rewriter.getDenseI32ArrayAttr(multiplier_arr), rewriter.getDenseI32ArrayAttr(multiplier_arr),
rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true)); rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
return rescale_op.getResult(); return rescale_op.getResult();

View File

@ -1,10 +1,10 @@
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s // RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { // CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INTM1:.*]] = torch.constant.int -1 // CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1
// CHECK: %[[NEG_STEP:.*]] = torch.aten.__range_length %[[INT1]], %[[INT3]], %[[INTM1]] : !torch.int, !torch.int, !torch.int -> !torch.int // CHECK: %[[NEG_STEP:.*]] = torch.aten.__range_length %[[INT1]], %[[INT3]], %[[INTM1]] : !torch.int, !torch.int, !torch.int -> !torch.int
// CHECK: return %[[INT2]], %[[INT2]], %[[INT1]], %[[NEG_STEP]] : !torch.int, !torch.int, !torch.int, !torch.int // CHECK: return %[[INT2]], %[[INT2]], %[[INT1]], %[[NEG_STEP]] : !torch.int, !torch.int, !torch.int, !torch.int
func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {

View File

@ -84,8 +84,8 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor
// CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor // CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
// CHECK: return %[[VAR]] : !torch.tensor // CHECK: return %[[VAR]] : !torch.tensor

View File

@ -105,9 +105,9 @@ func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !tor
// CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll( // CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor { // CHECK-SAME: %[[ARG1:.*]]: !torch.list<int>) -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor // CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor
// CHECK: } shapes { // CHECK: } shapes {
@ -375,8 +375,8 @@ func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch.
// missing. // missing.
// CHECK-LABEL: func.func @basic_integration( // CHECK-LABEL: func.func @basic_integration(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[RESULT:.*]] = torch.shape.calculate {
// CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk> // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk>
@ -410,8 +410,8 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor
// CHECK-LABEL: func.func @fold_prim_unchecked_cast_op( // CHECK-LABEL: func.func @fold_prim_unchecked_cast_op(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor, // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor { // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK-DAG: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK-DAG: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.shape.calculate { // CHECK: %[[VAL_4:.*]] = torch.shape.calculate {
// CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk> // CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk>
// CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk>