Fix recent break due to PyTorch changes.

Tracing seems now now capture a 4-operand version of aten::add instead
of 3-operand.

I fixed the tests that made sense. One test was XFAIL'ed, as I don't
have in cache the exact way to fix it yet (requires touching
aten-recogniz-kernels stuff).  I'll be context switching to work on the
kernel recognition stuff soon, and will fix it then.
pull/183/head
Sean Silva 2021-03-03 15:03:08 -08:00
parent 43dba03afd
commit a36113e586
3 changed files with 18 additions and 8 deletions

View File

@ -46,7 +46,7 @@ with mb.capture_function("resa", [inputs]) as f:
# CHECK-LABEL: func @resa(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[1,16,128,128]:f32>) -> !numpy.ndarray<[1,16,128,128]:f32> {
# CHECK: %[[VAL_118:.*]] = torch.kernel_call "aten::convolution" {{.*}} : (!numpy.ndarray<[1,8,128,128]:f32>, !numpy.ndarray<[16,8,1,1]:f32>, !numpy.ndarray<[16]:f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: %[[VAL_119:.*]] = torch.kernel_call "aten::add" %{{.*}}, %[[VAL_118]], %{{.*}} : (!numpy.ndarray<[1,16,128,128]:f32>, !numpy.ndarray<[1,16,128,128]:f32>, i64) -> !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: %[[VAL_119:.*]] = torch.kernel_call "aten::add" %{{.*}}, %[[VAL_118]], %{{.*}}, %{{.*}} : (!numpy.ndarray<[1,16,128,128]:f32>, !numpy.ndarray<[1,16,128,128]:f32>, i64, !numpy.ndarray<[1,16,128,128]:f32>) -> !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: return %[[VAL_119]] : !numpy.ndarray<[1,16,128,128]:f32>
# CHECK: }
mb.module.operation.print(large_elements_limit=2)

View File

@ -17,11 +17,18 @@ mb = torch_mlir.ModuleBuilder()
with mb.capture_function("add3", [t0, t1, t2]) as f:
t3 = t0 + t1 + t2
f.returns([t3])
# CHECK-LABEL: func @add3
# CHECK: %[[CST_1A:.*]] = constant 1 : i64
# CHECK: %[[CST_1B:.*]] = constant 1 : i64
# CHECK: %[[ADD0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[CST_1A]] : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64) -> !numpy.ndarray<[1,2,3,4]:f32> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: %[[ADD1:.*]] = torch.kernel_call "aten::add" %[[ADD0]], %arg2, %[[CST_1B]] : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64) -> !numpy.ndarray<[1,2,3,4]:f32> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[ADD1]]
# NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
# CHECK-LABEL: func @add3(
# CHECK-SAME: %[[VAL_0:.*]]: !numpy.ndarray<[1,2,3,4]:f32>, %[[VAL_1:.*]]: !numpy.ndarray<[1,2,3,4]:f32>,
# CHECK-SAME: %[[VAL_2:.*]]: !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32> {
# CHECK: %[[VAL_3:.*]] = constant 1 : i64
# CHECK: %[[VAL_4:.*]] = constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
# CHECK: %[[VAL_5:.*]] = constant 1 : i64
# CHECK: %[[VAL_6:.*]] = constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
# CHECK: %[[VAL_7:.*]] = numpy.create_array_from_tensor %[[VAL_4]] : (tensor<1x2x3x4xf32>) -> !numpy.ndarray<[1,2,3,4]:f32>
# CHECK: %[[VAL_8:.*]] = torch.kernel_call "aten::add" %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_7]] : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64, !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32> {sigArgTypes = ["Tensor", "Tensor", "Scalar", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: %[[VAL_9:.*]] = numpy.create_array_from_tensor %[[VAL_6]] : (tensor<1x2x3x4xf32>) -> !numpy.ndarray<[1,2,3,4]:f32>
# CHECK: %[[VAL_10:.*]] = torch.kernel_call "aten::add" %[[VAL_8]], %[[VAL_2]], %[[VAL_5]], %[[VAL_9]] : (!numpy.ndarray<[1,2,3,4]:f32>, !numpy.ndarray<[1,2,3,4]:f32>, i64, !numpy.ndarray<[1,2,3,4]:f32>) -> !numpy.ndarray<[1,2,3,4]:f32> {sigArgTypes = ["Tensor", "Tensor", "Scalar", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
# CHECK: return %[[VAL_10]] : !numpy.ndarray<[1,2,3,4]:f32>
# CHECK: }
print(mb.module)

View File

@ -1,7 +1,10 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
# RUN: %PYTHON %s | npcomp-opt -aten-recognize-kernels -numpy-public-functions-to-tensor -canonicalize | FileCheck %s
# TODO: Re-enable after adding support for 4-operand aten::add in `aten-recognize-kernels`.
# XFAIL: *
# TODO: This test should go away or become part of an e2e test suite. It is
# preserved right now as a stop-gap.