diff --git a/python/npcomp/tracing/emitters.py b/python/npcomp/tracing/emitters.py index 540663724..af75f71af 100644 --- a/python/npcomp/tracing/emitters.py +++ b/python/npcomp/tracing/emitters.py @@ -17,16 +17,16 @@ class TraceValueType(Enum): NDARRAY = 1 -class TraceValue( - namedtuple("TraceValue", ["value", "type"], - defaults=(TraceValueType.NDARRAY,))): +class TraceValue(namedtuple("TraceValue", ["value", "type"])): __slots__ = () """A Python value and the trace type that it should correspond to.""" +TraceValue.__new__.__defaults__ = (TraceValueType.NDARRAY,) + + class TraceInvocation( - namedtuple("TraceInvocation", ["inputs", "kwargs", "protocol", "method"], - defaults=(Protocol.ARRAY_FUNC, "__call__"))): + namedtuple("TraceInvocation", ["inputs", "kwargs", "protocol", "method"])): """An invocation of a single functions. This abstracts over both ufuncs and array_funcs, differentiating by the @@ -35,10 +35,12 @@ class TraceInvocation( __slots__ = () +TraceInvocation.__new__.__defaults__ = (Protocol.ARRAY_FUNC, "__call__") + + class EmissionRequest( namedtuple("EmissionRequest", - ["input_ssa_values", "dialect_helper", "extra"], - defaults=(None,))): + ["input_ssa_values", "dialect_helper", "extra"])): """Represents the result of processing inputs from an invocation. The `input_ssa_values` are mlir.ir.Value instances corresponding to @@ -53,10 +55,12 @@ class EmissionRequest( __slots__ = () +EmissionRequest.__new__.__defaults__ = (None,) + + class TraceValueMap( namedtuple("TraceValueMap", - ["input_trace_values", "result_trace_value_types", "extra"], - defaults=(None,))): + ["input_trace_values", "result_trace_value_types", "extra"])): """The result of mapping an invocation to corresponding op structure. This type associates: @@ -69,6 +73,9 @@ class TraceValueMap( __slots__ = () +TraceValueMap.__new__.__defaults__ = (None) + + class FuncEmitter: """An emitter for an op-like function invocation."""