As an example, we'll define an N-dimensional vector. We need to start off with the basic structure and a constructor. In order to be able to instantiate the object as, NDimPoint(x1, x2, ..., xN)
then we'll need to transform the vararg tuple to an array, so we do so.
struct NDimPoint{T<:Real}
coeffs::Vector{T}
N::Int
end
function NDimPoint(xs...)
xs = [xs...]
NDimPoint(xs, length(xs))
end
To enable iteration over a vector, we need to define Base.length
and Base.iterate
. I like to define Base.eltype
for anything which houses an internal container since it simplifies retrieving the datatype when needed.
Base.length(N::NDimPoint) = N.N
Base.eltype(::Type{NDimPoint{T}}) where {T} = T
function Base.iterate(N::NDimPoint, state=1)
return state > N.N ? nothing : (N.coeffs[state], state+1)
end
I also want to be able to easily retrieve components by index, so I'll define Base.getindex
, Base.firstindex
, and Base.lastindex
. I also want to be able to call eachindex
on a point for iteration, so I need to define Base.keys
.
Base.firstindex(N::NDimPoint) = 1
Base.lastindex(N::NDimPoint) = N.N
function Base.getindex(N::NDimPoint, i::Int)
1 <= i <= N.N || throw(BoundsError(N, i))
return N.coeffs[i]
end
function Base.keys(N::NDimPoint)
return 1:length(N)
end
It will simplify things to be able to perform vector algebra with these objects so we'll need to define vector addition, scalar multiplication, and the inner product.
import Base.:+, Base.:-, Base.:*, Base.:/
# x + y
function Base.:+(x::NDimPoint, y::NDimPoint)
if length(x) != length(y)
throw(DimensionMismatch())
end
NDimPoint(x.coeffs + y.coeffs, x.N)
end
# x + (-1)y
function Base.:-(x::NDimPoint, y::NDimPoint)
if length(x) != length(y)
throw(DimensionMismatch())
end
NDimPoint(x.coeffs - y.coeffs, x.N)
end
# λ * x
function Base.:*(λ::T, x::NDimPoint) where {T}
NDimPoint(λ * x.coeffs, x.N)
end
# x * λ
function Base.:*(x::NDimPoint, λ::T) where {T}
λ * x
end
# x * y
function Base.:*(x::NDimPoint, y::NDimPoint)
if length(x) != length(y)
throw(DimensionMismatch())
end
s = zero(eltype(x))
for k in eachindex(x)
s += x[k]*y[k]
end
return s
end
# x / y => undefined
function Base.:/(::Type{NDimPoint}, ::Type{NDimPoint})
throw(ArgumentError("Division is not defined for Rn vectors"))
end